Beispiel #1
0
def plot_cluster_tuning(clusterObj,
                        indTC,
                        experimenter='nick',
                        *args,
                        **kwargs):
    loader = dataloader.DataLoader('offline', experimenter=experimenter)
    spikeData, eventData, behavData = loader.get_cluster_data(
        clusterObj, indTC)

    spikeTimestamps = spikeData.timestamps
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    freqEachTrial = behavData['currentFreq']
    intensityEachTrial = behavData['currentIntensity']

    possibleFreq = np.unique(freqEachTrial)
    possibleIntensity = np.unique(intensityEachTrial)

    xlabel = 'Frequency (kHz)'
    ylabel = 'Intensity (dB SPL)'

    freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
    intenLabels = ["%.1f" % inten for inten in possibleIntensity]

    # plt.clf()
    ax, cax, cbar = dataplotter.two_axis_heatmap(
        spikeTimestamps,
        eventOnsetTimes,
        firstSortArray=intensityEachTrial,
        secondSortArray=freqEachTrial,
        firstSortLabels=intenLabels,
        secondSortLabels=freqLabels,
        timeRange=[0, 0.1],
        *args,
        **kwargs)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    return ax, cax, cbar
def plot_bandwidth_report(cell, bandIndex):
    cellInfo = get_cell_info(cell)
    #pdb.set_trace()
    loader = dataloader.DataLoader(cell['subject'])
    
    if len(cellInfo['laserIndex'])>0:
        laser = True
        gs = gridspec.GridSpec(13, 6)
    else:
        laser = False
        gs = gridspec.GridSpec(9, 6)
    offset = 4*laser
    gs.update(left=0.15, right=0.85, top = 0.96, wspace=0.7, hspace=1.0)
     
     # -- plot bandwidth rasters --
    plt.clf()
    eventData = loader.get_session_events(cellInfo['ephysDirs'][bandIndex])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][bandIndex], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    timeRange = [-0.2, 1.5]
    bandBData = loader.get_session_behavior(cellInfo['behavDirs'][bandIndex])  
    bandEachTrial = bandBData['currentBand']
    ampEachTrial = bandBData['currentAmp']
    charfreq = str(np.unique(bandBData['charFreq'])[0]/1000)
    modrate = str(np.unique(bandBData['modRate'])[0])
    numBands = np.unique(bandEachTrial)
    numAmps = np.unique(ampEachTrial)
            
    firstSortLabels = ['{}'.format(band) for band in np.unique(bandEachTrial)]
    secondSortLabels = ['Amplitude: {}'.format(amp) for amp in np.unique(ampEachTrial)]      
    spikeTimesFromEventOnset, indexLimitsEachTrial, trialsEachCond, firstSortLabels = bandwidth_raster_inputs(eventOnsetTimes, spikeTimestamps, bandEachTrial, ampEachTrial)
    colours = [np.tile(['#4e9a06','#8ae234'],len(numBands)/2+1), np.tile(['#5c3566','#ad7fa8'],len(numBands)/2+1)]
    for ind, secondArrayVal in enumerate(numAmps):
        plt.subplot(gs[5+2*ind+offset:7+2*ind+offset, 0:3])
        trialsThisSecondVal = trialsEachCond[:, :, ind]
        pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,
                                                        indexLimitsEachTrial,
                                                        timeRange,
                                                        trialsEachCond=trialsThisSecondVal,
                                                        labels=firstSortLabels,
                                                        colorEachCond = colours[ind])
        plt.setp(pRaster, ms=4)        
        plt.title(secondSortLabels[ind])
        plt.ylabel('bandwidth (octaves)')
        if ind == len(np.unique(ampEachTrial)) - 1:
            plt.xlabel("Time from sound onset (sec)")
    
           
    # -- plot Yashar plots for bandwidth data --
    plt.subplot(gs[5+offset:, 3:])
    spikeArray, errorArray, baseSpikeRate = band_select(spikeTimestamps, eventOnsetTimes, ampEachTrial, bandEachTrial, timeRange = [0.0, 1.0])
    band_select_plot(spikeArray, errorArray, baseSpikeRate, numBands, legend=True)
            
    # -- plot frequency tuning heat map -- 
    tuningBData = loader.get_session_behavior(cellInfo['behavDirs'][cellInfo['tuningIndex'][-1]])
    freqEachTrial = tuningBData['currentFreq']
    intEachTrial =  tuningBData['currentIntensity']
            
    eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['tuningIndex'][-1]])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['tuningIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    
    plt.subplot(gs[2+offset:4+offset, 0:3])       
    dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                    eventOnsetTimes=eventOnsetTimes,
                                    firstSortArray=intEachTrial,
                                    secondSortArray=freqEachTrial,
                                    firstSortLabels=["%.0f" % inten for inten in np.unique(intEachTrial)],
                                    secondSortLabels=["%.1f" % freq for freq in np.unique(freqEachTrial)/1000.0],
                                    xlabel='Frequency (kHz)',
                                    ylabel='Intensity (dB SPL)',
                                    plotTitle='Frequency Tuning Curve',
                                    flipFirstAxis=False,
                                    flipSecondAxis=False,
                                    timeRange=[0, 0.1])
    plt.ylabel('Intensity (dB SPL)')
    plt.xlabel('Frequency (kHz)')
    plt.title('Frequency Tuning Curve')
            
    # -- plot frequency tuning raster --
    plt.subplot(gs[0+offset:2+offset, 0:3])
    freqLabels = ["%.1f" % freq for freq in np.unique(freqEachTrial)/1000.0]
    dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=freqEachTrial, timeRange=[-0.1, 0.5], labels=freqLabels)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Frequency (kHz)')
    plt.title('Frequency Tuning Raster')
            
    # -- plot AM PSTH --
    amBData = loader.get_session_behavior(cellInfo['behavDirs'][cellInfo['amIndex'][-1]])
    rateEachTrial = amBData['currentFreq']
    
    eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['amIndex'][-1]])
    spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['amIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    spikeTimestamps = spikeData.timestamps
    timeRange = [-0.2, 1.5]
    
    spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                spikeTimestamps, eventOnsetTimes, timeRange)
    colourList = ['b', 'g', 'y', 'orange', 'r']
    numRates = np.unique(rateEachTrial)
    trialsEachCond = behavioranalysis.find_trials_each_type(rateEachTrial, numRates)
    plt.subplot(gs[2+offset:4+offset, 3:])
    dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, rateEachTrial, timeRange = [-0.2, 0.8], binsize = 25, colorEachCond = colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Firing rate (Hz)')
    plt.title('AM PSTH')
    
    # -- plot AM raster --
    plt.subplot(gs[0+offset:2+offset, 3:])
    rateLabels = ["%.0f" % rate for rate in np.unique(rateEachTrial)]
    dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=rateEachTrial, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
    plt.xlabel('Time from sound onset (sec)')
    plt.ylabel('Modulation Rate (Hz)')
    plt.title('AM Raster')
    
    # -- plot laser pulse and laser train data (if available) --
    if laser:
        # -- plot laser pulse raster -- 
        plt.subplot(gs[0:2, 0:3])
        eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['laserIndex'][-1]])
        spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['laserIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
        eventOnsetTimes = loader.get_event_onset_times(eventData)
        spikeTimestamps = spikeData.timestamps
        timeRange = [-0.1, 0.4]
        dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, timeRange=timeRange)
        plt.xlabel('Time from sound onset (sec)')
        plt.title('Laser Pulse Raster')
        
        # -- plot laser pulse psth --
        plt.subplot(gs[2:4, 0:3])
        dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, timeRange = timeRange, binsize = 10)
        plt.xlabel('Time from sound onset (sec)')
        plt.ylabel('Firing Rate (Hz)')
        plt.title('Laser Pulse PSTH')
        
        # -- didn't record laser trains for some earlier sessions --
        if len(cellInfo['laserTrainIndex']) > 0:
            # -- plot laser train raster --
            plt.subplot(gs[0:2, 3:])
            eventData = loader.get_session_events(cellInfo['ephysDirs'][cellInfo['laserTrainIndex'][-1]])
            spikeData = loader.get_session_spikes(cellInfo['ephysDirs'][cellInfo['laserTrainIndex'][-1]], cellInfo['tetrode'], cluster=cellInfo['cluster'])
            eventOnsetTimes = loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.0]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, timeRange=timeRange)
            plt.xlabel('Time from sound onset (sec)')
            plt.title('Laser Train Raster')
            
            # -- plot laser train psth --
            plt.subplot(gs[2:4, 3:])
            dataplotter.plot_psth(spikeTimestamps, eventOnsetTimes, timeRange = timeRange, binsize = 10)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing Rate (Hz)')
            plt.title('Laser Train PSTH')
        
    # -- show cluster analysis --
    tsThisCluster, wavesThisCluster = load_cluster_waveforms(cellInfo)
    
    # -- Plot ISI histogram --
    plt.subplot(gs[4+offset, 0:2])
    spikesorting.plot_isi_loghist(tsThisCluster)
    plt.ylabel('c%d'%cellInfo['cluster'],rotation=0,va='center',ha='center')
    plt.xlabel('')

    # -- Plot waveforms --
    plt.subplot(gs[4+offset, 2:4])
    spikesorting.plot_waveforms(wavesThisCluster)

    # -- Plot events in time --
    plt.subplot(gs[4+offset, 4:6])
    spikesorting.plot_events_in_time(tsThisCluster)

    plt.suptitle('{0}, {1}, {2}um, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(cellInfo['subject'], 
                                                                                            cellInfo['date'], 
                                                                                            cellInfo['depth'], 
                                                                                            cellInfo['tetrode'], 
                                                                                            cellInfo['cluster'], 
                                                                                            charfreq, 
                                                                                            modrate))
    
    fig_path = '/home/jarauser/Pictures/cell reports'
    fig_name = '{0}_{1}_{2}um_TT{3}Cluster{4}.png'.format(cellInfo['subject'], cellInfo['date'], cellInfo['depth'], cellInfo['tetrode'], cellInfo['cluster'])
    full_fig_path = os.path.join(fig_path, fig_name)
    fig = plt.gcf()
    fig.set_size_inches(20, 25)
    fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
Beispiel #3
0
#Dataplotter base plotting functions
from jaratest.nick.database import dataplotter
reload(dataplotter)

#Test plot_raster
dataplotter.plot_raster(spikeTimestamps,
                        eventOnsetTimes,
                        sortArray=bdata['currentFreq'])

#Test two_axis_sorted_raster
dataplotter.two_axis_sorted_raster(spikeTimestamps, eventOnsetTimes,
                                   bdata['currentFreq'],
                                   bdata['currentIntensity'])

#Test two_axis_heatmap - have to reverse first and second array
dataplotter.two_axis_heatmap(spikeTimestamps, eventOnsetTimes,
                             bdata['currentIntensity'], bdata['currentFreq'])

#Test one_axis_tc_or_rlf
#TC
dataplotter.one_axis_tc_or_rlf(spikeTimestamps,
                               eventOnsetTimes,
                               sortArray=bdata['currentFreq'])
#RLF
dataplotter.one_axis_tc_or_rlf(spikeTimestamps,
                               eventOnsetTimes,
                               sortArray=bdata['currentIntensity'])

#Plot waveforms in event locked timerange

import matplotlib.pyplot as plt
plt.close('all')
Beispiel #4
0
intensityEachTrial = bdata['currentIntensity']
freqEachTrial = bdata['currentFreq']
possibleFreq = np.unique(freqEachTrial)
possibleIntensity = np.unique(intensityEachTrial)

clusterSpikeTimes = spikeData.timestamps

plt.close('all')
plt.figure()
dataplotter.two_axis_heatmap(spikeData.timestamps,
                             eventOnsetTimes,
                             firstSortArray=bdata['currentIntensity'],
                             secondSortArray=bdata['currentFreq'],
                             flipFirstAxis=False,
                             firstSortLabels=np.unique(
                                 bdata['currentIntensity']),
                             secondSortLabels=[
                                 '{:.3}'.format(freq / 1000.)
                                 for freq in np.unique(bdata['currentFreq'])
                             ],
                             timeRange=[0, 0.2])
plt.show()

plt.figure()
dataplotter.plot_raster(spikeData.timestamps,
                        eventOnsetTimes,
                        sortArray=bdata['currentFreq'])
plt.show()

plt.figure()
Beispiel #5
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0] / 1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']

    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()

            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3],
                                                     tetrode,
                                                     cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)

            firstSortLabels = [
                '{}'.format(band) for band in np.unique(currentBand)
            ]
            secondSortLabels = [
                'Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)
            ]

            trialsEachCond = behavioranalysis.find_trials_each_combination(
                currentBand, numBands, currentAmp, numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5 * ind, 0), rowspan=4, colspan=7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(
                    spikeTimesFromEventOnset,
                    indexLimitsEachTrial,
                    timeRange,
                    trialsEachCond=trialsThisSecondVal,
                    labels=firstSortLabels)
                plt.setp(pRaster, ms=4)

                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")

            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12, 15), (10, 0), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.0, 1.0],
                             title='bandwidth selectivity')
            plt.subplot2grid((12, 15), (10, 3), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.2, 1.0],
                             title='first 200ms excluded')

            # -- plot frequency tuning heat map --
            plt.subplot2grid((12, 15), (5, 7), rowspan=4, colspan=4)

            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps

            dataplotter.two_axis_heatmap(
                spikeTimestamps=spikeTimestamps,
                eventOnsetTimes=eventOnsetTimes,
                firstSortArray=currentInt,
                secondSortArray=currentFreq,
                firstSortLabels=[
                    "%.1f" % inten for inten in np.unique(currentInt)
                ],
                secondSortLabels=[
                    "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
                ],
                xlabel='Frequency (kHz)',
                ylabel='Intensity (dB SPL)',
                plotTitle='Frequency Tuning Curve',
                flipFirstAxis=True,
                flipSecondAxis=False,
                timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')

            # -- plot frequency tuning raster --
            plt.subplot2grid((12, 15), (0, 7), rowspan=4, colspan=4)
            freqLabels = [
                "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
            ]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentFreq,
                                    timeRange=[-0.1, 0.5],
                                    labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')

            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(
                currentRate, numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
                spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12, 15), (5, 11), rowspan=4, colspan=4)
            pPSTH = extraplots.plot_psth(spikeCountMat / 0.05,
                                         1,
                                         binEdges[:-1],
                                         trialsEachCond,
                                         colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')

            # -- plot AM raster --
            plt.subplot2grid((12, 15), (0, 11), rowspan=4, colspan=4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentRate,
                                    timeRange=[-0.2, 0.8],
                                    labels=rateLabels,
                                    colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')

            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((12, 15), (10, 6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12, 15), (10, 9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12, 15), (10, 12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12, 15), (11, 12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=1.5)
            plt.suptitle(
                '{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'
                .format(mouse, date, siteName, tetrode, cluster, charfreq,
                        modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format='png', bbox_inches='tight')
cluster = 5
spikeTimestamps = spikeData.timestamps[spikeData.clusters == cluster]
eventOnsetTimes = ei.loader.get_event_onset_times(eventData)

freqEachTrial = bdata['currentFreq']
intensityEachTrial = bdata['currentIntensity']

possibleFreq = np.unique(freqEachTrial)
possibleIntensity = np.unique(intensityEachTrial)

xlabel = 'Frequency (kHz)'
ylabel = 'Intensity (dB SPL)'

freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
intenLabels = ["%.1f" % inten for inten in possibleIntensity]

dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                             eventOnsetTimes=eventOnsetTimes,
                             firstSortArray=intensityEachTrial,
                             secondSortArray=freqEachTrial,
                             firstSortLabels=intenLabels,
                             secondSortLabels=freqLabels,
                             xlabel=xlabel,
                             ylabel=ylabel,
                             plotTitle=plotTitle,
                             flipFirstAxis=True,
                             flipSecondAxis=False,
                             timeRange=[0, 0.1])

##The problem was that I was passing the first/second sort labels as the possible values, screwing up the finding each trial combo
def plot_cell_phys(cell, rastertypes, tctype):

    loader = dataloader.DataLoader(cell['subject'])

    fig = plt.clf()

    #Plot raster sessions
    for indRaster, rastertype in enumerate(rastertypes):
        plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)

        try:
            sessiontypeIndex = cell['sessiontype'].index(rastertype)
        except ValueError:  #The cell does not have this session type
            continue

        sessionEphys = cell['ephys'][sessiontypeIndex]

        rasterSpikes = loader.get_session_spikes(sessionEphys,
                                                 int(cell['tetrode']),
                                                 cluster=int(cell['cluster']))
        spikeTimestamps = rasterSpikes.timestamps

        rasterEvents = loader.get_session_events(sessionEphys)
        eventOnsetTimes = loader.get_event_onset_times(rasterEvents)

        dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, ms=1)

        plt.ylabel('{}\n{}'.format(rastertype,
                                   sessionEphys.split('_')[1]),
                   fontsize=10)
        ax = plt.gca()
        extraplots.set_ticks_fontsize(ax, 6)

    #Plot tuning curve
    try:
        sessiontypeIndex = cell['sessiontype'].index(tctype)

        plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)

        bdata = loader.get_session_behavior(cell['behavior'][sessiontypeIndex])
        eventData = loader.get_session_events(cell['ephys'][sessiontypeIndex])
        spikeData = loader.get_session_spikes(cell['ephys'][sessiontypeIndex],
                                              int(cell['tetrode']),
                                              cluster=int(cell['cluster']))

        spikeTimestamps = spikeData.timestamps

        eventOnsetTimes = loader.get_event_onset_times(eventData)

        freqEachTrial = bdata['currentFreq']
        intensityEachTrial = bdata['currentIntensity']

        possibleFreq = np.unique(freqEachTrial)
        possibleIntensity = np.unique(intensityEachTrial)

        xlabel = 'Frequency (kHz)'
        ylabel = 'Intensity (dB SPL)'

        freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
        intenLabels = ["%.1f" % inten for inten in possibleIntensity]

        dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                     eventOnsetTimes=eventOnsetTimes,
                                     firstSortArray=intensityEachTrial,
                                     secondSortArray=freqEachTrial,
                                     firstSortLabels=intenLabels,
                                     secondSortLabels=freqLabels,
                                     xlabel=xlabel,
                                     ylabel=ylabel,
                                     plotTitle='',
                                     flipFirstAxis=True,
                                     flipSecondAxis=False,
                                     timeRange=[0, 0.1],
                                     cmap='magma')

        # plt.title("{0}\n{1}".format(mainTCsession, mainTCbehavFilename), fontsize = 10)

    except ValueError:  #The cell does not have this session type
        print 'no tc'

    plt.show()
def plot_cell_tc(cell, tctype):
    loader = dataloader.DataLoader(cell['subject'])
    # try:
    sessiontypeIndex = cell['sessiontype'].index(tctype)

    plt.cla()

    bdata = loader.get_session_behavior(cell['behavior'][sessiontypeIndex])
    eventData = loader.get_session_events(cell['ephys'][sessiontypeIndex])
    spikeData = loader.get_session_spikes(cell['ephys'][sessiontypeIndex],
                                          int(cell['tetrode']),
                                          cluster=int(cell['cluster']))

    spikeTimestamps = spikeData.timestamps

    eventOnsetTimes = loader.get_event_onset_times(eventData)

    freqEachTrial = bdata['currentFreq']
    intensityEachTrial = bdata['currentIntensity']

    possibleFreq = np.unique(freqEachTrial)
    possibleIntensity = np.unique(intensityEachTrial)

    xlabel = 'Frequency (kHz)'
    ylabel = 'Intensity (dB SPL)'

    freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
    intenLabels = ["%d" % inten for inten in possibleIntensity]

    ax, cax, cbar, spikeArray = dataplotter.two_axis_heatmap(
        spikeTimestamps=spikeTimestamps,
        eventOnsetTimes=eventOnsetTimes,
        firstSortArray=intensityEachTrial,
        secondSortArray=freqEachTrial,
        firstSortLabels=intenLabels,
        secondSortLabels=freqLabels,
        xlabel=xlabel,
        ylabel=ylabel,
        plotTitle='',
        flipFirstAxis=False,
        flipSecondAxis=False,
        timeRange=[0, 0.1],
        cmap='YlOrRd')

    fontsize = 20

    ax.set_xticks(np.linspace(0, 15, 3))
    ax.set_xticklabels(
        np.round(np.logspace(np.log10(2), np.log10(40), 3), decimals=1))
    ax.set_xlabel('Frequency (kHz)', fontsize=fontsize)
    ax.set_ylabel('Intensity (dB SPL)', fontsize=fontsize)
    cbar.ax.yaxis.labelpad = -10

    maxFr = np.max(spikeArray.ravel())
    print maxFr
    cbar.set_clim(0, maxFr)
    cbar.set_ticks([0, maxFr])
    cbar.set_ticklabels([0, np.int(maxFr * 10)])
    cbar.set_label('Firing rate (Hz)', fontsize=fontsize)
    plt.show()

    extraplots.set_ticks_fontsize(ax, fontsize)
    extraplots.set_ticks_fontsize(cbar.ax, fontsize)
Beispiel #9
0
def nick_lan_daily_report(site,
                          siteName,
                          mainRasterInds,
                          mainTCind,
                          pcasort=True):
    '''

    '''

    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)

    for tetrode in site.tetrodes:

        #Tetrodes with no spikes will cause an error when clustering
        try:
            if pcasort:
                oneTT = cluster_site_PCA(site, siteName, tetrode)
            else:
                oneTT = cluster_site(site, siteName, tetrode)
        except AttributeError:
            print "There was an attribute error for tetrode {} at {}".format(
                tetrode, siteName)
            continue

        possibleClusters = np.unique(oneTT.clusters)

        #Iterate through the clusters, making a new figure for each cluster.
        #for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterEphysFilenames = [
                site.get_mouse_relative_ephys_filenames()[i]
                for i in mainRasterInds
            ]
            mainRasterTypes = [
                site.get_session_types()[i] for i in mainRasterInds
            ]
            if mainTCind:
                mainTCsession = site.get_mouse_relative_ephys_filenames(
                )[mainTCind]
                mainTCbehavFilename = site.get_mouse_relative_behav_filenames(
                )[mainTCind]
                mainTCtype = site.get_session_types()[mainTCind]
            else:
                mainTCsession = None

            # plt.figure() #The main report for this cluster/tetrode/session
            plt.clf()

            for indRaster, rasterSession in enumerate(
                    mainRasterEphysFilenames):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)

                rasterSpikes = loader.get_session_spikes(
                    rasterSession, tetrode)
                spikeTimestamps = rasterSpikes.timestamps[rasterSpikes.clusters
                                                          == cluster]

                rasterEvents = loader.get_session_events(rasterSession)
                eventOnsetTimes = loader.get_event_onset_times(rasterEvents)

                dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, ms=1)

                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster],
                                           rasterSession.split('_')[1]),
                           fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(
                    ax, 6)  #Should this go in dataplotter?

            #We can only do one main TC for now.
            if mainTCsession:

                plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)

                bdata = loader.get_session_behavior(mainTCbehavFilename)
                plotTitle = loader.get_session_filename(mainTCsession)
                eventData = loader.get_session_events(mainTCsession)
                spikeData = loader.get_session_spikes(mainTCsession, tetrode)

                spikeTimestamps = spikeData.timestamps[spikeData.clusters ==
                                                       cluster]

                eventOnsetTimes = loader.get_event_onset_times(eventData)

                freqEachTrial = bdata['currentFreq']
                intensityEachTrial = bdata['currentIntensity']

                possibleFreq = np.unique(freqEachTrial)
                possibleIntensity = np.unique(intensityEachTrial)

                xlabel = 'Frequency (kHz)'
                ylabel = 'Intensity (dB SPL)'

                # firstSortLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
                # secondSortLabels = ['{}'.format(inten) for inten in possibleIntensity]

                # dataplotter.two_axis_heatmap(spikeTimestamps,
                #                             eventOnsetTimes,
                #                             freqEachTrial,
                #                             intensityEachTrial,
                #                             firstSortLabels,
                #                             secondSortLabels,
                #                             xlabel,
                #                             ylabel,
                #                             plotTitle=plotTitle,
                #                             flipFirstAxis=False,
                #                             flipSecondAxis=True,
                #                             timeRange=[0, 0.1])

                freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
                intenLabels = ["%.1f" % inten for inten in possibleIntensity]

                dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                             eventOnsetTimes=eventOnsetTimes,
                                             firstSortArray=intensityEachTrial,
                                             secondSortArray=freqEachTrial,
                                             firstSortLabels=intenLabels,
                                             secondSortLabels=freqLabels,
                                             xlabel=xlabel,
                                             ylabel=ylabel,
                                             plotTitle=plotTitle,
                                             flipFirstAxis=True,
                                             flipSecondAxis=False,
                                             timeRange=[0, 0.1])

                plt.title("{0}\n{1}".format(mainTCsession,
                                            mainTCbehavFilename),
                          fontsize=10)
                plt.show()

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=0.7)
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            #plt.tight_layout()
            plt.savefig(full_fig_path, format='png')
Beispiel #10
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0]/1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])  
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])  
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']
      
    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()
            
            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3], tetrode, cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)
            
            firstSortLabels = ['{}'.format(band) for band in np.unique(currentBand)]
            secondSortLabels = ['Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)]
            
            trialsEachCond = behavioranalysis.find_trials_each_combination(currentBand, 
                                                                           numBands, 
                                                                           currentAmp, 
                                                                           numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5*ind, 0), rowspan = 4, colspan = 7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,
                                                                indexLimitsEachTrial,
                                                                timeRange,
                                                                trialsEachCond=trialsThisSecondVal,
                                                                labels=firstSortLabels)
                plt.setp(pRaster, ms=4)
                
                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")
            
            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12,15), (10,0), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.0, 1.0], title='bandwidth selectivity')
            plt.subplot2grid((12,15), (10,3), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.2, 1.0], title='first 200ms excluded')
            
            # -- plot frequency tuning heat map -- 
            plt.subplot2grid((12, 15), (5, 7), rowspan = 4, colspan = 4)
            
            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            
            dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                            eventOnsetTimes=eventOnsetTimes,
                                            firstSortArray=currentInt,
                                            secondSortArray=currentFreq,
                                            firstSortLabels=["%.1f" % inten for inten in np.unique(currentInt)],
                                            secondSortLabels=["%.1f" % freq for freq in np.unique(currentFreq)/1000.0],
                                            xlabel='Frequency (kHz)',
                                            ylabel='Intensity (dB SPL)',
                                            plotTitle='Frequency Tuning Curve',
                                            flipFirstAxis=True,
                                            flipSecondAxis=False,
                                            timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')
            
            # -- plot frequency tuning raster --
            plt.subplot2grid((12,15), (0, 7), rowspan = 4, colspan = 4)
            freqLabels = ["%.1f" % freq for freq in np.unique(currentFreq)/1000.0]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentFreq, timeRange=[-0.1, 0.5], labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')
            
            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(currentRate, 
                                                                           numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12,15), (5, 11), rowspan = 4, colspan = 4)
            pPSTH = extraplots.plot_psth(spikeCountMat/0.05, 1, binEdges[:-1], trialsEachCond, colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')
            
            # -- plot AM raster --
            plt.subplot2grid((12,15), (0, 11), rowspan = 4, colspan = 4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentRate, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')
            
            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]
            
            # -- Plot ISI histogram --
            plt.subplot2grid((12,15), (10,6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12,15), (10,9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12,15), (10,12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12,15), (11,12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace = 1.5)
            plt.suptitle('{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(mouse, date, siteName, tetrode, cluster, charfreq, modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
loader = dataloader.DataLoader('offline', experimenter='nick')

dbFn = '/home/nick/data/database/nick_thalamus_cells.json'
db = cellDB.CellDB()
db.load_from_json(dbFn)


c = db[0]
spikeData, eventData, behavData = loader.get_cluster_data(c, 'tc_heatmap')

spikeTimes = spikeData.timestamps
eventOnsetTimes = loader.get_event_onset_times(eventData)

figure()
dataplotter.two_axis_heatmap(spikeTimes, eventOnsetTimes, behavData['currentIntensity'], behavData['currentFreq'])
title('All Spikes')
show()

# figure()
# spikesorting.plot_waveforms(spikeData.samples)
# title('all spikes')

# figure()
# dataplotter.plot_raster(spikeTimes, eventOnsetTimes)
# show()

fet = spikesorting.calculate_features(spikeData.samples, ['peak', 'valley', 'energy'])
# fet = spikesorting.calculate_features(spikeData.samples, ['peak'])

cw = clustercutting.AdvancedClusterCutter(spikeData.samples)