Example #1
0
def cluster_stats(cell, gs):

    (timestamps, samples, recordingNumber) = cell.load_all_spikedata()

    #ISI loghist
    plt.subplot(gs[0:2, 0])
    if timestamps is not None:
        try:
            spikesorting.plot_isi_loghist(timestamps)
        except:
            # raise AttributeError
            print("problem with isi vals")

    #Waveforms
    plt.subplot(gs[2:5, 0])
    if len(samples) > 0:
        spikesorting.plot_waveforms(samples)

    #Events in time
    plt.subplot(gs[5:7, 0])
    if timestamps is not None:
        try:
            spikesorting.plot_events_in_time(timestamps)
        except:
            print("problem with isi vals")
def plot_noisebursts_response_raster(animal,
                                     ephysSession,
                                     tetrode,
                                     cluster,
                                     alignment='sound',
                                     timeRange=[-0.1, 0.3]):
    '''
    Function to plot noisebursts along with waveforms for each cluster to distinguish cell responses and noise.
    '''
    eventData = load_event_data(animal, ephysSession)

    spikeData = load_spike_data(animal, ephysSession, tetrode, cluster)
    spikeTimestamps = spikeData.timestamps

    eventOnsetTimes = np.array(eventData.timestamps)
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        spikeTimestamps, eventOnsetTimes, timeRange)
    # -- Plot raster -- #
    plt.subplot(2, 1, 1)
    extraplots.raster_plot(spikeTimesFromEventOnset,
                           indexLimitsEachTrial,
                           timeRange,
                           trialsEachCond=[],
                           fillWidth=None,
                           labels=None)
    plt.ylabel('Trials')
    plt.subplot(2, 1, 2)
    wavesThisCluster = spikeData.samples
    spikesorting.plot_waveforms(wavesThisCluster)
    plt.title('{0} T{1}C{2} noisebursts'.format(ephysSession, tetrode, cluster,
                                                alignment),
              fontsize=10)
    plt.show()
Example #3
0
def plot_waveforms_in_event_locked_timerange(spikeSamples, spikeTimes, eventOnsetTimes, timeRange):
    
    spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial,spikeIndices = spikesanalysis.eventlocked_spiketimes(spikeTimes, eventOnsetTimes, timeRange, spikeindex=True)
    samplesToPlot=spikeSamples[spikeIndices]
    ax=plt.gca()
    spikesorting.plot_waveforms(samplesToPlot)
    plt.title('Waveforms in range {} to {}'.format(*timeRange))
def cluster_stats(cell, gs):

    (timestamps,
    samples,
    recordingNumber) = cell.load_all_spikedata()

    #ISI loghist
    plt.subplot(gs[0, 0])
    if timestamps is not None:
        try:
            spikesorting.plot_isi_loghist(timestamps)
        except:
            # raise AttributeError
            print("problem with isi vals")

    #Waveforms
    plt.subplot(gs[1, 0])
    if len(samples)>0:
        spikesorting.plot_waveforms(samples)

    #Events in time
    plt.subplot(gs[2, 0])
    if timestamps is not None:
        try:
            spikesorting.plot_events_in_time(timestamps)
        except:
            print("problem with isi vals")
Example #5
0
def plot_waveform_each_cluster(cellObj, sessionType='behavior'):
    '''Function to plot average and individual waveforms for one isolated cluster. 
    :param arg1: Cell object from ephyscore.
    :param arg2: A string of the type of the ephys session to use. 
    '''
    sessionInd = cellObj.get_session_inds(sessionType)[0]
    ephysData = cellObj.load_ephys_by_index(sessionInd) 
    wavesThisCluster = ephysData['samples']
    spikesorting.plot_waveforms(wavesThisCluster)
def plot_waveform_each_cluster(animal, ephysSession, tetrode, cluster):
    '''Function to plot average and individual waveforms for one isolated cluster. 
    :param arg1: String containing animal name.
    :param arg2: A string of the name of the ephys session, this is the full filename, in {date}_XX-XX-XX format. 
    :param arg3: Integer in range(1,9) for tetrode number.
    :param arg4: Integer for cluster number.
    '''
    spikeData = load_spike_data(animal, ephysSession, tetrode, cluster)
    wavesThisCluster = spikeData.samples
    spikesorting.plot_waveforms(wavesThisCluster)
Example #7
0
def plot_waveforms_in_event_locked_timerange(spikeSamples, spikeTimes,
                                             eventOnsetTimes, timeRange):
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial, spikeIndices = spikesanalysis.eventlocked_spiketimes(
        spikeTimes,
        eventOnsetTimes,
        timeRange,
        spikeindex=True)
    samplesToPlot = spikeSamples[spikeIndices]
    ax = plt.gca()
    spikesorting.plot_waveforms(samplesToPlot)
    plt.title('Waveforms in range {} to {}'.format(*timeRange))
Example #8
0
 def plot_report(self, showfig=False):
     print 'Plotting report...'
     #plt.figure(self.fig)
     self.fig = plt.gcf()
     self.fig.clf()
     self.fig.set_facecolor('w')
     nCols = 3
     nRows = self.nRows
     #for indc,clusterID in enumerate(self.clustersList[:3]):
     for indc, clusterID in enumerate(self.clustersList):
         #print('Preparing cluster %d'%clusterID)
         if (indc + 1) > self.nRows:
             print 'WARNING! This cluster was ignored (more clusters than rows)'
             continue
         tsThisCluster = self.timestamps[self.spikesEachCluster[indc, :]]
         wavesThisCluster = self.samples[self.spikesEachCluster[
             indc, :], :, :]
         # -- Plot ISI histogram --
         plt.subplot(self.nRows, nCols, indc * nCols + 1)
         spikesorting.plot_isi_loghist(tsThisCluster)
         if indc < (self.nClusters - 1):  #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         plt.ylabel('c%d' % clusterID, rotation=0, va='center', ha='center')
         # -- Plot events in time --
         plt.subplot(2 * self.nRows, nCols, 2 * (indc * nCols) + 6)
         spikesorting.plot_events_in_time(tsThisCluster)
         if indc < (self.nClusters - 1):  #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         # -- Plot projections --
         plt.subplot(2 * self.nRows, nCols, 2 * (indc * nCols) + 3)
         spikesorting.plot_projections(wavesThisCluster)
         # -- Plot waveforms --
         plt.subplot(self.nRows, nCols, indc * nCols + 2)
         spikesorting.plot_waveforms(wavesThisCluster)
     #figTitle = self.get_title()
     plt.figtext(0.5,
                 0.92,
                 self.figTitle,
                 ha='center',
                 fontweight='bold',
                 fontsize=10)
     if showfig:
         #plt.draw()
         plt.show()
 def plot_report(self,showfig=False):
     print 'Plotting report...'
     #plt.figure(self.fig)
     self.fig = plt.gcf()
     self.fig.clf()
     self.fig.set_facecolor('w')
     nCols = 3
     nRows = self.nRows
     #for indc,clusterID in enumerate(self.clustersList[:3]):
     for indc,clusterID in enumerate(self.clustersList):
         #print('Preparing cluster %d'%clusterID)
         if (indc+1)>self.nRows:
             print 'WARNING! This cluster was ignored (more clusters than rows)'
             continue
         tsThisCluster = self.timestamps[self.spikesEachCluster[indc,:]]
         wavesThisCluster = self.samples[self.spikesEachCluster[indc,:],:,:]
         # -- Plot ISI histogram --
         plt.subplot(self.nRows,nCols,indc*nCols+1)
         spikesorting.plot_isi_loghist(tsThisCluster)
         if indc<(self.nClusters-1): #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')
         # -- Plot events in time --
         plt.subplot(2*self.nRows,nCols,2*(indc*nCols)+6)
         spikesorting.plot_events_in_time(tsThisCluster)
         if indc<(self.nClusters-1): #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         # -- Plot projections --
         plt.subplot(2*self.nRows,nCols,2*(indc*nCols)+3)
         spikesorting.plot_projections(wavesThisCluster)
         # -- Plot waveforms --
         plt.subplot(self.nRows,nCols,indc*nCols+2)
         spikesorting.plot_waveforms(wavesThisCluster)
     #figTitle = self.get_title()
     plt.figtext(0.5,0.92, self.figTitle,ha='center',fontweight='bold',fontsize=10)
     if showfig:
         #plt.draw()
         plt.show()
Example #10
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0] / 1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']

    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()

            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3],
                                                     tetrode,
                                                     cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)

            firstSortLabels = [
                '{}'.format(band) for band in np.unique(currentBand)
            ]
            secondSortLabels = [
                'Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)
            ]

            trialsEachCond = behavioranalysis.find_trials_each_combination(
                currentBand, numBands, currentAmp, numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5 * ind, 0), rowspan=4, colspan=7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(
                    spikeTimesFromEventOnset,
                    indexLimitsEachTrial,
                    timeRange,
                    trialsEachCond=trialsThisSecondVal,
                    labels=firstSortLabels)
                plt.setp(pRaster, ms=4)

                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")

            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12, 15), (10, 0), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.0, 1.0],
                             title='bandwidth selectivity')
            plt.subplot2grid((12, 15), (10, 3), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.2, 1.0],
                             title='first 200ms excluded')

            # -- plot frequency tuning heat map --
            plt.subplot2grid((12, 15), (5, 7), rowspan=4, colspan=4)

            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps

            dataplotter.two_axis_heatmap(
                spikeTimestamps=spikeTimestamps,
                eventOnsetTimes=eventOnsetTimes,
                firstSortArray=currentInt,
                secondSortArray=currentFreq,
                firstSortLabels=[
                    "%.1f" % inten for inten in np.unique(currentInt)
                ],
                secondSortLabels=[
                    "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
                ],
                xlabel='Frequency (kHz)',
                ylabel='Intensity (dB SPL)',
                plotTitle='Frequency Tuning Curve',
                flipFirstAxis=True,
                flipSecondAxis=False,
                timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')

            # -- plot frequency tuning raster --
            plt.subplot2grid((12, 15), (0, 7), rowspan=4, colspan=4)
            freqLabels = [
                "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
            ]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentFreq,
                                    timeRange=[-0.1, 0.5],
                                    labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')

            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(
                currentRate, numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
                spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12, 15), (5, 11), rowspan=4, colspan=4)
            pPSTH = extraplots.plot_psth(spikeCountMat / 0.05,
                                         1,
                                         binEdges[:-1],
                                         trialsEachCond,
                                         colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')

            # -- plot AM raster --
            plt.subplot2grid((12, 15), (0, 11), rowspan=4, colspan=4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentRate,
                                    timeRange=[-0.2, 0.8],
                                    labels=rateLabels,
                                    colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')

            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((12, 15), (10, 6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12, 15), (10, 9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12, 15), (10, 12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12, 15), (11, 12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=1.5)
            plt.suptitle(
                '{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'
                .format(mouse, date, siteName, tetrode, cluster, charfreq,
                        modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format='png', bbox_inches='tight')
Example #11
0
plt.figure()
plt.plot(spikeTimesFromEventOnset, trialIndexForEachSpike, 'k.', ms=1)

plt.show()

#Will plot the waveforms of spikes in a certain time range after an event
from jaratoolbox import spikesorting

indexTR = [0, 0.1]
spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial, spikeIndex = spikesanalysis.eventlocked_spiketimes(
    spikeTimestamps, eventOnsetTimes, indexTR, spikeindex=True)
waves = spikeDataNB.samples[spikeIndex]

figure()
spikesorting.plot_waveforms(waves)

#For a session that has behavior data, we can use the object to get the behavior data
#This is the behavior data for the best frequency presentation.
#At this site, the TC and BF sessions have saved behavior data
bdata = ex0624.get_session_behav_data(s1BF)

#We can also pass ephys filenames, with or without the date attached, to the object instead of handles.
#Lets look at the LP session for TT6c3
spikeDataLP = ex0624.get_session_spike_data_one_tetrode(
    '2015-06-24_15-25-08', 6)
spikeTimesLP = spikeDataLP.timestamps[spikeDataLP.clusters == 3]

#Since the object knows the date, we can also do:
spikeDataLP = ex0624.get_session_spike_data_one_tetrode('15-25-08', 6)
Example #12
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0]/1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])  
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])  
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']
      
    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()
            
            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3], tetrode, cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)
            
            firstSortLabels = ['{}'.format(band) for band in np.unique(currentBand)]
            secondSortLabels = ['Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)]
            
            trialsEachCond = behavioranalysis.find_trials_each_combination(currentBand, 
                                                                           numBands, 
                                                                           currentAmp, 
                                                                           numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5*ind, 0), rowspan = 4, colspan = 7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,
                                                                indexLimitsEachTrial,
                                                                timeRange,
                                                                trialsEachCond=trialsThisSecondVal,
                                                                labels=firstSortLabels)
                plt.setp(pRaster, ms=4)
                
                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")
            
            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12,15), (10,0), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.0, 1.0], title='bandwidth selectivity')
            plt.subplot2grid((12,15), (10,3), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.2, 1.0], title='first 200ms excluded')
            
            # -- plot frequency tuning heat map -- 
            plt.subplot2grid((12, 15), (5, 7), rowspan = 4, colspan = 4)
            
            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            
            dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                            eventOnsetTimes=eventOnsetTimes,
                                            firstSortArray=currentInt,
                                            secondSortArray=currentFreq,
                                            firstSortLabels=["%.1f" % inten for inten in np.unique(currentInt)],
                                            secondSortLabels=["%.1f" % freq for freq in np.unique(currentFreq)/1000.0],
                                            xlabel='Frequency (kHz)',
                                            ylabel='Intensity (dB SPL)',
                                            plotTitle='Frequency Tuning Curve',
                                            flipFirstAxis=True,
                                            flipSecondAxis=False,
                                            timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')
            
            # -- plot frequency tuning raster --
            plt.subplot2grid((12,15), (0, 7), rowspan = 4, colspan = 4)
            freqLabels = ["%.1f" % freq for freq in np.unique(currentFreq)/1000.0]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentFreq, timeRange=[-0.1, 0.5], labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')
            
            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(currentRate, 
                                                                           numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12,15), (5, 11), rowspan = 4, colspan = 4)
            pPSTH = extraplots.plot_psth(spikeCountMat/0.05, 1, binEdges[:-1], trialsEachCond, colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')
            
            # -- plot AM raster --
            plt.subplot2grid((12,15), (0, 11), rowspan = 4, colspan = 4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentRate, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')
            
            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]
            
            # -- Plot ISI histogram --
            plt.subplot2grid((12,15), (10,6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12,15), (10,9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12,15), (10,12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12,15), (11,12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace = 1.5)
            plt.suptitle('{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(mouse, date, siteName, tetrode, cluster, charfreq, modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
Example #13
0
                                         downsamplefactor=downsampleFactorPsth)
            #plt.plot(spikeTimesFromEventOnset,trialIndexForEachSpike,'.', markersize = '3')
            extraplots.boxoff(plt.gca())
            plt.ylabel('Firing rate\n(spk/s)')
            plt.xlabel('time from onset of the sound(s)')

            # --------ISI plot----------------
            ax31 = plt.subplot2grid(scaleGrid, (2 + j / 3, 0), colspan=2)
            plt.subplots_adjust(bottom=bottom, top=top, hspace=hspace)
            spikesorting.plot_isi_loghist(spikeT[i])

            # -- Plot waveforms --
            ax32 = plt.subplot2grid(scaleGrid, (2 + j / 3, 2), colspan=2)
            plt.subplots_adjust(bottom=bottom, top=top, hspace=hspace)
            if waveF[i].any():
                spikesorting.plot_waveforms(waveF[i])

            # -- Plot events in time --
            ax33 = plt.subplot2grid(scaleGrid, (2 + j / 3, 4), colspan=2)
            plt.subplots_adjust(bottom=bottom, top=top, hspace=hspace)
            if spikeT[i].any():
                spikesorting.plot_events_in_time(spikeT[i])
            j = j + 3
    ################################################################################
    # -----------------------------Tuning Curve Raster Plot-------------------------
        else:
            #----------if the length doesn't match, abandon last one from trialsEachType
            while indexLimitsEachTrial.shape[1] < trialsEachType.shape[0]:
                trialsEachType = np.delete(trialsEachType, -1, 0)

            ax11 = plt.subplot2grid(scaleGrid, (0, 6), colspan=3, rowspan=3)
def plot_rew_change_per_cell(oneCell,trialLimit=[],alignment='sound'):
    '''
    Plots raster and PSTH for one cell during reward_change_freq_dis task, split by block; alignment parameter should be set to either 'sound', 'center-out', or 'side-in'.
    '''    
    bdata = load_behav_per_cell(oneCell)
    (spikeTimestamps,waveforms,eventOnsetTimes,eventData) = load_ephys_per_cell(oneCell)

    # -- Check to see if ephys has skipped trials, if so remove trials from behav data 
    soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
    soundOnsetTimeEphys = eventOnsetTimes[soundOnsetEvents]
    soundOnsetTimeBehav = bdata['timeTarget']

    # Find missing trials
    missingTrials = behavioranalysis.find_missing_trials(soundOnsetTimeEphys,soundOnsetTimeBehav)
    # Remove missing trials
    bdata.remove_trials(missingTrials)

    currentBlock = bdata['currentBlock']
    blockTypes = [bdata.labels['currentBlock']['same_reward'],bdata.labels['currentBlock']['more_left'],bdata.labels['currentBlock']['more_right']]
    #blockLabels = ['more_left', 'more_right']
    if(not len(trialLimit)):
        validTrials = np.ones(len(currentBlock),dtype=bool)
    else:
        validTrials = np.zeros(len(currentBlock),dtype=bool)
        validTrials[trialLimit[0]:trialLimit[1]] = 1

    trialsEachType = behavioranalysis.find_trials_each_type(currentBlock,blockTypes)
    
        
    if alignment == 'sound':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
    elif alignment == 'center-out':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
        diffTimes=bdata['timeCenterOut']-bdata['timeTarget']
        EventOnsetTimes+=diffTimes
    elif alignment == 'side-in':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
        diffTimes=bdata['timeSideIn']-bdata['timeTarget']
        EventOnsetTimes+=diffTimes

    freqEachTrial = bdata['targetFrequency']
    possibleFreq = np.unique(freqEachTrial)
    
    rightward = bdata['choice']==bdata.labels['choice']['right']
    leftward = bdata['choice']==bdata.labels['choice']['left']
    invalid = bdata['outcome']==bdata.labels['outcome']['invalid']
        
    correct = bdata['outcome']==bdata.labels['outcome']['correct'] 
    incorrect = bdata['outcome']==bdata.labels['outcome']['error']  

    ######Split left and right trials into correct and  incorrect categories to look at error trials#########
    rightcorrect = rightward&correct&validTrials
    leftcorrect = leftward&correct&validTrials
    #righterror = rightward&incorrect&validTrials
    #lefterror = leftward&incorrect&validTrials

    rightcorrectBlockSameReward = rightcorrect&trialsEachType[:,0]
    rightcorrectBlockMoreLeft = rightcorrect&trialsEachType[:,1] 
    rightcorrectBlockMoreRight = rightcorrect&trialsEachType[:,2]
    leftcorrectBlockSameReward = leftcorrect&trialsEachType[:,0]
    leftcorrectBlockMoreLeft = leftcorrect&trialsEachType[:,1]
    leftcorrectBlockMoreRight = leftcorrect&trialsEachType[:,2]

    trialsEachCond = np.c_[leftcorrectBlockMoreLeft,rightcorrectBlockMoreLeft,leftcorrectBlockMoreRight,rightcorrectBlockMoreRight,leftcorrectBlockSameReward,rightcorrectBlockSameReward] 


    colorEachCond = ['g','r','m','b','y','darkgray']
    #trialsEachCond = np.c_[invalid,leftcorrect,rightcorrect,lefterror,righterror] 
    #colorEachCond = ['0.75','g','r','b','m'] 

    (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
spikesanalysis.eventlocked_spiketimes(spikeTimestamps,EventOnsetTimes,timeRange)
    
    ###########Plot raster and PSTH#################
    plt.figure()
    ax1 = plt.subplot2grid((8,5), (0, 0), rowspan=4,colspan=5)
    extraplots.raster_plot(spikeTimesFromEventOnset,indexLimitsEachTrial,timeRange,trialsEachCond=trialsEachCond,colorEachCond=colorEachCond,fillWidth=None,labels=None)
    plt.ylabel('Trials')
    plt.xlim(timeRange)

    plt.title('{0}_{1}_TT{2}_c{3}_{4}'.format(oneCell.animalName,oneCell.behavSession,oneCell.tetrode,oneCell.cluster,alignment))

    timeVec = np.arange(timeRange[0],timeRange[-1],binWidth)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset,indexLimitsEachTrial,timeVec)
    smoothWinSize = 3
    ax2 = plt.subplot2grid((8,5), (4, 0),colspan=5,rowspan=2,sharex=ax1)
    extraplots.plot_psth(spikeCountMat/binWidth,smoothWinSize,timeVec,trialsEachCond=trialsEachCond,colorEachCond=colorEachCond,linestyle=None,linewidth=3,downsamplefactor=1)
    plt.xlabel('Time from {0} onset (s)'.format(alignment))
    plt.ylabel('Firing rate (spk/sec)')
   
    # -- Plot ISI histogram --
    plt.subplot2grid((8,5), (6,0), rowspan=1, colspan=2)
    spikesorting.plot_isi_loghist(spikeTimestamps)
    plt.ylabel('c%d'%oneCell.cluster,rotation=0,va='center',ha='center')
    plt.xlabel('')

    # -- Plot waveforms --
    plt.subplot2grid((8,5), (7,0), rowspan=1, colspan=3)
    spikesorting.plot_waveforms(waveforms)

    # -- Plot projections --
    plt.subplot2grid((8,5), (6,2), rowspan=1, colspan=3)
    spikesorting.plot_projections(waveforms)

    # -- Plot events in time --
    plt.subplot2grid((8,5), (7,3), rowspan=1, colspan=2)
    spikesorting.plot_events_in_time(spikeTimestamps)

    plt.subplots_adjust(wspace = 0.7)
    
    #plt.show()
    #fig_path = 
    #fig_name = 'TT{0}Cluster{1}{2}.png'.format(tetrode, cluster, '_2afc plot_each_type')
    #full_fig_path = os.path.join(fig_path, fig_name)
    #print full_fig_path
    plt.gcf().set_size_inches((8.5,11))
Example #15
0
def plot_blind_cell_quality(cell):
    plt.clf()
    gs = gridspec.GridSpec(5, 6)
    
    #create cell object for loading data
    cellObj = ephyscore.Cell(cell)
    # -- plot laser pulse raster -- 
    laserEphysData, noBehav = cellObj.load('laserPulse')
    laserEventOnsetTimes = laserEphysData['events']['laserOn']
    laserSpikeTimestamps = laserEphysData['spikeTimes']
    timeRange = [-0.1, 0.4]
    
    plt.subplot(gs[0:2, 0:3])
    laserSpikeTimesFromEventOnset, trialIndexForEachSpike, laserIndexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
    laserSpikeTimestamps, laserEventOnsetTimes, timeRange)
    pRaster, hcond, zline = extraplots.raster_plot(laserSpikeTimesFromEventOnset,laserIndexLimitsEachTrial,timeRange)
    plt.xlabel('Time from laser onset (sec)')
    plt.title('Laser Pulse Raster')
    
    # -- plot laser pulse psth --
    plt.subplot(gs[2:4, 0:3])
    binsize = 10/1000.0
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserSpikeTimestamps, 
                                                                                                                   laserEventOnsetTimes, 
                                                                                                                   [timeRange[0]-binsize, 
                                                                                                                    timeRange[1]])
    binEdges = np.around(np.arange(timeRange[0]-binsize, timeRange[1]+2*binsize, binsize), decimals=2)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
    pPSTH = extraplots.plot_psth(spikeCountMat/binsize, 1, binEdges[:-1])
    plt.xlim(timeRange)
    plt.xlabel('Time from laser onset (sec)')
    plt.ylabel('Firing Rate (Hz)')
    plt.title('Laser Pulse PSTH')
    
    # -- didn't record laser trains for some earlier sessions --
    if len(cellObj.get_session_inds('laserTrain')) > 0:
        # -- plot laser train raster --
        laserTrainEphysData, noBehav = cellObj.load('laserTrain')
        laserTrainEventOnsetTimes = laserTrainEphysData['events']['laserOn']
        laserTrainSpikeTimestamps = laserTrainEphysData['spikeTimes']
        laserTrainEventOnsetTimes = spikesanalysis.minimum_event_onset_diff(laserTrainEventOnsetTimes, 0.5)
        timeRange = [-0.2, 1.0]
        
        plt.subplot(gs[0:2, 3:])
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserTrainSpikeTimestamps, 
                                                                                                                       laserTrainEventOnsetTimes, 
                                                                                                                       timeRange)
        pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,indexLimitsEachTrial,timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.title('Laser Train Raster')
        
        # -- plot laser train psth --
        plt.subplot(gs[2:4, 3:])
        binsize = 10/1000.0
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserTrainSpikeTimestamps, 
                                                                                                                   laserTrainEventOnsetTimes, 
                                                                                                                   [timeRange[0]-binsize, 
                                                                                                                    timeRange[1]])
        binEdges = np.around(np.arange(timeRange[0]-binsize, timeRange[1]+2*binsize, binsize), decimals=2)
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        pPSTH = extraplots.plot_psth(spikeCountMat/binsize, 1, binEdges[:-1])
        plt.xlim(timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Train PSTH')
        
    # -- show cluster analysis --
    #tsThisCluster, wavesThisCluster, recordingNumber = celldatabase.load_all_spikedata(cell)
    # -- Plot ISI histogram --
    plt.subplot(gs[4, 0:2])
    spikesorting.plot_isi_loghist(tsThisCluster)

    # -- Plot waveforms --
    plt.subplot(gs[4, 2:4])
    spikesorting.plot_waveforms(wavesThisCluster)

    # -- Plot events in time --
    plt.subplot(gs[4, 4:6])
    spikesorting.plot_events_in_time(tsThisCluster)
Example #16
0
def plot_bandwidth_report(cell,
                          type='normal',
                          bandTimeRange=[0.2, 1.0],
                          bandBaseRange=[-1.0, -0.2]):
    plt.clf()

    bandIndex = int(cell['bestBandSession'])

    if bandIndex is None:
        print 'No bandwidth session given'
        return

    #create cell object for loading data
    cellObj = ephyscore.Cell(cell)

    #change dimensions of report to add laser trials if they exist
    if len(cellObj.get_session_inds('laserPulse')) > 0:
        laser = True
        gs = gridspec.GridSpec(13, 6)
    else:
        laser = False
        gs = gridspec.GridSpec(9, 6)
    offset = 4 * laser
    gs.update(left=0.15, right=0.85, top=0.96, wspace=0.7, hspace=1.0)

    tetrode = int(cell['tetrode'])
    cluster = int(cell['cluster'])

    #load bandwidth ephys and behaviour data
    bandEphysData, bandBData = cellObj.load_by_index(bandIndex)
    bandEventOnsetTimes = ephysanalysis.get_sound_onset_times(
        bandEphysData, 'bandwidth')
    bandSpikeTimestamps = bandEphysData['spikeTimes']

    timeRange = [-0.2, 1.5]
    bandEachTrial = bandBData['currentBand']
    numBands = np.unique(bandEachTrial)

    #change the trial type that the bandwidth session is split by so we can use this report for Arch-inactivation experiments
    #also changes the colours to be more thematically appropriate! (in Anna's opinion)
    if type == 'laser':
        secondSort = bandBData['laserTrial']
        secondSortLabels = ['no laser', 'laser']
        colours = ['k', '#c4a000']
        errorColours = ['0.5', '#fce94f']
        gaussFitCol = 'gaussFit'
        tuningR2Col = 'tuningFitR2'
    elif type == 'normal':
        secondSort = bandBData['currentAmp']
        secondSortLabels = [
            '{} dB'.format(amp) for amp in np.unique(secondSort)
        ]
        colours = ['#4e9a06', '#5c3566']
        errorColours = ['#8ae234', '#ad7fa8']
        gaussFitCol = 'gaussFit'
        tuningR2Col = 'tuningFitR2'

    charfreq = str(np.unique(bandBData['charFreq'])[0] / 1000)
    modrate = str(np.unique(bandBData['modRate'])[0])
    numBands = np.unique(bandEachTrial)

    # -- plot rasters of the bandwidth trials --
    rasterColours = [
        np.tile([colours[0], errorColours[0]],
                len(numBands) / 2 + 1),
        np.tile([colours[1], errorColours[1]],
                len(numBands) / 2 + 1)
    ]
    plot_separated_rasters(gs, [0, 3],
                           5 + offset,
                           bandEachTrial,
                           secondSort,
                           bandSpikeTimestamps,
                           bandEventOnsetTimes,
                           colours=rasterColours,
                           titles=secondSortLabels,
                           plotHeight=2)

    # -- plot bandwidth tuning curves --
    plt.subplot(gs[5 + offset:, 3:])
    tuningDict = ephysanalysis.calculate_tuning_curve_inputs(
        bandSpikeTimestamps,
        bandEventOnsetTimes,
        bandEachTrial,
        secondSort,
        bandTimeRange,
        baseRange=bandBaseRange,
        info='plotting')
    plot_tuning_curve(tuningDict['responseArray'],
                      tuningDict['errorArray'],
                      numBands,
                      tuningDict['baselineSpikeRate'],
                      linecolours=colours,
                      errorcolours=errorColours)

    # load tuning ephys and behaviour data
    tuningEphysData, tuningBData = cellObj.load('tuningCurve')
    tuningEventOnsetTimes = ephysanalysis.get_sound_onset_times(
        tuningEphysData, 'tuningCurve')
    tuningSpikeTimestamps = tuningEphysData['spikeTimes']

    # -- plot frequency tuning at intensity used in bandwidth trial with gaussian fit --

    # high amp bandwidth trials used to select appropriate frequency
    maxAmp = max(np.unique(bandBData['currentAmp']))
    if maxAmp < 1:
        maxAmp = 66.0  #HARDCODED dB VALUE FOR SESSIONS DONE BEFORE NOISE CALIBRATION

    # find tone intensity that corresponds to tone sessions in bandwidth trial
    toneInt = maxAmp - 15.0  #HARDCODED DIFFERENCE IN TONE AND NOISE AMP BASED ON OSCILLOSCOPE READINGS FROM RIG 2

    freqEachTrial = tuningBData['currentFreq']

    plt.subplot(gs[2 + offset:4 + offset, 0:3])
    plot_tuning_fitted_gaussian(tuningSpikeTimestamps,
                                tuningEventOnsetTimes,
                                tuningBData,
                                toneInt,
                                cell[gaussFitCol],
                                cell[tuningR2Col],
                                timeRange=cell['tuningTimeRange'])

    # -- plot frequency tuning raster --
    plt.subplot(gs[0 + offset:2 + offset, 0:3])
    freqLabels = ["%.1f" % freq for freq in np.unique(freqEachTrial) / 1000.0]
    plot_sorted_raster(tuningSpikeTimestamps,
                       tuningEventOnsetTimes,
                       freqEachTrial,
                       timeRange=[-0.2, 0.6],
                       labels=freqLabels)
    plt.title('Frequency Tuning Raster')

    # -- plot AM PSTH --
    amEphysData, amBData = cellObj.load('AM')
    amEventOnsetTimes = ephysanalysis.get_sound_onset_times(amEphysData, 'AM')
    amSpikeTimestamps = amEphysData['spikeTimes']
    rateEachTrial = amBData['currentFreq']
    timeRange = [-0.2, 1.5]
    colourList = ['b', 'g', 'y', 'orange', 'r']

    plt.subplot(gs[2 + offset:4 + offset, 3:])
    plot_sorted_psth(amSpikeTimestamps,
                     amEventOnsetTimes,
                     rateEachTrial,
                     timeRange=[-0.2, 0.8],
                     binsize=25,
                     colorEachCond=colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Firing rate (Hz)')
    plt.title('AM PSTH')

    # -- plot AM raster --
    plt.subplot(gs[0 + offset:2 + offset, 3:])
    rateLabels = ["%.0f" % rate for rate in np.unique(rateEachTrial)]
    plot_sorted_raster(amSpikeTimestamps,
                       amEventOnsetTimes,
                       rateEachTrial,
                       timeRange=[-0.2, 0.8],
                       labels=rateLabels,
                       colorEachCond=colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Modulation Rate (Hz)')
    plt.title('AM Raster')

    # -- plot laser pulse and laser train data (if available) --
    if laser:
        # -- plot laser pulse raster --
        laserEphysData, noBehav = cellObj.load('laserPulse')
        laserEventOnsetTimes = laserEphysData['events']['laserOn']
        laserSpikeTimestamps = laserEphysData['spikeTimes']
        timeRange = [-0.1, 0.4]

        plt.subplot(gs[0:2, 0:3])
        laserSpikeTimesFromEventOnset, trialIndexForEachSpike, laserIndexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
            laserSpikeTimestamps, laserEventOnsetTimes, timeRange)
        pRaster, hcond, zline = extraplots.raster_plot(
            laserSpikeTimesFromEventOnset, laserIndexLimitsEachTrial,
            timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.title('Laser Pulse Raster')

        # -- plot laser pulse psth --
        plt.subplot(gs[2:4, 0:3])
        binsize = 10 / 1000.0
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
            laserSpikeTimestamps, laserEventOnsetTimes,
            [timeRange[0] - binsize, timeRange[1]])
        binEdges = np.around(np.arange(timeRange[0] - binsize,
                                       timeRange[1] + 2 * binsize, binsize),
                             decimals=2)
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        pPSTH = extraplots.plot_psth(spikeCountMat / binsize, 1, binEdges[:-1])
        plt.xlim(timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Pulse PSTH')

        # -- didn't record laser trains for some earlier sessions --
        if len(cellObj.get_session_inds('laserTrain')) > 0:
            # -- plot laser train raster --
            laserTrainEphysData, noBehav = cellObj.load('laserTrain')
            laserTrainEventOnsetTimes = laserTrainEphysData['events'][
                'laserOn']
            laserTrainSpikeTimestamps = laserTrainEphysData['spikeTimes']
            laserTrainEventOnsetTimes = spikesanalysis.minimum_event_onset_diff(
                laserTrainEventOnsetTimes, 0.5)
            timeRange = [-0.2, 1.0]

            plt.subplot(gs[0:2, 3:])
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                laserTrainSpikeTimestamps, laserTrainEventOnsetTimes,
                timeRange)
            pRaster, hcond, zline = extraplots.raster_plot(
                spikeTimesFromEventOnset, indexLimitsEachTrial, timeRange)
            plt.xlabel('Time from laser onset (sec)')
            plt.title('Laser Train Raster')

            # -- plot laser train psth --
            plt.subplot(gs[2:4, 3:])
            binsize = 10 / 1000.0
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                laserTrainSpikeTimestamps, laserTrainEventOnsetTimes,
                [timeRange[0] - binsize, timeRange[1]])
            binEdges = np.around(np.arange(timeRange[0] - binsize,
                                           timeRange[1] + 2 * binsize,
                                           binsize),
                                 decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
                spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            pPSTH = extraplots.plot_psth(spikeCountMat / binsize, 1,
                                         binEdges[:-1])
            plt.xlim(timeRange)
            plt.xlabel('Time from laser onset (sec)')
            plt.ylabel('Firing Rate (Hz)')
            plt.title('Laser Train PSTH')

    # -- show cluster analysis --
    #tsThisCluster, wavesThisCluster, recordingNumber = celldatabase.load_all_spikedata(cell)
    # -- Plot ISI histogram --
    plt.subplot(gs[4 + offset, 0:2])
    spikesorting.plot_isi_loghist(bandSpikeTimestamps)

    # -- Plot waveforms --
    plt.subplot(gs[4 + offset, 2:4])
    spikesorting.plot_waveforms(bandEphysData['samples'])

    # -- Plot events in time --
    plt.subplot(gs[4 + offset, 4:6])
    spikesorting.plot_events_in_time(bandSpikeTimestamps)
    title = '{0}, {1}, {2}um, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(
        cell['subject'], cell['date'], cell['depth'], tetrode, cluster,
        charfreq, modrate)

    plt.suptitle(title)

    fig_path = '/home/jarauser/Pictures/cell reports'
    fig_name = '{0}_{1}_{2}um_TT{3}Cluster{4}.png'.format(
        cell['subject'], cell['date'], cell['depth'], tetrode, cluster)
    full_fig_path = os.path.join(fig_path, fig_name)
    fig = plt.gcf()
    fig.set_size_inches(20, 25)
    fig.savefig(full_fig_path, format='png', bbox_inches='tight')
Example #17
0
def nick_lan_main_report(siteObj, show=False, save=True, saveClusterReport=True):
    for tetrode in siteObj.goodTetrodes:
        oneTT = cms2.MultipleSessionsToCluster(
            siteObj.animalName,
            siteObj.get_session_filenames(),
            tetrode,
            '{}at{}um'.format(
                siteObj.date,
                siteObj.depth))
        oneTT.load_all_waveforms()

        # Do the clustering if necessary.
        clusterFile = os.path.join(
            oneTT.clustersDir,
            'Tetrode%d.clu.1' %
            oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        possibleClusters = np.unique(oneTT.clusters)

        ee = ee3.EphysExperiment(
            siteObj.animalName,
            siteObj.date,
            experimenter=siteObj.experimenter)

        # Iterate through the clusters, making a new figure for each cluster.
        # for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterInds = siteObj.get_session_inds_one_type(
                plotType='raster',
                report='main')
            mainRasterSessions = [
                siteObj.get_session_filenames()[i] for i in mainRasterInds]
            mainRasterTypes = [
                siteObj.get_session_types()[i] for i in mainRasterInds]

            mainTCinds = siteObj.get_session_inds_one_type(
                plotType='tc_heatmap',
                report='main')
            mainTCsessions = [
                siteObj.get_session_filenames()[i] for i in mainTCinds]

            mainTCbehavIDs = [
                siteObj.get_session_behavIDs()[i] for i in mainTCinds]
            mainTCtypes = [siteObj.get_session_types()[i] for i in mainTCinds]

            # The main report for this cluster/tetrode/session
            plt.figure()

            for indRaster, rasterSession in enumerate(mainRasterSessions):
                plt.subplot2grid(
                    (6, 6), (indRaster, 0), rowspan=1, colspan=3)
                ee.plot_session_raster(
                    rasterSession,
                    tetrode,
                    cluster=cluster,
                    replace=1,
                    ms=1)
                plt.ylabel(
                    '{}\n{}'.format(
                        mainRasterTypes[indRaster],
                        rasterSession.split('_')[1]),
                    fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(ax, 6)

            # We can only do one main TC for now.
            plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)
            tcSession = mainTCsessions[0]
            tcBehavID = mainTCbehavIDs[0]
            ee.plot_session_tc_heatmap(
                tcSession,
                tetrode,
                tcBehavID,
                replace=1,
                cluster=cluster)
            plt.title(
                "{0}\nBehavFileID = '{1}'".format(
                    tcSession,
                    tcBehavID),
                fontsize=10)

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            # if oneTT.clusters == None:
            # oneTT.set_clusters_from_file()
            # for indc, clusterID in enumerate (possibleClusters):
            #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]
            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel(
                'c%d' %
                cluster,
                rotation=0,
                va='center',
                ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            # plt.tight_layout()

            if save: 
                plt.savefig(full_fig_path, format='png')
            if show:
                plt.show()
            if not show:
                plt.close()

        if saveClusterReport:
            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
Example #18
0
            timeRange=timeRange,
            trialsEachCond=trialsEachCond,
            colorEachCond=colorEachCond,
            fillWidth=None,
            labels=None)
        plt.setp(pRaster, ms=msRaster)
        plt.xlabel('Time from sound onset (s)',
                   fontsize=fontSizeLabels,
                   labelpad=labelDis)
        plt.ylabel('Trials', fontsize=fontSizeLabels, labelpad=labelDis)
        plt.title('freq:{}; modInd:{};\n modSig:{}; maxZ:{}'.format(
            middleFreqs[1], cell['modIndexMid2'], cell['modSigMid2'],
            cell['maxZSoundMid2']),
                  fontsize=fontSizeLabels - 2)

        # -- Plot waveform -- #
        ax7 = plt.subplot(gs00[3, :])
        wavesThisCluster = spikeData.samples
        spikesorting.plot_waveforms(wavesThisCluster)

        # -- Plot ISI -- #
        ax8 = plt.subplot(gs01[3, :])
        spikesorting.plot_isi_loghist(spikeTimestamps)

        print 'Saving figure'

        plt.savefig(fullFigname)

    except:
        continue
Example #19
0
def plot_pinp_report(dbRow, saveDir=None, useModifiedClusters=True):
    #Init cell object
    cell = ephyscore.Cell(dbRow, useModifiedClusters=useModifiedClusters)

    plt.clf()
    gs = gridspec.GridSpec(11, 6)
    gs.update(left=0.15, right=0.95, bottom=0.15, wspace=1, hspace=1)

    if 'noiseburst' in dbRow['sessionType']:  #DONE
        ax0 = plt.subplot(gs[0:2, 0:3])
        ephysData, bdata = cell.load('noiseburst')
        eventOnsetTimes = ephysData['events']['stimOn']
        timeRange = [-0.3, 0.5]
        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)
        pRaster, hCond, zLine = extraplots.raster_plot(
            spikeTimesFromEventOnset, indexLimitsEachTrial, timeRange)
        plt.setp(pRaster, ms=1)
        ax0.set_xlim(timeRange)
        ax0.set_xticks([])

        #Laser pulse psth
        ax1 = plt.subplot(gs[4:6, 0:3])
        win = np.array([0, 0.25, 0.75, 1, 0.75, 0.25,
                        0])  # scipy.signal.hanning(7)
        win = win / np.sum(win)
        binEdges = np.arange(timeRange[0], timeRange[-1], 0.001)
        timeVec = binEdges[
            1:]  # FIXME: is this the best way to define the time axis?
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        avResp = np.mean(spikeCountMat, axis=0)
        smoothPSTH = np.convolve(avResp, win, mode='same')
        plt.plot(timeVec, smoothPSTH, 'k-', mec='none', lw=2)
        ax1.set_xlim(timeRange)
        ax1.set_xlabel('Time from noise onset (s)')

    if 'laserpulse' in dbRow['sessionType']:  #DONE
        #Laser pulse raster
        ax0 = plt.subplot(gs[2:4, 0:3])
        ephysData, bdata = cell.load('laserpulse')
        eventOnsetTimes = ephysData['events']['stimOn']
        timeRange = [-0.3, 0.5]
        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)
        pRaster, hCond, zLine = extraplots.raster_plot(
            spikeTimesFromEventOnset, indexLimitsEachTrial, timeRange)
        plt.setp(pRaster, ms=1)
        ax0.set_xlim(timeRange)
        ax0.set_xticks([])

        #Laser pulse psth
        ax1 = plt.subplot(gs[4:6, 0:3])
        win = np.array([0, 0.25, 0.75, 1, 0.75, 0.25,
                        0])  # scipy.signal.hanning(7)
        win = win / np.sum(win)
        binEdges = np.arange(timeRange[0], timeRange[-1], 0.001)
        timeVec = binEdges[
            1:]  # FIXME: is this the best way to define the time axis?
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        avResp = np.mean(spikeCountMat, axis=0)
        smoothPSTH = np.convolve(avResp, win, mode='same')
        plt.plot(timeVec, smoothPSTH, 'k-', mec='none', lw=2)
        ax1.set_xlim(timeRange)
        ax1.set_xlabel('Time from laser pulse onset (s)')

    if 'lasertrain' in dbRow['sessionType']:  #DONE
        #Laser train raster
        ax2 = plt.subplot(gs[2:4, 3:6])
        ephysData, bdata = cell.load('lasertrain')
        eventOnsetTimes = ephysData['events']['stimOn']
        eventOnsetTimes = spikesanalysis.minimum_event_onset_diff(
            eventOnsetTimes, 0.5)

        timeRange = [-0.5, 1]
        pulseTimes = [0, 0.2, 0.4, 0.6, 0.8]

        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)

        pRaster, hCond, zLine = extraplots.raster_plot(
            spikeTimesFromEventOnset, indexLimitsEachTrial, timeRange)
        plt.setp(pRaster, ms=1)
        ax2.set_xlim(timeRange)
        ax2.set_xticks(pulseTimes)

        #Laser train psth
        ax3 = plt.subplot(gs[4:6, 3:6])
        win = np.array([0, 0.25, 0.75, 1, 0.75, 0.25,
                        0])  # scipy.signal.hanning(7)
        win = win / np.sum(win)
        binEdges = np.arange(timeRange[0], timeRange[-1], 0.001)
        timeVec = binEdges[
            1:]  # FIXME: is this the best way to define the time axis?
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        avResp = np.mean(spikeCountMat, axis=0)
        smoothPSTH = np.convolve(avResp, win, mode='same')
        plt.plot(timeVec, smoothPSTH, 'k-', mec='none', lw=2)
        ax3.set_xlim(timeRange)
        ax3.set_xticks(pulseTimes)
        ax3.set_xlabel('Time from first pulse onset (s)')

    #Sorted tuning raster
    if 'tc' in dbRow['sessionType']:  #DONE
        ax4 = plt.subplot(gs[6:8, 0:3])
        ephysData, bdata = cell.load('tc')
        eventOnsetTimes = ephysData['events']['stimOn']
        timeRange = [-0.5, 1]
        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)
        freqEachTrial = bdata['currentFreq']
        possibleFreq = np.unique(freqEachTrial)
        freqLabels = ['{0:.1f}'.format(freq / 1000.0) for freq in possibleFreq]
        trialsEachCondition = behavioranalysis.find_trials_each_type(
            freqEachTrial, possibleFreq)

        pRaster, hCond, zLine = extraplots.raster_plot(
            spikeTimesFromEventOnset,
            indexLimitsEachTrial,
            timeRange,
            trialsEachCond=trialsEachCondition,
            labels=freqLabels)
        plt.setp(pRaster, ms=1)
        ax4.set_ylabel('Frequency (kHz)')

        #TC heatmap
        ax5 = plt.subplot(gs[8:10, 0:3])

        baseRange = [-0.1, 0]
        responseRange = [0, 0.1]
        alignmentRange = [baseRange[0], responseRange[1]]

        freqEachTrial = bdata['currentFreq']
        possibleFreq = np.unique(freqEachTrial)
        intensityEachTrial = bdata['currentIntensity']
        possibleIntensity = np.unique(intensityEachTrial)

        #Init arrays to hold the baseline and response spike counts per condition
        allIntenBase = np.array([])
        allIntenResp = np.empty((len(possibleIntensity), len(possibleFreq)))

        spikeTimes = ephysData['spikeTimes']

        for indinten, inten in enumerate(possibleIntensity):
            spks = np.array([])
            freqs = np.array([])
            base = np.array([])
            for indfreq, freq in enumerate(possibleFreq):
                selectinds = np.flatnonzero((freqEachTrial == freq)
                                            & (intensityEachTrial == inten))
                selectedOnsetTimes = eventOnsetTimes[selectinds]
                (spikeTimesFromEventOnset, trialIndexForEachSpike,
                 indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
                     spikeTimes, selectedOnsetTimes, alignmentRange)
                nspkBase = spikesanalysis.spiketimes_to_spikecounts(
                    spikeTimesFromEventOnset, indexLimitsEachTrial, baseRange)
                nspkResp = spikesanalysis.spiketimes_to_spikecounts(
                    spikeTimesFromEventOnset, indexLimitsEachTrial,
                    responseRange)
                base = np.concatenate([base, nspkBase.ravel()])
                spks = np.concatenate([spks, nspkResp.ravel()])
                # inds = np.concatenate([inds, np.ones(len(nspkResp.ravel()))*indfreq])
                freqs = np.concatenate(
                    [freqs, np.ones(len(nspkResp.ravel())) * freq])
                allIntenBase = np.concatenate([allIntenBase, nspkBase.ravel()])
                allIntenResp[indinten, indfreq] = np.mean(nspkResp)

        lowFreq = possibleFreq.min()
        highFreq = possibleFreq.max()
        nFreqLabels = 3

        freqTickLocations = np.linspace(0, len(possibleFreq), nFreqLabels)
        freqs = np.logspace(np.log10(lowFreq), np.log10(highFreq), nFreqLabels)
        freqs = np.round(freqs, decimals=1)

        nIntenLabels = 3
        intensities = np.linspace(possibleIntensity.min(),
                                  possibleIntensity.max(), nIntenLabels)
        intenTickLocations = np.linspace(0,
                                         len(possibleIntensity) - 1,
                                         nIntenLabels)

        plt.imshow(np.flipud(allIntenResp),
                   interpolation='nearest',
                   cmap='Blues')
        ax5.set_yticks(intenTickLocations)
        ax5.set_yticklabels(intensities[::-1])
        ax5.set_xticks(freqTickLocations)
        freqLabels = ['{0:.1f}'.format(freq) for freq in freqs]
        # ax.set_xticklabels(freqLabels, rotation='vertical')
        ax5.set_xticklabels(freqLabels)
        ax5.set_xlabel('Frequency (kHz)')
        plt.ylabel('Intensity (db SPL)')

        if not pd.isnull(dbRow['threshold']):
            plt.hold(1)
            indThresh = (len(possibleIntensity) - 1) - np.where(
                dbRow['threshold'] == possibleIntensity)[0]
            indCF = np.where(dbRow['cf'] == possibleFreq)[0]
            # import ipdb; ipdb.set_trace()
            ax5.plot(indCF, indThresh, 'r*')
            plt.suptitle('Threshold: {}'.format(dbRow['threshold']))

        if not pd.isnull(dbRow['upperFreq']):
            plt.hold(1)
            threshPlus10 = indThresh - (10 / np.diff(possibleIntensity)[0])
            upperFraction = (np.log2(dbRow['upperFreq']) - np.log2(
                possibleFreq[0])) / (np.log2(possibleFreq[-1]) -
                                     np.log2(possibleFreq[0]))
            indUpper = upperFraction * (len(possibleFreq) - 1)

            lowerFraction = (np.log2(dbRow['lowerFreq']) - np.log2(
                possibleFreq[0])) / (np.log2(possibleFreq[-1]) -
                                     np.log2(possibleFreq[0]))
            indLower = lowerFraction * (len(possibleFreq) - 1)

            # import ipdb; ipdb.set_trace()
            ax5.plot(indUpper, threshPlus10, 'b*')
            ax5.plot(indLower, threshPlus10, 'b*')

    if 'am' in dbRow['sessionType']:  #DONE
        #Sorted am raster
        # ax6 = plt.subplot(gs[4:6, 3:6])
        ax6spec = gs[6:8, 3:6]
        ephysData, bdata = cell.load('am')
        eventOnsetTimes = ephysData['events']['stimOn']

        colors = get_colors(len(np.unique(bdata['currentFreq'])))

        timeRange = [-0.5, 1]
        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)
        # extraplots.raster_plot(spikeTimesFromEventOnset,
        #                        indexLimitsEachTrial,
        #                        timeRange,
        #                        trialsEachCond=bdata['currentFreq'],
        #                        colorsEachCond=colors)
        plot_example_with_rate(ax6spec,
                               spikeTimesFromEventOnset,
                               indexLimitsEachTrial,
                               bdata['currentFreq'],
                               colorEachCond=colors,
                               maxSyncRate=cell.dbRow['highestSyncCorrected'])

        #AM cycle average hist
        psthLineWidth = 2
        ax7 = plt.subplot(gs[8:10, 3:6])

        colorEachCond = colors
        plt.hold(True)
        sortArray = bdata['currentFreq']
        for indFreq, (freq, spikeTimesThisFreq,
                      trialIndicesThisFreq) in enumerate(
                          spiketimes_each_frequency(spikeTimesFromEventOnset,
                                                    trialIndexForEachSpike,
                                                    sortArray)):
            radsPerSec = freq * 2 * np.pi
            spikeRads = (spikeTimesThisFreq * radsPerSec) % (2 * np.pi)
            ax7.hist(spikeRads,
                     bins=20,
                     color=colors[indFreq],
                     histtype='step')

        #AM psth
        # psthLineWidth = 2
        # ax7 = plt.subplot(gs[6:8, 3:6])

        # colorEachCond = colors
        # binsize = 50
        # sortArray = bdata['currentFreq']
        # binsize = binsize/1000.0
        # # If a sort array is supplied, find the trials that correspond to each value of the array
        # trialsEachCond = behavioranalysis.find_trials_each_type(sortArray, np.unique(sortArray))

        # (spikeTimesFromEventOnset,
        # trialIndexForEachSpike,
        # indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(ephysData['spikeTimes'],
        #                                                              eventOnsetTimes,
        #                                                              [timeRange[0]-binsize,
        #                                                               timeRange[1]])

        # binEdges = np.around(np.arange(timeRange[0]-binsize, timeRange[1]+2*binsize, binsize), decimals=2)
        # spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        # pPSTH = extraplots.plot_psth(spikeCountMat/binsize, 1, binEdges[:-1], trialsEachCond, colorEachCond=colors)
        # plt.setp(pPSTH, lw=psthLineWidth)
        # plt.hold(True)
        # zline = plt.axvline(0,color='0.75',zorder=-10)
        # plt.xlim(timeRange)

    (timestamps, samples, recordingNumber) = cell.load_all_spikedata()

    #ISI loghist
    ax8 = plt.subplot(gs[10, 0:2])
    if timestamps is not None:
        try:
            spikesorting.plot_isi_loghist(timestamps)
        except:
            # raise AttributeError
            print "problem with isi vals"

    #Waveforms
    ax9 = plt.subplot(gs[10, 2:4])
    if len(samples) > 0:
        spikesorting.plot_waveforms(samples)

    #Events in time
    ax10 = plt.subplot(gs[10, 4:6])
    if timestamps is not None:
        try:
            spikesorting.plot_events_in_time(timestamps)
        except:
            print "problem with isi vals"

    fig = plt.gcf()
    fig.set_size_inches(8.5 * 2, 11 * 2)

    figName = '{}_{}_{}um_TT{}c{}.png'.format(dbRow['subject'], dbRow['date'],
                                              int(dbRow['depth']),
                                              int(dbRow['tetrode']),
                                              int(dbRow['cluster']))

    plt.suptitle(figName[:-4])

    if saveDir is not None:
        figPath = os.path.join(saveDir, figName)
        plt.savefig(figPath)
Example #20
0
def plot_bandwidth_report(cell, bandIndex, type='normal'):
    plt.clf()
    if bandIndex is None:
        print 'No bandwidth session given'
        return
    
    #create cell object for loading data
    cellObj = ephyscore.Cell(cell)
    
    #change dimensions of report to add laser trials if they exist
    if len(cellObj.get_session_inds('laserPulse'))>0:
        laser = True
        gs = gridspec.GridSpec(13, 6)
    else:
        laser = False
        gs = gridspec.GridSpec(9, 6)
    offset = 4*laser
    gs.update(left=0.15, right=0.85, top = 0.96, wspace=0.7, hspace=1.0)
    
    tetrode=int(cell['tetrode'])
    cluster=int(cell['cluster'])
     
    #load bandwidth ephys and behaviour data
    bandEphysData, bandBData = cellObj.load_by_index(bandIndex)
    bandEventOnsetTimes = ephysanalysis.get_sound_onset_times(bandEphysData, 'bandwidth')
    bandSpikeTimestamps = bandEphysData['spikeTimes']
    
    timeRange = [-0.2, 1.5]
    bandEachTrial = bandBData['currentBand']
    numBands = np.unique(bandEachTrial)

    #change the trial type that the bandwidth session is split by so we can use this report for Arch-inactivation experiments
    #also changes the colours to be more thematically appropriate! (in Anna's opinion)
    if type=='laser':
        secondSort = bandBData['laserTrial']
        secondSortLabels = ['no laser','laser']
        colours = ['k', '#c4a000']
        errorColours = ['0.5', '#fce94f']
        gaussFitCol = 'gaussFit'
        tuningR2Col = 'tuningFitR2'
    elif type=='normal':
        secondSort = bandBData['currentAmp']
        secondSortLabels = ['{} dB'.format(amp) for amp in np.unique(secondSort)]
        colours = ['#4e9a06','#5c3566']
        errorColours = ['#8ae234','#ad7fa8']
        gaussFitCol = 'gaussFit'
        tuningR2Col = 'tuningFitR2'
    
    charfreq = str(np.unique(bandBData['charFreq'])[0]/1000)
    modrate = str(np.unique(bandBData['modRate'])[0])
    numBands = np.unique(bandEachTrial)
            
    # -- plot rasters of the bandwidth trials --     
    rasterColours = [np.tile([colours[0],errorColours[0]],len(numBands)/2+1), np.tile([colours[1],errorColours[1]],len(numBands)/2+1)]  
    plot_separated_rasters(gs, [0,3], 5+offset, bandEachTrial, secondSort, bandSpikeTimestamps, bandEventOnsetTimes, colours=rasterColours, titles=secondSortLabels, plotHeight=2)
           
    # -- plot bandwidth tuning curves --
    plt.subplot(gs[5+offset:, 3:])
    timeRange = [0.2, 1.0]# if type=='normal' else [0.1, 1.1]
    baseRange = [-1.1, -0.3]
    tuningDict = ephysanalysis.calculate_tuning_curve_inputs(bandSpikeTimestamps, bandEventOnsetTimes, bandEachTrial, secondSort, timeRange, info='plotting')
    plot_tuning_curve(tuningDict['responseArray'], tuningDict['errorArray'], numBands, tuningDict['baselineSpikeRate'], linecolours=colours, errorcolours=errorColours)

    # load tuning ephys and behaviour data
    tuningEphysData, tuningBData = cellObj.load('tuningCurve')
    tuningEventOnsetTimes = ephysanalysis.get_sound_onset_times(tuningEphysData,'tuningCurve')
    tuningSpikeTimestamps = tuningEphysData['spikeTimes']       
    
    # -- plot frequency tuning at intensity used in bandwidth trial with gaussian fit -- 
    
    # high amp bandwidth trials used to select appropriate frequency
    maxAmp = max(np.unique(bandBData['currentAmp']))
    if maxAmp < 1:
        maxAmp = 66.0 #HARDCODED dB VALUE FOR SESSIONS DONE BEFORE NOISE CALIBRATION
    
    # find tone intensity that corresponds to tone sessions in bandwidth trial
    toneInt = maxAmp - 15.0 #HARDCODED DIFFERENCE IN TONE AND NOISE AMP BASED ON OSCILLOSCOPE READINGS FROM RIG 2

    freqEachTrial = tuningBData['currentFreq']
    
    plt.subplot(gs[2+offset:4+offset, 0:3])       
    plot_tuning_fitted_gaussian(tuningSpikeTimestamps, tuningEventOnsetTimes, tuningBData, toneInt, cell[gaussFitCol], cell[tuningR2Col], timeRange=cell['tuningTimeRange'])
            
    # -- plot frequency tuning raster --
    plt.subplot(gs[0+offset:2+offset, 0:3])
    freqLabels = ["%.1f" % freq for freq in np.unique(freqEachTrial)/1000.0]
    plot_sorted_raster(tuningSpikeTimestamps, tuningEventOnsetTimes, freqEachTrial, timeRange=[-0.2,0.6], labels=freqLabels)
    plt.title('Frequency Tuning Raster')
            
    # -- plot AM PSTH --
    amEphysData, amBData = cellObj.load('AM')
    amEventOnsetTimes = ephysanalysis.get_sound_onset_times(amEphysData, 'AM')
    amSpikeTimestamps = amEphysData['spikeTimes']   
    rateEachTrial = amBData['currentFreq']
    timeRange = [-0.2, 1.5]
    colourList = ['b', 'g', 'y', 'orange', 'r']
    
    plt.subplot(gs[2+offset:4+offset, 3:])
    plot_sorted_psth(amSpikeTimestamps, amEventOnsetTimes, rateEachTrial, timeRange = [-0.2, 0.8], binsize = 25, colorEachCond = colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Firing rate (Hz)')
    plt.title('AM PSTH')
    
    # -- plot AM raster --
    plt.subplot(gs[0+offset:2+offset, 3:])
    rateLabels = ["%.0f" % rate for rate in np.unique(rateEachTrial)]
    plot_sorted_raster(amSpikeTimestamps, amEventOnsetTimes, rateEachTrial, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Modulation Rate (Hz)')
    plt.title('AM Raster')
    
    # -- plot laser pulse and laser train data (if available) --
    if laser:
        # -- plot laser pulse raster -- 
        laserEphysData, noBehav = cellObj.load('laserPulse')
        laserEventOnsetTimes = laserEphysData['events']['laserOn']
        laserSpikeTimestamps = laserEphysData['spikeTimes']
        timeRange = [-0.1, 0.4]
        
        plt.subplot(gs[0:2, 0:3])
        laserSpikeTimesFromEventOnset, trialIndexForEachSpike, laserIndexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        laserSpikeTimestamps, laserEventOnsetTimes, timeRange)
        pRaster, hcond, zline = extraplots.raster_plot(laserSpikeTimesFromEventOnset,laserIndexLimitsEachTrial,timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.title('Laser Pulse Raster')
        
        # -- plot laser pulse psth --
        plt.subplot(gs[2:4, 0:3])
        binsize = 10/1000.0
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserSpikeTimestamps, 
                                                                                                                       laserEventOnsetTimes, 
                                                                                                                       [timeRange[0]-binsize, 
                                                                                                                        timeRange[1]])
        binEdges = np.around(np.arange(timeRange[0]-binsize, timeRange[1]+2*binsize, binsize), decimals=2)
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        pPSTH = extraplots.plot_psth(spikeCountMat/binsize, 1, binEdges[:-1])
        plt.xlim(timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Pulse PSTH')
        
        # -- didn't record laser trains for some earlier sessions --
        if len(cellObj.get_session_inds('laserTrain')) > 0:
            # -- plot laser train raster --
            laserTrainEphysData, noBehav = cellObj.load('laserTrain')
            laserTrainEventOnsetTimes = laserTrainEphysData['events']['laserOn']
            laserTrainSpikeTimestamps = laserTrainEphysData['spikeTimes']
            laserTrainEventOnsetTimes = spikesanalysis.minimum_event_onset_diff(laserTrainEventOnsetTimes, 0.5)
            timeRange = [-0.2, 1.0]
            
            plt.subplot(gs[0:2, 3:])
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserTrainSpikeTimestamps, 
                                                                                                                           laserTrainEventOnsetTimes, 
                                                                                                                           timeRange)
            pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,indexLimitsEachTrial,timeRange)
            plt.xlabel('Time from laser onset (sec)')
            plt.title('Laser Train Raster')
            
            # -- plot laser train psth --
            plt.subplot(gs[2:4, 3:])
            binsize = 10/1000.0
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(laserTrainSpikeTimestamps, 
                                                                                                                       laserTrainEventOnsetTimes, 
                                                                                                                       [timeRange[0]-binsize, 
                                                                                                                        timeRange[1]])
            binEdges = np.around(np.arange(timeRange[0]-binsize, timeRange[1]+2*binsize, binsize), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            pPSTH = extraplots.plot_psth(spikeCountMat/binsize, 1, binEdges[:-1])
            plt.xlim(timeRange)
            plt.xlabel('Time from laser onset (sec)')
            plt.ylabel('Firing Rate (Hz)')
            plt.title('Laser Train PSTH')
        
    # -- show cluster analysis --
    #tsThisCluster, wavesThisCluster, recordingNumber = celldatabase.load_all_spikedata(cell)
    # -- Plot ISI histogram --
    plt.subplot(gs[4+offset, 0:2])
    spikesorting.plot_isi_loghist(bandSpikeTimestamps)

    # -- Plot waveforms --
    plt.subplot(gs[4+offset, 2:4])
    spikesorting.plot_waveforms(bandEphysData['samples'])

    # -- Plot events in time --
    plt.subplot(gs[4+offset, 4:6])
    spikesorting.plot_events_in_time(bandSpikeTimestamps)
    title = '{0}, {1}, {2}um, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(cell['subject'], 
                                                                                            cell['date'], 
                                                                                            cell['depth'], 
                                                                                            tetrode, 
                                                                                            cluster, 
                                                                                            charfreq, 
                                                                                            modrate)

    plt.suptitle(title)
    
    fig_path = '/home/jarauser/Pictures/cell reports'
    fig_name = '{0}_{1}_{2}um_TT{3}Cluster{4}.png'.format(cell['subject'], cell['date'], cell['depth'], tetrode, cluster)
    full_fig_path = os.path.join(fig_path, fig_name)
    fig = plt.gcf()
    fig.set_size_inches(20, 25)
    fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
plt.figure()
plt.plot(spikeTimesFromEventOnset, trialIndexForEachSpike, 'k.', ms=1)

plt.show()



#Will plot the waveforms of spikes in a certain time range after an event
from jaratoolbox import spikesorting

indexTR = [0, 0.1]
spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial, spikeIndex = spikesanalysis.eventlocked_spiketimes(spikeTimestamps,eventOnsetTimes,indexTR, spikeindex=True)
waves = spikeDataNB.samples[spikeIndex]

figure()
spikesorting.plot_waveforms(waves)


 

#For a session that has behavior data, we can use the object to get the behavior data
#This is the behavior data for the best frequency presentation. 
#At this site, the TC and BF sessions have saved behavior data
bdata = ex0624.get_session_behav_data(s1BF)


#We can also pass ephys filenames, with or without the date attached, to the object instead of handles. 
#Lets look at the LP session for TT6c3
spikeDataLP = ex0624.get_session_spike_data_one_tetrode('2015-06-24_15-25-08', 6)
spikeTimesLP = spikeDataLP.timestamps[spikeDataLP.clusters==3]
Example #22
0
                             eventOnsetTimes,
                             firstSortArray=intensityEachTrial,
                             secondSortArray=freqEachTrial,
                             firstSortLabels=intenLabels,
                             secondSortLabels=freqLabels,
                             timeRange=[0, 0.1])

plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.show()

### I think that the only other plot you need for the figure that you gave me is the spike waveform plot
##The waveforms for cell 9, the one with the good amp modulation

plt.clf()
spikesorting.plot_waveforms(spikeData.samples)
plt.show()

#The plot_waveforms method currently computes the average over only the 40 selected spikes.
#Spikes are selected and then aligned, and the the aligned ones are used to calculate the average.
#We should think about calculating the average over all the spikes if it isn't too much to align them all.

spikeData, eventData, behavData = loader.get_cluster_data(figdb[9], 'AM')

spikeTimestamps = spikeData.timestamps
eventOnsetTimes = loader.get_event_onset_times(eventData)
currentFreq = behavData['currentFreq']
dataplotter.plot_raster(spikeTimestamps,
                        eventOnsetTimes,
                        sortArray=currentFreq)
Example #23
0
                                     smoothWinSizePsth,
                                     timeVec,
                                     trialsEachCond=[],
                                     linestyle=None,
                                     linewidth=lwPsth,
                                     downsamplefactor=downsampleFactorPsth,
                                     colorEachCond='r')
        axLaserpulsePSTH.set_xlim(-0.3, 0.5)
        extraplots.boxoff(plt.gca())
        plt.ylabel('Firing rate\n(spk/s)')
        plt.xlabel('time from onset of the sound(s)')

        # Plot waveforms
        plt.sca(axLaserpulseWaveform)
        if laserWaveform.any():
            allLaserWaves, meanLaserWaves, scaleBar = spikesorting.plot_waveforms(
                laserWaveform)
            plt.setp(meanLaserWaves, color='r')
        plt.title("Laserpulse Waveform")

# TODO: Add tuning curve and AM in same fashion as laserpulse and noiseburst above
# -----------AM---------------------
    if "am" in sessions:
        # Loading data for session
        amEphysData, amBehavData = oneCell.load('am')

        # General variables for am calculations/plotting
        amSpikeTimes = amEphysData['spikeTimes']
        amOnsetTime = amEphysData['events']['soundDetectorOn']
        amCurrentFreq = amBehavData['currentFreq']
        amUniqFreq = np.unique(amCurrentFreq)
        amTimeRange = [-0.2, 0.7]
Example #24
0
def plot_blind_cell_quality(cell):
    plt.clf()
    gs = gridspec.GridSpec(5, 6)

    #create cell object for loading data
    cellObj = ephyscore.Cell(cell)
    # -- plot laser pulse raster --
    laserEphysData, noBehav = cellObj.load('laserPulse')
    laserEventOnsetTimes = laserEphysData['events']['laserOn']
    laserSpikeTimestamps = laserEphysData['spikeTimes']
    timeRange = [-0.1, 0.4]

    plt.subplot(gs[0:2, 0:3])
    laserSpikeTimesFromEventOnset, trialIndexForEachSpike, laserIndexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        laserSpikeTimestamps, laserEventOnsetTimes, timeRange)
    pRaster, hcond, zline = extraplots.raster_plot(
        laserSpikeTimesFromEventOnset, laserIndexLimitsEachTrial, timeRange)
    plt.xlabel('Time from laser onset (sec)')
    plt.title('Laser Pulse Raster')

    # -- plot laser pulse psth --
    plt.subplot(gs[2:4, 0:3])
    binsize = 10 / 1000.0
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        laserSpikeTimestamps, laserEventOnsetTimes,
        [timeRange[0] - binsize, timeRange[1]])
    binEdges = np.around(np.arange(timeRange[0] - binsize,
                                   timeRange[1] + 2 * binsize, binsize),
                         decimals=2)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
        spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
    pPSTH = extraplots.plot_psth(spikeCountMat / binsize, 1, binEdges[:-1])
    plt.xlim(timeRange)
    plt.xlabel('Time from laser onset (sec)')
    plt.ylabel('Firing Rate (Hz)')
    plt.title('Laser Pulse PSTH')

    # -- didn't record laser trains for some earlier sessions --
    if len(cellObj.get_session_inds('laserTrain')) > 0:
        # -- plot laser train raster --
        laserTrainEphysData, noBehav = cellObj.load('laserTrain')
        laserTrainEventOnsetTimes = laserTrainEphysData['events']['laserOn']
        laserTrainSpikeTimestamps = laserTrainEphysData['spikeTimes']
        laserTrainEventOnsetTimes = spikesanalysis.minimum_event_onset_diff(
            laserTrainEventOnsetTimes, 0.5)
        timeRange = [-0.2, 1.0]

        plt.subplot(gs[0:2, 3:])
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
            laserTrainSpikeTimestamps, laserTrainEventOnsetTimes, timeRange)
        pRaster, hcond, zline = extraplots.raster_plot(
            spikeTimesFromEventOnset, indexLimitsEachTrial, timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.title('Laser Train Raster')

        # -- plot laser train psth --
        plt.subplot(gs[2:4, 3:])
        binsize = 10 / 1000.0
        spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
            laserTrainSpikeTimestamps, laserTrainEventOnsetTimes,
            [timeRange[0] - binsize, timeRange[1]])
        binEdges = np.around(np.arange(timeRange[0] - binsize,
                                       timeRange[1] + 2 * binsize, binsize),
                             decimals=2)
        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
        pPSTH = extraplots.plot_psth(spikeCountMat / binsize, 1, binEdges[:-1])
        plt.xlim(timeRange)
        plt.xlabel('Time from laser onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Train PSTH')

    # -- show cluster analysis --
    #tsThisCluster, wavesThisCluster, recordingNumber = celldatabase.load_all_spikedata(cell)
    # -- Plot ISI histogram --
    plt.subplot(gs[4, 0:2])
    spikesorting.plot_isi_loghist(tsThisCluster)

    # -- Plot waveforms --
    plt.subplot(gs[4, 2:4])
    spikesorting.plot_waveforms(wavesThisCluster)

    # -- Plot events in time --
    plt.subplot(gs[4, 4:6])
    spikesorting.plot_events_in_time(tsThisCluster)
Example #25
0
    def generate_reports(self):
        '''
        Generate the reports for all of the sessions in this site. This is where we should interface with
        the multiunit clustering code, since all of the sessions that need to be clustered together have
        been defined at this point. 
        
        FIXME: This method should be able to load some kind of report template perhaps, so that the
        types of reports we can make are not a limited. For instance, what happens when we just have 
        rasters for a site and no tuning curve? Implementing this is a lower priority for now. 
        
        Incorporated lan's code for plotting the cluster reports directly on the main report
        '''
        #FIXME: import another piece of code to do this?

        for tetrode in self.goodTetrodes:
            oneTT = cms2.MultipleSessionsToCluster(self.animalName, self.get_session_filenames(), tetrode, '{}at{}um'.format(self.date, self.depth))
            oneTT.load_all_waveforms()

            #Do the clustering if necessary. 
            clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
            if os.path.isfile(clusterFile):
                oneTT.set_clusters_from_file() 
            else:
                oneTT.create_multisession_fet_files()
                oneTT.run_clustering()
                oneTT.set_clusters_from_file() 

            oneTT.save_single_session_clu_files()
            possibleClusters = np.unique(oneTT.clusters)
            
            exp2 = ee2.EphysExperiment(self.animalName, self.date, experimenter = self.experimenter)

            #Iterate through the clusters, making a new figure for each cluster. 
            for indClust, cluster in enumerate(possibleClusters):

                plt.figure() #The main report for this cluster/tetrode/session
                
                mainRasterInds = self.get_session_inds_one_type(plotType='raster', report='main')
                mainRasterSessions = self.get_session_filenames()[mainRasterInds]
                mainRasterTypes = self.get_session_types()[mainRasterInds]
                
                mainTCinds = self.get_session_inds_one_type(plotType='tc_heatmap', report='main')
                mainTCsessions = self.get_session_filenames()[mainTCinds]
                mainTCbehavIDs = self.get_session_behavIDs()[mainTCinds]
                mainTCtypes = self.get_session_types()[mainTCinds]
                
                for indRaster, rasterSession in enumerate(mainRasterSessions):
                    plt.subplot2grid((6, 6), (indRaster, 0), rowspan = 1, colspan = 3)
                    exp2.plot_session_raster(rasterSession, tetrode, cluster = cluster, replace = 1)
                    plt.ylabel(mainRasterTypes[indRaster])
                    plt.title(rasterSession, fontsize = 10)
                

                #We can only do one main TC for now. 
                plt.subplot2grid((6, 6), (0, 3), rowspan = 3, colspan = 3)
                tcIndex = site.get_session_types().index('tuningCurve')
                tcSession = mainTCsessions[0]
                tcBehavID = mainTCbehavIDs[0]
                exp2.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
                plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)


                nSpikes = len(oneTT.timestamps) 
                nClusters = len(possibleClusters)
                spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
                if oneTT.clusters == None:
                    oneTT.set_clusters_from_file()
                for indc, clusterID in enumerate (possibleClusters):
                    spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

                tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust,:]]
                wavesThisCluster = oneTT.samples[spikesEachCluster[indClust,:],:,:]
                # -- Plot ISI histogram --
                plt.subplot2grid((6,6), (4,0), rowspan=1, colspan=3)
                spikesorting.plot_isi_loghist(tsThisCluster)
                plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')

                # -- Plot waveforms --
                plt.subplot2grid((6,6), (5,0), rowspan=1, colspan=3)
                spikesorting.plot_waveforms(wavesThisCluster)

                # -- Plot projections --
                plt.subplot2grid((6,6), (4,3), rowspan=1, colspan=3)
                spikesorting.plot_projections(wavesThisCluster)  

                # -- Plot events in time --
                plt.subplot2grid((6,6), (5,3), rowspan=1, colspan=3)
                spikesorting.plot_events_in_time(tsThisCluster)

                fig_path = oneTT.clustersDir
                fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
                full_fig_path = os.path.join(fig_path, fig_name)
                print full_fig_path
                plt.tight_layout()
                plt.savefig(full_fig_path, format = 'png')
                #plt.show()
                plt.close()

            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
Example #26
0
                    spkMat[sessInd] / binWidth,
                    smoothWinSizePsth,
                    timeVec,
                    trialsEachCond=[],
                    linestyle=None,
                    linewidth=lwPsth,
                    downsamplefactor=downsampleFactorPsth,
                    colorEachCond='r')
                extraplots.boxoff(plt.gca())
                plt.ylabel('Firing rate\n(spk/s)')
                plt.xlabel('time from onset of the sound(s)')

                # Plot waveforms
                plt.sca(axLaserpulseWaveform)
                if waveF[sessInd].any():
                    all_waves, mean_waves, scale_bar = spikesorting.plot_waveforms(
                        waveF[sessInd])
                    plt.setp(mean_waves, color='r')
                plt.title(con.title())

        # Tuning Curve
        else:
            # ----------if the length doesn't match, abandon last one from trialsEachType
            while indexLimitsEachTrial.shape[1] < trialsEachType.shape[0]:
                trialsEachType = np.delete(trialsEachType, -1, 0)

            plt.sca(axTuningCurveRaster)
            pRaster, hcond, zline = extraplots.raster_plot(
                rast1[sessInd],
                rast2[sessInd],
                timeRange,
                trialsEachCond=trialsEachType)
Example #27
0
def plot_NBQX_report(dbRow, saveDir=None):
    #Init cell object
    cell = ephyscore.Cell(dbRow)

    plt.clf()
    gs = gridspec.GridSpec(4, 2, hspace=0.5, wspace=0.5)
    # gs.update(left=0.15, right=0.95, bottom=0.15, wspace=1, hspace=1)

    gsNoisePre = gridspec.GridSpecFromSubplotSpec(2,
                                                  1,
                                                  subplot_spec=gs[0, 0],
                                                  hspace=0)
    gsPulsePre = gridspec.GridSpecFromSubplotSpec(2,
                                                  1,
                                                  subplot_spec=gs[1, 0],
                                                  hspace=0)
    gsTrainPre = gridspec.GridSpecFromSubplotSpec(2,
                                                  1,
                                                  subplot_spec=gs[2, 0],
                                                  hspace=0)

    gsNoisePost = gridspec.GridSpecFromSubplotSpec(2,
                                                   1,
                                                   subplot_spec=gs[0, 1],
                                                   hspace=0)
    gsPulsePost = gridspec.GridSpecFromSubplotSpec(2,
                                                   1,
                                                   subplot_spec=gs[1, 1],
                                                   hspace=0)
    gsTrainPost = gridspec.GridSpecFromSubplotSpec(2,
                                                   1,
                                                   subplot_spec=gs[2, 1],
                                                   hspace=0)

    gsCluster = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[3, :])

    #Noiseburst Pre

    def plot_raster_and_PSTH(sessiontype, gs, color='k'):
        axRaster = plt.subplot(gs[0])
        ephysData, bdata = cell.load(sessiontype)
        eventOnsetTimes = ephysData['events']['stimOn']
        eventOnsetTimes = spikesanalysis.minimum_event_onset_diff(
            eventOnsetTimes, minEventOnsetDiff=0.5)
        timeRange = [-0.3, 1.0]
        (spikeTimesFromEventOnset, trialIndexForEachSpike,
         indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
             ephysData['spikeTimes'], eventOnsetTimes, timeRange)
        # pRaster, hCond, zLine = extraplots.raster_plot(spikeTimesFromEventOnset,
        #                                                 indexLimitsEachTrial,
        #                                                 timeRange)
        axRaster.plot(spikeTimesFromEventOnset,
                      trialIndexForEachSpike,
                      'k.',
                      ms=1,
                      rasterized=True)
        # plt.setp(pRaster, ms=1)
        axRaster.set_xlim(timeRange)
        axRaster.set_xticks([])
        # axRaster.axis('off')
        extraplots.boxoff(axRaster)
        axRaster.set_yticks([len(eventOnsetTimes)])

        axPSTH = plt.subplot(gs[1])
        smoothPSTH = True
        psthLineWidth = 2
        smoothWinSize = 1
        binsize = 10  #in milliseconds
        binEdges = np.around(np.arange(timeRange[0] - (binsize / 1000.0),
                                       timeRange[1] + 2 * (binsize / 1000.0),
                                       (binsize / 1000.0)),
                             decimals=2)
        winShape = np.concatenate((np.zeros(smoothWinSize),
                                   np.ones(smoothWinSize)))  # Square (causal)
        winShape = winShape / np.sum(winShape)
        psthTimeBase = np.linspace(timeRange[0],
                                   timeRange[1],
                                   num=len(binEdges) - 1)

        spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
            spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)

        thisPSTH = np.mean(spikeCountMat, axis=0)
        if smoothPSTH:
            thisPSTH = np.convolve(thisPSTH, winShape, mode='same')
        ratePSTH = thisPSTH / float(binsize / 1000.0)
        axPSTH.plot(psthTimeBase, ratePSTH, '-', color=color, lw=psthLineWidth)

        displayRange = timeRange
        axPSTH.set_xlim(displayRange)
        extraplots.boxoff(axPSTH)
        axPSTH.set_ylim([0, max(ratePSTH)])
        axPSTH.set_yticks([0, np.floor(np.max(ratePSTH))])
        # axPSTH.set_ylabel('spk/s', fontsize=fontSizeLabels)
        axPSTH.set_ylabel('spk/s')
        # axPSTH.set_xticks([0, 0.3])

        # avResp = np.mean(spikeCountMat,axis=0)
        # smoothPSTH = np.convolve(avResp,win, mode='same')
        # plt.plot(timeVec, smoothPSTH,'k-', mec='none' ,lw=2)
        # axPSTH.set_xlim(timeRange)
        # axPSTH.set_xlabel('Time from onset (s)')

    plot_raster_and_PSTH('noiseburst_pre', gsNoisePre, color=colorNoise)
    plot_raster_and_PSTH('laserpulse_pre', gsPulsePre, color=colorLaser)
    plot_raster_and_PSTH('lasertrain_pre', gsTrainPre, color=colorLaser)
    plot_raster_and_PSTH('noiseburst_post', gsNoisePost, color=colorNoise)
    plot_raster_and_PSTH('laserpulse_post', gsPulsePost, color=colorLaser)
    plot_raster_and_PSTH('lasertrain_post', gsTrainPost, color=colorLaser)

    (timestamps, samples, recordingNumber) = cell.load_all_spikedata()

    #ISI loghist
    axISI = plt.subplot(gsCluster[0])
    if timestamps is not None:
        try:
            spikesorting.plot_isi_loghist(timestamps)
        except:
            # raise AttributeError
            print "problem with isi vals"

    #Waveforms
    axWaves = plt.subplot(gsCluster[1])
    if len(samples) > 0:
        spikesorting.plot_waveforms(samples)

    #Events in time
    axEvents = plt.subplot(gsCluster[2])
    if timestamps is not None:
        try:
            spikesorting.plot_events_in_time(timestamps)
        except:
            print "problem with isi vals"

    fig = plt.gcf()
    fig.set_size_inches(figSize)

    figName = '{}_{}_{}um_TT{}c{}.png'.format(dbRow['subject'], dbRow['date'],
                                              int(dbRow['depth']),
                                              int(dbRow['tetrode']),
                                              int(dbRow['cluster']))

    if dbRow['autoTagged'] == 1:
        autoTaggedStatus = "PASS"
    elif dbRow['autoTagged'] == 0:
        autoTaggedStatus = "FAIL"

    plt.suptitle("{}\nautoTagged:{}".format(figName[:-4], autoTaggedStatus))

    if saveDir is not None:
        figPath = os.path.join(saveDir, figName)
        print "Saving figure to: {}".format(figPath)
        plt.savefig(figPath)
Example #28
0
dbPath = os.path.join(settings.FIGURES_DATA_PATH, studyparams.STUDY_NAME)
dbFilename = os.path.join(dbPath,'celldb_{}.h5'.format(studyparams.STUDY_NAME))

figFormat = 'png'
outputDir = os.path.join(settings.FIGURES_DATA_PATH, studyparams.STUDY_NAME,'reports')

# -- Load the database of cells --
celldb = celldatabase.load_hdf(dbFilename)

for indRow,dbRow in celldb[266:267].iterrows():
    oneCell = ephyscore.Cell(dbRow)

    #ephysData, bdata = oneCell.load('noiseburst')
    #ephysData, bdata = oneCell.load('tc')
    #ephysData, bdata = oneCell.load('standard')
    ephysData, bdata = oneCell.load('oddball')

    spikesorting.plot_waveforms(ephysData['samples'])

    '''
    Saving the figure --------------------------------------------------------------
    '''
    figFilename ='{}_{}_{}um_T{}_c{}_oddwave.{}'.format(dbRow['subject'],dbRow['date'],dbRow['depth'],
            dbRow['tetrode'],dbRow['cluster'],figFormat)
    figFullpath = os.path.join(outputDir,figFilename)
    plt.savefig(figFullpath,format=figFormat)
    plt.gcf().set_size_inches([6,4])

    plt.tight_layout()
    plt.show()
def laser_tc_analysis(site, sitenum):

    '''
    Data analysis function for laser/tuning curve experiments

    This function will take a RecordingSite object, do multisession clustering on it, and save all of the clusters 
    back to the original session cluster directories. We can then use an EphysExperiment object (version 2) 
    to load each session, select clusters, plot the appropriate plots, etc. This code is being removed from 
    the EphysExperiment object because that object should be general and apply to any kind of recording 
    experiment. This function does the data analysis for one specific kind of experiment. 
    
    Args:

        site (RecordingSite object): An instance of the RecordingSite class from the ephys_experiment_v2 module
        sitenum (int): The site number for the site, used for constructing directory names
    
    Example:
    
        from jaratoolbox.test.nick.ephysExperiments import laserTCanalysis
        for indSite, site in enumerate(today.siteList):
            laserTCanalysis.laser_tc_analysis(site, indSite+1)
    '''
    #This is where I should incorporate Lan's sorting function
    #Construct a multiple session clustering object with the session list. 
    for tetrode in site.goodTetrodes:

        oneTT = cms2.MultipleSessionsToCluster(site.animalName, site.get_session_filenames(), tetrode, '{}site{}'.format(site.date, sitenum))
        oneTT.load_all_waveforms()

        #Do the clustering if necessary. 
        clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file() 
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file() 

        oneTT.save_single_session_clu_files()
        '''
        0710: Ran into problems while saving single session clu file:
        ipdb> can't invoke "event" command: application has been destroyed
        while executing
        "event generate $w <<ThemeChanged>>"
        (procedure "ttk::ThemeChanged" line 6)
        invoked from within
        "ttk::ThemeChanged"
        Fixed by commenting out this line. 
        0712: This seemed to be a warning instead of actual error, still generated reports and save        ed files.
        0712: if not saving single session clu files, ran into another problem when ploting single         cluster raster plots. so have to save them and fix the
        ValueError:too many boolean indices
        > /home/languo/src/jaratoolbox/test/nick/ephysExperiments/clusterManySessions_v2.py(119)sav        e_single_session_clu_files()
        118 
        --> 119             clusterNumsThisSession = self.clusters[self.recordingNumber == indSessi        on]
        120             print "Writing .clu.1 file for session {}".format(session)
        This is fixed: old clu files from last clustering is messing me up, making self.clusters an        d self.RecordingNumber be different size.
        '''

        possibleClusters = np.unique(oneTT.clusters)

        #We also need to initialize an EphysExperiment object to get the sessions
        exp2 = ee2.EphysExperiment(site.animalName, site.date, experimenter = site.experimenter)

        #Iterate through the clusters, making a new figure for each cluster. 
        for indClust, cluster in enumerate(possibleClusters): #Using possibleClusters[1:] was a hack to omit cluster 1 which usually contains noise and sometimes don't have spikes in the range for raster plot to run properlyl
            plt.figure(figsize = (8.5,11))

            #The first noise burst raster plot
            plt.subplot2grid((5, 6), (0, 0), rowspan = 1, colspan = 3)
            nbIndex = site.get_session_types().index('noiseBurst')
            nbSession = site.get_session_filenames()[nbIndex]
            
            exp2.plot_session_raster(nbSession, tetrode, cluster = cluster, replace = 1)
            
            plt.ylabel('Noise Bursts')
            plt.title(nbSession, fontsize = 10)

            #The laser pulse raster plot
            plt.subplot2grid((5, 6), (1, 0), rowspan = 1, colspan = 3)
            lpIndex = site.get_session_types().index('laserPulse')
            lpSession = site.get_session_filenames()[lpIndex]
            exp2.plot_session_raster(lpSession, tetrode, cluster = cluster, replace = 1)
            plt.ylabel('Laser Pulses')
            plt.title(lpSession, fontsize = 10)

            #The laser train raster plot
            plt.subplot2grid((5, 6), (2, 0), rowspan = 1, colspan = 3)
            try:  
                ltIndex = site.get_session_types().index('laserTrain')
                ltSession = site.get_session_filenames()[ltIndex]
                exp2.plot_session_raster(ltSession, tetrode, cluster = cluster, replace = 1)
                plt.ylabel('Laser Trains')
                plt.title(ltSession, fontsize = 10)
            except ValueError:
                print 'This session doesnot exist.'

            #The tuning curve
            plt.subplot2grid((5, 6), (0, 3), rowspan = 3, colspan = 3)
            tcIndex = site.get_session_types().index('tuningCurve')
            tcSession = site.get_session_filenames()[tcIndex]
            tcBehavID = site.get_session_behavIDs()[tcIndex]
            exp2.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)

            '''
            The best freq presentation, if a session is not initialized, a Value            Error is raised  when indexing the list returned by the get_ methods            of ee2.RecordingSite. Could use Try Except ValueError?
            '''
           #plt.subplot2grid((6, 6), (3, 0), rowspan=1, colspan=3)
           #bfIndex = site.get_session_types().index('bestFreq')
           #bfSession = site.get_session_filenames()[bfIndex]
           #exp2.plot_session_raster(bfSession, tetrode, cluster = cluster, replace = 1)
           #plt.ylabel('Best Frequency')
           #plt.title(bfSession, fontsize = 10)

           #FIXME: Omitting the laser pulses at different intensities for now

            '''LG0710: Added reports (ISI, waveform, events in time, projections) for each cluster to its sessions summary graph. The MultiSessionClusterReport class initializer calls plot_report automatically... it's hard to get around doing repeated work (i.e. getting bits and pieces of necessary functionality out of this class manually) here without rewriting this class'''

            nSpikes = len(oneTT.timestamps) 
            nClusters = len(possibleClusters)
            spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            if oneTT.clusters == None:
                oneTT.set_clusters_from_file()
            for indc, clusterID in enumerate (possibleClusters):
                spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)
            
            tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust,:]]
            wavesThisCluster = oneTT.samples[spikesEachCluster[indClust,:],:,:]
            # -- Plot ISI histogram --
            plt.subplot2grid((5,6), (3,0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')

            # -- Plot waveforms --
            plt.subplot2grid((5,6), (4,0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((5,6), (3,3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)  
            
            # -- Plot events in time --
            plt.subplot2grid((5,6), (4,3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)
            
            #Save the figure in the multisession clustering folder so that it is easy to find
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            plt.tight_layout()
            plt.savefig(full_fig_path, format = 'png')
            #plt.show()
            plt.close()

        plt.figure()
        oneTT.save_multisession_report()
        plt.close()
Example #30
0
    def generate_main_report(self, siteName):
        '''
        Generate the reports for all of the sessions in this site. This is where we should interface with
        the multiunit clustering code, since all of the sessions that need to be clustered together have
        been defined at this point.

        FIXME: This method should be able to load some kind of report template perhaps, so that the
        types of reports we can make are not a limited. For instance, what happens when we just have
        rasters for a site and no tuning curve? Implementing this is a lower priority for now.

        Incorporated lan's code for plotting the cluster reports directly on the main report
        '''
        #FIXME: import another piece of code to do this?
        #FIXME: OR, break into two functions: one that will do the multisite clustering, and one that
        #knows the type of report that we want. The first one can probably be a method of MSTC, the other
        #should either live in extraplots or should go in someone's directory

        for tetrode in self.goodTetrodes:
            oneTT = cms2.MultipleSessionsToCluster(self.animalName, self.get_session_filenames(), tetrode, '{}_{}'.format(self.date, siteName))
            oneTT.load_all_waveforms()

            #Do the clustering if necessary.
            clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
            if os.path.isfile(clusterFile):
                oneTT.set_clusters_from_file()
            else:
                oneTT.create_multisession_fet_files()
                oneTT.run_clustering()
                oneTT.set_clusters_from_file()

            oneTT.save_single_session_clu_files()
            possibleClusters = np.unique(oneTT.clusters)

            ee = EphysExperiment(self.animalName, self.date, experimenter = self.experimenter)

            #Iterate through the clusters, making a new figure for each cluster.
            #for indClust, cluster in enumerate([3]):
            for indClust, cluster in enumerate(possibleClusters):


                mainRasterInds = self.get_session_inds_one_type(plotType='raster', report='main')
                mainRasterSessions = [self.get_session_filenames()[i] for i in mainRasterInds]
                mainRasterTypes = [self.get_session_types()[i] for i in mainRasterInds]

                mainTCinds = self.get_session_inds_one_type(plotType='tc_heatmap', report='main')
                mainTCsessions = [self.get_session_filenames()[i] for i in mainTCinds]

                mainTCbehavIDs = [self.get_session_behavIDs()[i] for i in mainTCinds]
                mainTCtypes = [self.get_session_types()[i] for i in mainTCinds]

                plt.figure() #The main report for this cluster/tetrode/session

                for indRaster, rasterSession in enumerate(mainRasterSessions):
                    plt.subplot2grid((6, 6), (indRaster, 0), rowspan = 1, colspan = 3)
                    ee.plot_session_raster(rasterSession, tetrode, cluster = cluster, replace = 1, ms=1)
                    plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster], rasterSession.split('_')[1]), fontsize = 10)
                    ax=plt.gca()
                    extraplots.set_ticks_fontsize(ax,6)

                #We can only do one main TC for now.
                if len(mainTCsessions)>0:
                    plt.subplot2grid((6, 6), (0, 3), rowspan = 3, colspan = 3)
                    #tcIndex = site.get_session_types().index('tuningCurve')
                    tcSession = mainTCsessions[0]
                    tcBehavID = mainTCbehavIDs[0]
                    ee.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
                    plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)

                nSpikes = len(oneTT.timestamps)
                nClusters = len(possibleClusters)
                #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
                #if oneTT.clusters == None:
                    #oneTT.set_clusters_from_file()
                #for indc, clusterID in enumerate (possibleClusters):
                    #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

                tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
                wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]
                # -- Plot ISI histogram --
                plt.subplot2grid((6,6), (4,0), rowspan=1, colspan=3)
                spikesorting.plot_isi_loghist(tsThisCluster)
                plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
                plt.xlabel('')

                # -- Plot waveforms --
                plt.subplot2grid((6,6), (5,0), rowspan=1, colspan=3)
                spikesorting.plot_waveforms(wavesThisCluster)

                # -- Plot projections --
                plt.subplot2grid((6,6), (4,3), rowspan=1, colspan=3)
                spikesorting.plot_projections(wavesThisCluster)

                # -- Plot events in time --
                plt.subplot2grid((6,6), (5,3), rowspan=1, colspan=3)
                spikesorting.plot_events_in_time(tsThisCluster)

                plt.subplots_adjust(wspace = 0.7)
                fig_path = oneTT.clustersDir
                fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
                full_fig_path = os.path.join(fig_path, fig_name)
                print full_fig_path
                #plt.tight_layout()
                plt.savefig(full_fig_path, format = 'png')
                #plt.show()
                plt.close()


            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
Example #31
0
     samples,
     recordingNumber) = dataloader.load_all_spikedata(cell)

    #ISI loghist
    ax8 = plt.subplot(gs[4, 0:2])
    if timestamps is not None:
        try:
            spikesorting.plot_isi_loghist(timestamps)
        except:
            # raise AttributeError
            print "problem with isi vals"

    #Waveforms
    ax9 = plt.subplot(gs[4, 2:4])
    if timestamps is not None:
        spikesorting.plot_waveforms(samples)

    #Events in time
    ax10 = plt.subplot(gs[4, 4:6])
    if timestamps is not None:
        try:
            spikesorting.plot_events_in_time(timestamps)
        except:
            print "problem with isi vals"

    fig = plt.gcf()
    fig.set_size_inches(8.5, 11)

    figName = '{}_{}_{}_TT{}c{}.png'.format(cell['subject'],
                                            cell['date'],
                                            int(cell['depth']),
Example #32
0
def nick_lan_main_report(siteObj,
                         show=False,
                         save=True,
                         saveClusterReport=True):
    for tetrode in siteObj.goodTetrodes:
        oneTT = cms2.MultipleSessionsToCluster(
            siteObj.animalName, siteObj.get_session_filenames(), tetrode,
            '{}at{}um'.format(siteObj.date, siteObj.depth))
        oneTT.load_all_waveforms()

        # Do the clustering if necessary.
        clusterFile = os.path.join(oneTT.clustersDir,
                                   'Tetrode%d.clu.1' % oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        possibleClusters = np.unique(oneTT.clusters)

        ee = ee3.EphysExperiment(siteObj.animalName,
                                 siteObj.date,
                                 experimenter=siteObj.experimenter)

        # Iterate through the clusters, making a new figure for each cluster.
        # for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterInds = siteObj.get_session_inds_one_type(
                plotType='raster', report='main')
            mainRasterSessions = [
                siteObj.get_session_filenames()[i] for i in mainRasterInds
            ]
            mainRasterTypes = [
                siteObj.get_session_types()[i] for i in mainRasterInds
            ]

            mainTCinds = siteObj.get_session_inds_one_type(
                plotType='tc_heatmap', report='main')
            mainTCsessions = [
                siteObj.get_session_filenames()[i] for i in mainTCinds
            ]

            mainTCbehavIDs = [
                siteObj.get_session_behavIDs()[i] for i in mainTCinds
            ]
            mainTCtypes = [siteObj.get_session_types()[i] for i in mainTCinds]

            # The main report for this cluster/tetrode/session
            plt.figure()

            for indRaster, rasterSession in enumerate(mainRasterSessions):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)
                ee.plot_session_raster(rasterSession,
                                       tetrode,
                                       cluster=cluster,
                                       replace=1,
                                       ms=1)
                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster],
                                           rasterSession.split('_')[1]),
                           fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(ax, 6)

            # We can only do one main TC for now.
            plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)
            tcSession = mainTCsessions[0]
            tcBehavID = mainTCbehavIDs[0]
            ee.plot_session_tc_heatmap(tcSession,
                                       tetrode,
                                       tcBehavID,
                                       replace=1,
                                       cluster=cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID),
                      fontsize=10)

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            # if oneTT.clusters == None:
            # oneTT.set_clusters_from_file()
            # for indc, clusterID in enumerate (possibleClusters):
            #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]
            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            # plt.tight_layout()

            if save:
                plt.savefig(full_fig_path, format='png')
            if show:
                plt.show()
            if not show:
                plt.close()

        if saveClusterReport:
            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
Example #33
0
                             eventOnsetTimes,
                             firstSortArray=intensityEachTrial,
                             secondSortArray=freqEachTrial,
                             firstSortLabels=intenLabels,
                             secondSortLabels=freqLabels,
                             timeRange=[0, 0.1])

plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.show()

### I think that the only other plot you need for the figure that you gave me is the spike waveform plot
##The waveforms for cell 9, the one with the good amp modulation

plt.clf()
spikesorting.plot_waveforms(spikeData.samples)
plt.show()

#The plot_waveforms method currently computes the average over only the 40 selected spikes.
#Spikes are selected and then aligned, and the the aligned ones are used to calculate the average.
#We should think about calculating the average over all the spikes if it isn't too much to align them all. 


spikeData, eventData, behavData = loader.get_cluster_data(figdb[9], 'AM')


spikeTimestamps = spikeData.timestamps
eventOnsetTimes = loader.get_event_onset_times(eventData)
currentFreq = behavData['currentFreq']
dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray = currentFreq)
def plot_bandwidth_report(cell, bandIndex):
    cellInfo = get_cell_info(cell)
    #pdb.set_trace()
    loader = dataloader.DataLoader(cell['subject'])
    
    if len(cellInfo['laserIndex'])>0:
        laser = True
        gs = gridspec.GridSpec(13, 6)
    else:
        laser = False
        gs = gridspec.GridSpec(9, 6)
    offset = 4*laser
    gs.update(left=0.15, right=0.85, top = 0.96, wspace=0.7, hspace=1.0)
     
     # -- plot bandwidth rasters --
    plt.clf()
    eventData = loader.get_session_events(cellInfo['ephysDirs'][bandIndex])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][bandIndex], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    timeRange = [-0.2, 1.5]
    bandBData = loader.get_session_behavior(cellInfo['behavDirs'][bandIndex])  
    bandEachTrial = bandBData['currentBand']
    ampEachTrial = bandBData['currentAmp']
    charfreq = str(np.unique(bandBData['charFreq'])[0]/1000)
    modrate = str(np.unique(bandBData['modRate'])[0])
    numBands = np.unique(bandEachTrial)
    numAmps = np.unique(ampEachTrial)
            
    firstSortLabels = ['{}'.format(band) for band in np.unique(bandEachTrial)]
    secondSortLabels = ['Amplitude: {}'.format(amp) for amp in np.unique(ampEachTrial)]      
    spikeTimesFromEventOnset, indexLimitsEachTrial, trialsEachCond, firstSortLabels = bandwidth_raster_inputs(eventOnsetTimes, spikeTimestamps, bandEachTrial, ampEachTrial)
    colours = [np.tile(['#4e9a06','#8ae234'],len(numBands)/2+1), np.tile(['#5c3566','#ad7fa8'],len(numBands)/2+1)]
    for ind, secondArrayVal in enumerate(numAmps):
        plt.subplot(gs[5+2*ind+offset:7+2*ind+offset, 0:3])
        trialsThisSecondVal = trialsEachCond[:, :, ind]
        pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,
                                                        indexLimitsEachTrial,
                                                        timeRange,
                                                        trialsEachCond=trialsThisSecondVal,
                                                        labels=firstSortLabels,
                                                        colorEachCond = colours[ind])
        plt.setp(pRaster, ms=4)        
        plt.title(secondSortLabels[ind])
        plt.ylabel('bandwidth (octaves)')
        if ind == len(np.unique(ampEachTrial)) - 1:
            plt.xlabel("Time from sound onset (sec)")
    
           
    # -- plot Yashar plots for bandwidth data --
    plt.subplot(gs[5+offset:, 3:])
    spikeArray, errorArray, baseSpikeRate = band_select(spikeTimestamps, eventOnsetTimes, ampEachTrial, bandEachTrial, timeRange = [0.0, 1.0])
    band_select_plot(spikeArray, errorArray, baseSpikeRate, numBands, legend=True)
            
    # -- plot frequency tuning heat map -- 
    tuningBData = loader.get_session_behavior(cellInfo['behavDirs'][cellInfo['tuningIndex'][-1]])
    freqEachTrial = tuningBData['currentFreq']
    intEachTrial =  tuningBData['currentIntensity']
            
    eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['tuningIndex'][-1]])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['tuningIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    
    plt.subplot(gs[2+offset:4+offset, 0:3])       
    dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                    eventOnsetTimes=eventOnsetTimes,
                                    firstSortArray=intEachTrial,
                                    secondSortArray=freqEachTrial,
                                    firstSortLabels=["%.0f" % inten for inten in np.unique(intEachTrial)],
                                    secondSortLabels=["%.1f" % freq for freq in np.unique(freqEachTrial)/1000.0],
                                    xlabel='Frequency (kHz)',
                                    ylabel='Intensity (dB SPL)',
                                    plotTitle='Frequency Tuning Curve',
                                    flipFirstAxis=False,
                                    flipSecondAxis=False,
                                    timeRange=[0, 0.1])
    plt.ylabel('Intensity (dB SPL)')
    plt.xlabel('Frequency (kHz)')
    plt.title('Frequency Tuning Curve')
            
    # -- plot frequency tuning raster --
    plt.subplot(gs[0+offset:2+offset, 0:3])
    freqLabels = ["%.1f" % freq for freq in np.unique(freqEachTrial)/1000.0]
    dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=freqEachTrial, timeRange=[-0.1, 0.5], labels=freqLabels)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Frequency (kHz)')
    plt.title('Frequency Tuning Raster')
            
    # -- plot AM PSTH --
    amBData = loader.get_session_behavior(cellInfo['behavDirs'][cellInfo['amIndex'][-1]])
    rateEachTrial = amBData['currentFreq']
    
    eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['amIndex'][-1]])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['amIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    timeRange = [-0.2, 1.5]
    
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                spikeTimestamps, eventOnsetTimes, timeRange)
    colourList = ['b', 'g', 'y', 'orange', 'r']
    numRates = np.unique(rateEachTrial)
    trialsEachCond = behavioranalysis.find_trials_each_type(rateEachTrial, numRates)
    plt.subplot(gs[2+offset:4+offset, 3:])
    dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, rateEachTrial, timeRange = [-0.2, 0.8], binsize = 25, colorEachCond = colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Firing rate (Hz)')
    plt.title('AM PSTH')
    
    # -- plot AM raster --
    plt.subplot(gs[0+offset:2+offset, 3:])
    rateLabels = ["%.0f" % rate for rate in np.unique(rateEachTrial)]
    dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=rateEachTrial, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Modulation Rate (Hz)')
    plt.title('AM Raster')
    
    # -- plot laser pulse and laser train data (if available) --
    if laser:
        # -- plot laser pulse raster -- 
        plt.subplot(gs[0:2, 0:3])
        eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['laserIndex'][-1]])
        spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['laserIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
        eventOnsetTimes = loader.get_event_onset_times(eventData)
        spikeTimestamps = spikeData.timestamps
        timeRange = [-0.1, 0.4]
        dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, timeRange=timeRange)
        plt.xlabel('Time from sound onset (sec)')
        plt.title('Laser Pulse Raster')
        
        # -- plot laser pulse psth --
        plt.subplot(gs[2:4, 0:3])
        dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, timeRange = timeRange, binsize = 10)
        plt.xlabel('Time from sound onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Pulse PSTH')
        
        # -- didn't record laser trains for some earlier sessions --
        if len(cellInfo['laserTrainIndex']) > 0:
            # -- plot laser train raster --
            plt.subplot(gs[0:2, 3:])
            eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['laserTrainIndex'][-1]])
            spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['laserTrainIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
            eventOnsetTimes = loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.0]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, timeRange=timeRange)
            plt.xlabel('Time from sound onset (sec)')
            plt.title('Laser Train Raster')
            
            # -- plot laser train psth --
            plt.subplot(gs[2:4, 3:])
            dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, timeRange = timeRange, binsize = 10)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing Rate (Hz)')
            plt.title('Laser Train PSTH')
        
    # -- show cluster analysis --
    tsThisCluster, wavesThisCluster = load_cluster_waveforms(cellInfo)
    
    # -- Plot ISI histogram --
    plt.subplot(gs[4+offset, 0:2])
    spikesorting.plot_isi_loghist(tsThisCluster)
    plt.ylabel('c%d'%cellInfo['cluster'],rotation=0,va='center',ha='center')
    plt.xlabel('')

    # -- Plot waveforms --
    plt.subplot(gs[4+offset, 2:4])
    spikesorting.plot_waveforms(wavesThisCluster)

    # -- Plot events in time --
    plt.subplot(gs[4+offset, 4:6])
    spikesorting.plot_events_in_time(tsThisCluster)

    plt.suptitle('{0}, {1}, {2}um, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(cellInfo['subject'], 
                                                                                            cellInfo['date'], 
                                                                                            cellInfo['depth'], 
                                                                                            cellInfo['tetrode'], 
                                                                                            cellInfo['cluster'], 
                                                                                            charfreq, 
                                                                                            modrate))
    
    fig_path = '/home/jarauser/Pictures/cell reports'
    fig_name = '{0}_{1}_{2}um_TT{3}Cluster{4}.png'.format(cellInfo['subject'], cellInfo['date'], cellInfo['depth'], cellInfo['tetrode'], cellInfo['cluster'])
    full_fig_path = os.path.join(fig_path, fig_name)
    fig = plt.gcf()
    fig.set_size_inches(20, 25)
    fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
Example #35
0
def nick_lan_daily_report(site,
                          siteName,
                          mainRasterInds,
                          mainTCind,
                          pcasort=True):
    '''

    '''

    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)

    for tetrode in site.tetrodes:

        #Tetrodes with no spikes will cause an error when clustering
        try:
            if pcasort:
                oneTT = cluster_site_PCA(site, siteName, tetrode)
            else:
                oneTT = cluster_site(site, siteName, tetrode)
        except AttributeError:
            print "There was an attribute error for tetrode {} at {}".format(
                tetrode, siteName)
            continue

        possibleClusters = np.unique(oneTT.clusters)

        #Iterate through the clusters, making a new figure for each cluster.
        #for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterEphysFilenames = [
                site.get_mouse_relative_ephys_filenames()[i]
                for i in mainRasterInds
            ]
            mainRasterTypes = [
                site.get_session_types()[i] for i in mainRasterInds
            ]
            if mainTCind:
                mainTCsession = site.get_mouse_relative_ephys_filenames(
                )[mainTCind]
                mainTCbehavFilename = site.get_mouse_relative_behav_filenames(
                )[mainTCind]
                mainTCtype = site.get_session_types()[mainTCind]
            else:
                mainTCsession = None

            # plt.figure() #The main report for this cluster/tetrode/session
            plt.clf()

            for indRaster, rasterSession in enumerate(
                    mainRasterEphysFilenames):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)

                rasterSpikes = loader.get_session_spikes(
                    rasterSession, tetrode)
                spikeTimestamps = rasterSpikes.timestamps[rasterSpikes.clusters
                                                          == cluster]

                rasterEvents = loader.get_session_events(rasterSession)
                eventOnsetTimes = loader.get_event_onset_times(rasterEvents)

                dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, ms=1)

                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster],
                                           rasterSession.split('_')[1]),
                           fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(
                    ax, 6)  #Should this go in dataplotter?

            #We can only do one main TC for now.
            if mainTCsession:

                plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)

                bdata = loader.get_session_behavior(mainTCbehavFilename)
                plotTitle = loader.get_session_filename(mainTCsession)
                eventData = loader.get_session_events(mainTCsession)
                spikeData = loader.get_session_spikes(mainTCsession, tetrode)

                spikeTimestamps = spikeData.timestamps[spikeData.clusters ==
                                                       cluster]

                eventOnsetTimes = loader.get_event_onset_times(eventData)

                freqEachTrial = bdata['currentFreq']
                intensityEachTrial = bdata['currentIntensity']

                possibleFreq = np.unique(freqEachTrial)
                possibleIntensity = np.unique(intensityEachTrial)

                xlabel = 'Frequency (kHz)'
                ylabel = 'Intensity (dB SPL)'

                # firstSortLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
                # secondSortLabels = ['{}'.format(inten) for inten in possibleIntensity]

                # dataplotter.two_axis_heatmap(spikeTimestamps,
                #                             eventOnsetTimes,
                #                             freqEachTrial,
                #                             intensityEachTrial,
                #                             firstSortLabels,
                #                             secondSortLabels,
                #                             xlabel,
                #                             ylabel,
                #                             plotTitle=plotTitle,
                #                             flipFirstAxis=False,
                #                             flipSecondAxis=True,
                #                             timeRange=[0, 0.1])

                freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
                intenLabels = ["%.1f" % inten for inten in possibleIntensity]

                dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                             eventOnsetTimes=eventOnsetTimes,
                                             firstSortArray=intensityEachTrial,
                                             secondSortArray=freqEachTrial,
                                             firstSortLabels=intenLabels,
                                             secondSortLabels=freqLabels,
                                             xlabel=xlabel,
                                             ylabel=ylabel,
                                             plotTitle=plotTitle,
                                             flipFirstAxis=True,
                                             flipSecondAxis=False,
                                             timeRange=[0, 0.1])

                plt.title("{0}\n{1}".format(mainTCsession,
                                            mainTCbehavFilename),
                          fontsize=10)
                plt.show()

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=0.7)
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            #plt.tight_layout()
            plt.savefig(full_fig_path, format='png')
def plot_report():
    plt.figure()

    gs = gridspec.GridSpec(2, 2)

    #Plot noiseburst raster
    plt.subplot(gs[0, 0])

    ephysData, bdata = cell.load('noisebursts')

    eventOnsetTimes = ephysData['events']['stimOn']
    spikeTimeStamps = ephysData['spikeTimes']
    timeRange = [-0.1, 1.0]

    trialsEachCond = []

    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        spikeTimeStamps, eventOnsetTimes, timeRange)

    pRaster, hcond, zline = extraplots.raster_plot(
        spikeTimesFromEventOnset,
        indexLimitsEachTrial,
        timeRange,
        trialsEachCond=trialsEachCond)

    xlabel = 'time (s)'
    ylabel = 'Trial'

    plt.title('Noise')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    #Plot tuning curve raster
    plt.subplot(gs[0, 1])

    ephysData, bdata = cell.load('tuningCurve')

    freqEachTrial = bdata['currentFreq']

    eventOnsetTimes = ephysData['events']['stimOn']
    spikeTimeStamps = ephysData['spikeTimes']
    timeRange = [-0.1, 1.0]

    possiblefreqs = np.unique(freqEachTrial)
    freqLabels = [round(x / 1000, 1) for x in possiblefreqs]

    trialsEachCond = behavioranalysis.find_trials_each_type(
        freqEachTrial, possiblefreqs)
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
        spikeTimeStamps, eventOnsetTimes, timeRange)

    #print len(freqEachTrial), len(eventOnsetTimes)

    pRaster, hcond, zline = extraplots.raster_plot(
        spikeTimesFromEventOnset,
        indexLimitsEachTrial,
        timeRange,
        trialsEachCond=trialsEachCond,
        labels=freqLabels)

    xlabel = 'time (s)'
    ylabel = 'Trial'

    plt.title('Tuning Curve')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    #show cluster analysis
    #tsThisCluster, wavesThisCluster = load_cluster_waveforms(cellInfo)
    idString = 'exp{}site{}'.format(dbRow['experimentInd'],
                                    cellInfo['siteInd'])
    oneTT = cms2.MultipleSessionsToCluster(cellInfo['subject'],
                                           cellInfo['ephysDirs'],
                                           cellInfo['tetrode'], idString)
    oneTT.load_all_waveforms()
    clusterFile = os.path.join(oneTT.clustersDir,
                               'Tetrode%d.clu.1' % oneTT.tetrode)
    oneTT.set_clusters_from_file()
    tsThisCluster = oneTT.timestamps[oneTT.clusters == cellInfo['cluster']]
    wavesThisCluster = oneTT.samples[oneTT.clusters == cellInfo['cluster']]
    return tsThisCluster, wavesThisCluster

    # -- Plot ISI histogram --
    plt.subplot(gs[1, 0])
    spikesorting.plot_isi_loghist(tsThisCluster)
    plt.ylabel('c%d' % cellInfo['cluster'],
               rotation=0,
               va='center',
               ha='center')
    plt.xlabel('')

    # -- Plot waveforms --
    plt.subplot(gs[0, 1])
    spikesorting.plot_waveforms(wavesThisCluster)

    #plt.setp(pRaster, ms=ms)
    print("Saving Cell " + str(cellInd))
    #figname = '/home/jarauser/data/reports_alex/dapa008/test/{}_{}_depth{}_T{}_C{}.png'.format(dbRow['date'], dbRow['ephysTime'][sessionInd], int(dbRow['depth']), dbRow['tetrode'], dbRow['cluster'])
    figname = '/home/jarauser/data/reports_alex/dapa008/test/test.png'
    plt.savefig(figname)

    plt.clf()
def laser_tc_analysis(site, sitenum):
    '''
    Data analysis function for laser/tuning curve experiments

    This function will take a RecordingSite object, do multisession clustering on it, and save all of the clusters 
    back to the original session cluster directories. We can then use an EphysExperiment object (version 2) 
    to load each session, select clusters, plot the appropriate plots, etc. This code is being removed from 
    the EphysExperiment object because that object should be general and apply to any kind of recording 
    experiment. This function does the data analysis for one specific kind of experiment. 
    
    Args:

        site (RecordingSite object): An instance of the RecordingSite class from the ephys_experiment_v2 module
        sitenum (int): The site number for the site, used for constructing directory names
    
    Example:
    
        from jaratoolbox.test.nick.ephysExperiments import laserTCanalysis
        for indSite, site in enumerate(today.siteList):
            laserTCanalysis.laser_tc_analysis(site, indSite+1)
    '''
    #This is where I should incorporate Lan's sorting function
    #Construct a multiple session clustering object with the session list.
    for tetrode in site.goodTetrodes:

        oneTT = cms2.MultipleSessionsToCluster(
            site.animalName, site.get_session_filenames(), tetrode,
            '{}site{}'.format(site.date, sitenum))
        oneTT.load_all_waveforms()

        #Do the clustering if necessary.
        clusterFile = os.path.join(oneTT.clustersDir,
                                   'Tetrode%d.clu.1' % oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        '''
        0710: Ran into problems while saving single session clu file:
        ipdb> can't invoke "event" command: application has been destroyed
        while executing
        "event generate $w <<ThemeChanged>>"
        (procedure "ttk::ThemeChanged" line 6)
        invoked from within
        "ttk::ThemeChanged"
        Fixed by commenting out this line. 
        0712: This seemed to be a warning instead of actual error, still generated reports and save        ed files.
        0712: if not saving single session clu files, ran into another problem when ploting single         cluster raster plots. so have to save them and fix the
        ValueError:too many boolean indices
        > /home/languo/src/jaratoolbox/test/nick/ephysExperiments/clusterManySessions_v2.py(119)sav        e_single_session_clu_files()
        118 
        --> 119             clusterNumsThisSession = self.clusters[self.recordingNumber == indSessi        on]
        120             print "Writing .clu.1 file for session {}".format(session)
        This is fixed: old clu files from last clustering is messing me up, making self.clusters an        d self.RecordingNumber be different size.
        '''

        possibleClusters = np.unique(oneTT.clusters)

        #We also need to initialize an EphysExperiment object to get the sessions
        exp2 = ee2.EphysExperiment(site.animalName,
                                   site.date,
                                   experimenter=site.experimenter)

        #Iterate through the clusters, making a new figure for each cluster.
        for indClust, cluster in enumerate(
                possibleClusters
        ):  #Using possibleClusters[1:] was a hack to omit cluster 1 which usually contains noise and sometimes don't have spikes in the range for raster plot to run properlyl
            plt.figure(figsize=(8.5, 11))

            #The first noise burst raster plot
            plt.subplot2grid((5, 6), (0, 0), rowspan=1, colspan=3)
            nbIndex = site.get_session_types().index('noiseBurst')
            nbSession = site.get_session_filenames()[nbIndex]

            exp2.plot_session_raster(nbSession,
                                     tetrode,
                                     cluster=cluster,
                                     replace=1)

            plt.ylabel('Noise Bursts')
            plt.title(nbSession, fontsize=10)

            #The laser pulse raster plot
            plt.subplot2grid((5, 6), (1, 0), rowspan=1, colspan=3)
            lpIndex = site.get_session_types().index('laserPulse')
            lpSession = site.get_session_filenames()[lpIndex]
            exp2.plot_session_raster(lpSession,
                                     tetrode,
                                     cluster=cluster,
                                     replace=1)
            plt.ylabel('Laser Pulses')
            plt.title(lpSession, fontsize=10)

            #The laser train raster plot
            plt.subplot2grid((5, 6), (2, 0), rowspan=1, colspan=3)
            try:
                ltIndex = site.get_session_types().index('laserTrain')
                ltSession = site.get_session_filenames()[ltIndex]
                exp2.plot_session_raster(ltSession,
                                         tetrode,
                                         cluster=cluster,
                                         replace=1)
                plt.ylabel('Laser Trains')
                plt.title(ltSession, fontsize=10)
            except ValueError:
                print 'This session doesnot exist.'

            #The tuning curve
            plt.subplot2grid((5, 6), (0, 3), rowspan=3, colspan=3)
            tcIndex = site.get_session_types().index('tuningCurve')
            tcSession = site.get_session_filenames()[tcIndex]
            tcBehavID = site.get_session_behavIDs()[tcIndex]
            exp2.plot_session_tc_heatmap(tcSession,
                                         tetrode,
                                         tcBehavID,
                                         replace=1,
                                         cluster=cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID),
                      fontsize=10)
            '''
            The best freq presentation, if a session is not initialized, a Value            Error is raised  when indexing the list returned by the get_ methods            of ee2.RecordingSite. Could use Try Except ValueError?
            '''
            #plt.subplot2grid((6, 6), (3, 0), rowspan=1, colspan=3)
            #bfIndex = site.get_session_types().index('bestFreq')
            #bfSession = site.get_session_filenames()[bfIndex]
            #exp2.plot_session_raster(bfSession, tetrode, cluster = cluster, replace = 1)
            #plt.ylabel('Best Frequency')
            #plt.title(bfSession, fontsize = 10)

            #FIXME: Omitting the laser pulses at different intensities for now
            '''LG0710: Added reports (ISI, waveform, events in time, projections) for each cluster to its sessions summary graph. The MultiSessionClusterReport class initializer calls plot_report automatically... it's hard to get around doing repeated work (i.e. getting bits and pieces of necessary functionality out of this class manually) here without rewriting this class'''

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            spikesEachCluster = np.empty((nClusters, nSpikes), dtype=bool)
            if oneTT.clusters == None:
                oneTT.set_clusters_from_file()
            for indc, clusterID in enumerate(possibleClusters):
                spikesEachCluster[indc, :] = (oneTT.clusters == clusterID)

            tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust, :]]
            wavesThisCluster = oneTT.samples[spikesEachCluster[
                indClust, :], :, :]
            # -- Plot ISI histogram --
            plt.subplot2grid((5, 6), (3, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % clusterID, rotation=0, va='center', ha='center')

            # -- Plot waveforms --
            plt.subplot2grid((5, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((5, 6), (3, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((5, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            #Save the figure in the multisession clustering folder so that it is easy to find
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            plt.tight_layout()
            plt.savefig(full_fig_path, format='png')
            #plt.show()
            plt.close()

        plt.figure()
        oneTT.save_multisession_report()
        plt.close()