예제 #1
0
    def __init__(self,
                 animalName,
                 date,
                 experimenter,
                 defaultParadigm=None,
                 defaultTetrodes=[3, 4, 5, 6],
                 serverUser='******',
                 serverName='jarahub',
                 serverBehavPathBase='/data/behavior'):

        self.animalName = animalName
        self.date = date
        self.experimenter = experimenter
        self.defaultParadigm = defaultParadigm
        self.defaultTetrodes = defaultTetrodes

        self.loader = dataloader.DataLoader('online', animalName, date,
                                            experimenter, defaultParadigm)

        self.serverUser = serverUser
        self.serverName = serverName
        self.serverBehavPathBase = serverBehavPathBase
        self.experimenter = experimenter
        self.serverBehavPath = os.path.join(self.serverBehavPathBase,
                                            self.experimenter, self.animalName)
        self.remoteBehavLocation = '{0}@{1}:{2}'.format(
            self.serverUser, self.serverName, self.serverBehavPath)
def plot_cluster_tuning(clusterObj, indTC, experimenter='nick', *args, **kwargs):
    loader = dataloader.DataLoader('offline', experimenter=experimenter)
    spikeData, eventData, behavData = loader.get_cluster_data(clusterObj, indTC)

    spikeTimestamps = spikeData.timestamps
    eventOnsetTimes = loader.get_event_onset_times(eventData)
    freqEachTrial = behavData['currentFreq']
    intensityEachTrial = behavData['currentIntensity']

    possibleFreq = np.unique(freqEachTrial)
    possibleIntensity = np.unique(intensityEachTrial)

    xlabel='Frequency (kHz)'
    ylabel='Intensity (dB SPL)'

    freqLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
    intenLabels = ["%.1f" % inten for inten in possibleIntensity]

    # plt.clf()
    ax, cax, cbar = dataplotter.two_axis_heatmap(spikeTimestamps,
                                                 eventOnsetTimes,
                                                 firstSortArray=intensityEachTrial,
                                                 secondSortArray=freqEachTrial,
                                                 firstSortLabels=intenLabels,
                                                 secondSortLabels=freqLabels,
                                                 timeRange=[0, 0.1],
                                                 *args,
                                                 **kwargs)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    return ax, cax, cbar
예제 #3
0
def calculate_site_response(site, siteName, sessionInd, maxZonly=False):

    from jaratoolbox import spikesanalysis

    #Zscore settings from billy

    baseRange = [-0.050,-0.025]              # Baseline range (in seconds)
    binTime = baseRange[1]-baseRange[0]         # Time-bin size
    responseTimeRange = [-0.5,1]       #Time range to calculate z value for (should be divisible by binTime
    responseTime = responseTimeRange[1]-responseTimeRange[0]
    numBins = responseTime/binTime
    binEdges = np.arange(responseTimeRange[0], responseTimeRange[1], binTime)
    timeRange = [-0.5, 1]

    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)

    sessionEphys = site.get_mouse_relative_ephys_filenames()[sessionInd]

    siteClusterMaxZ = {}
    siteClusterPval = {}
    siteClusterZstat = {}

    for tetrode in site.tetrodes:
        oneTT = cluster_site(site, siteName, tetrode, report=False)
        possibleClusters=np.unique(oneTT.clusters)


        for indClust, cluster in enumerate(possibleClusters):

            rasterSpikes = loader.get_session_spikes(sessionEphys, tetrode)
            spikeTimes = rasterSpikes.timestamps[rasterSpikes.clusters==cluster]
            rasterEvents = loader.get_session_events(sessionEphys)
            eventOnsetTimes = loader.get_event_onset_times(rasterEvents)


            (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
            spikesanalysis.eventlocked_spiketimes(spikeTimes,eventOnsetTimes,timeRange)


            [zStat,pValue,maxZ] = spikesanalysis.response_score(spikeTimesFromEventOnset,indexLimitsEachTrial,baseRange,binEdges) #computes z score for each bin. zStat is array of z scores. maxZ is maximum value of z in timeRange

            tetClustName = '{0}T{1}c{2}'.format(siteName, tetrode, cluster)
            siteClusterMaxZ[tetClustName] = maxZ
            siteClusterPval[tetClustName] = pValue
            siteClusterZstat[tetClustName] = zStat


    if maxZonly:
        return siteClusterMaxZ
    else:
      return siteClusterZstat, siteClusterPval, siteClusterMaxZ
예제 #4
0
def am_mod_report(site, siteName, amSessionInd):
    '''

    '''
    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)

    for tetrode in site.tetrodes:
        try:
            oneTT = cluster_site(site, siteName, tetrode)
        except AttributeError:
            print "There was an attribute error for tetrode {} at {}".format(tetrode, siteName)
            continue
        possibleClusters=np.unique(oneTT.clusters)

        for indClust, cluster in enumerate(possibleClusters):


            amFilename = site.get_mouse_relative_ephys_filenames()[amSessionInd]
            amBehav = site.get_mouse_relative_behav_filenames()[amSessionInd]

            plt.clf()

            spikeData = loader.get_session_spikes(amFilename, tetrode, cluster=cluster)
            spikeTimes = spikeData.timestamps

            eventData = loader.get_session_events(amFilename)
            eventOnsetTimes = loader.get_event_onset_times(eventData)

            bdata = loader.get_session_behavior(amBehav)

            currentFreq = bdata['currentFreq']

            dataplotter.plot_raster(spikeTimes, eventOnsetTimes, sortArray=currentFreq)
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}_Amp_Mod.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            plt.savefig(full_fig_path, format = 'png')
reload(simple_spike_selector)

from jaratoolbox.test.nick.database import dataloader
reload(dataloader)

from jaratoolbox.test.nick.database import cellDB
reload(cellDB)

from jaratoolbox.test.nick.database import dataplotter
from jaratoolbox import spikesorting
from jaratoolbox.test.nick import clustercutting

from pylab import *


loader = dataloader.DataLoader('offline', experimenter='nick')

dbFn = '/home/nick/data/database/nick_thalamus_cells.json'
db = cellDB.CellDB()
db.load_from_json(dbFn)


c = db[0]
spikeData, eventData, behavData = loader.get_cluster_data(c, 'tc_heatmap')

spikeTimes = spikeData.timestamps
eventOnsetTimes = loader.get_event_onset_times(eventData)
currentFreq = behavData['currentFreq']
currentIntensity = behavData['currentIntensity']
trialsToPlot = currentIntensity==70
eventOnsetTimes=eventOnsetTimes[trialsToPlot]
예제 #6
0
def compare_session_spike_waveforms(spikeSamples, spikeTimes, eventOnsetTimes,
                                    rasterTimeRange, timeRangeList):

    fig = plt.figure()
    plt.subplot2grid((3, 2), (0, 0), rowspan=2, colspan=2)
    dataplotter.plot_raster(spikeTimes,
                            eventOnsetTimes,
                            timeRange=rasterTimeRange)

    for ind, tr in enumerate(timeRangeList):

        plt.subplot2grid((3, 2), (2, ind), rowspan=1, colspan=1)
        dataplotter.plot_waveforms_in_event_locked_timerange(
            spikeSamples, spikeTimes, eventOnsetTimes, tr)

    plt.subplots_adjust(hspace=0.7)


if __name__ == "__main__":
    loader = dataloader.DataLoader('online', 'pinp005', '2015-07-30',
                                   'laser_tuning_curve')
    spikeData = loader.get_session_spikes('22-10-33', 4, cluster=8)
    events = loader.get_session_events('22-10-33')
    eventOnsetTimes = loader.get_event_onset_times(events)

    spikeTimes = spikeData.timestamps
    waveforms = spikeData.samples
    compare_session_spike_waveforms(waveforms, spikeTimes, eventOnsetTimes,
                                    [-0.5, 1], [[-0.2, 0], [0, 0.1]])

#TODO: It would be awesome if we could show the spikes on the raster in a color that corresponds to the waveforms
예제 #7
0
##Getting the data from a cluster to make a plot
# -----------------------

#Sessions with no behav data return just the ephys
cell1NoisePhys = cell1.get_data_filenames('noiseBurst')

#Sessions with behav data return the tuple (ephysFilename, behavFilename)
cell1TuningPhys, cell1TuningBehavior = cell1.get_data_filenames('tcHeatmap')

# -----------------------
##Initialize an offline data loader
# -----------------------

#For now we still need to specify the experimenter for offline data analysis
#since the behavior data is broken up by experimenter
loader = dataloader.DataLoader('offline', experimenter='nick')

# -----------------------
##Get ephys and behavior data by passing the filenames from the cluster
# -----------------------

cell1NoiseSpikesTT6 = loader.get_session_spikes(cell1NoisePhys,
                                                6,
                                                cluster=cell1.cluster)
cell1NoiseEvents = loader.get_session_events(cell1NoisePhys)
cell1TuningBdata = loader.get_session_behavior(cell1TuningBehavior)
eventOnsetTimes = loader.get_event_onset_times(cell1NoiseEvents)
spikeTimes = cell1NoiseSpikesTT6.timestamps

# -----------------------
##Make a raster plot
예제 #8
0
def nick_lan_daily_report(site, siteName, mainRasterInds, mainTCind, pcasort=True):
    '''

    '''

    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)


    for tetrode in site.tetrodes:

        #Tetrodes with no spikes will cause an error when clustering
        try:
            if pcasort:
                oneTT = cluster_site_PCA(site, siteName, tetrode)
            else:
                oneTT = cluster_site(site, siteName, tetrode)
        except AttributeError:
            print "There was an attribute error for tetrode {} at {}".format(tetrode, siteName)
            continue

        possibleClusters=np.unique(oneTT.clusters)


        #Iterate through the clusters, making a new figure for each cluster.
        #for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):


            mainRasterEphysFilenames = [site.get_mouse_relative_ephys_filenames()[i] for i in mainRasterInds]
            mainRasterTypes = [site.get_session_types()[i] for i in mainRasterInds]
            if mainTCind:
                mainTCsession = site.get_mouse_relative_ephys_filenames()[mainTCind]
                mainTCbehavFilename = site.get_mouse_relative_behav_filenames()[mainTCind]
                mainTCtype = site.get_session_types()[mainTCind]
            else:
                mainTCsession=None

            # plt.figure() #The main report for this cluster/tetrode/session
            plt.clf()

            for indRaster, rasterSession in enumerate(mainRasterEphysFilenames):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan = 1, colspan = 3)

                rasterSpikes = loader.get_session_spikes(rasterSession, tetrode)
                spikeTimestamps = rasterSpikes.timestamps[rasterSpikes.clusters==cluster]

                rasterEvents = loader.get_session_events(rasterSession)
                eventOnsetTimes = loader.get_event_onset_times(rasterEvents)

                dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, ms=1)

                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster], rasterSession.split('_')[1]), fontsize = 10)
                ax=plt.gca()
                extraplots.set_ticks_fontsize(ax,6) #Should this go in dataplotter?

            #We can only do one main TC for now.
            if mainTCsession:

                plt.subplot2grid((6, 6), (0, 3), rowspan = 3, colspan = 3)


                bdata = loader.get_session_behavior(mainTCbehavFilename)
                plotTitle = loader.get_session_filename(mainTCsession)
                eventData = loader.get_session_events(mainTCsession)
                spikeData = loader.get_session_spikes(mainTCsession, tetrode)

                spikeTimestamps = spikeData.timestamps[spikeData.clusters==cluster]

                eventOnsetTimes = loader.get_event_onset_times(eventData)

                freqEachTrial = bdata['currentFreq']
                intensityEachTrial = bdata['currentIntensity']

                possibleFreq = np.unique(freqEachTrial)
                possibleIntensity = np.unique(intensityEachTrial)

                xlabel='Frequency (kHz)'
                ylabel='Intensity (dB SPL)'

                # firstSortLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
                # secondSortLabels = ['{}'.format(inten) for inten in possibleIntensity]

                # dataplotter.two_axis_heatmap(spikeTimestamps,
                #                             eventOnsetTimes,
                #                             freqEachTrial,
                #                             intensityEachTrial,
                #                             firstSortLabels,
                #                             secondSortLabels,
                #                             xlabel,
                #                             ylabel,
                #                             plotTitle=plotTitle,
                #                             flipFirstAxis=False,
                #                             flipSecondAxis=True,
                #                             timeRange=[0, 0.1])

                freqLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
                intenLabels = ["%.1f" % inten for inten in possibleIntensity]

                dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                            eventOnsetTimes=eventOnsetTimes,
                                            firstSortArray=intensityEachTrial,
                                            secondSortArray=freqEachTrial,
                                            firstSortLabels=intenLabels,
                                            secondSortLabels=freqLabels,
                                            xlabel=xlabel,
                                            ylabel=ylabel,
                                            plotTitle=plotTitle,
                                            flipFirstAxis=True,
                                            flipSecondAxis=False,
                                            timeRange=[0, 0.1])

                plt.title("{0}\n{1}".format(mainTCsession, mainTCbehavFilename), fontsize = 10)
                plt.show()


            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)

            tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]


            # -- Plot ISI histogram --
            plt.subplot2grid((6,6), (4,0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6,6), (5,0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6,6), (4,3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6,6), (5,3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace = 0.7)
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            #plt.tight_layout()
            plt.savefig(full_fig_path, format = 'png')