def load_ephys_per_cell(oneCell):
    '''
    Load spike and event data from one cell given a CellInfo object (from celldatabase)
    '''
    ###0301 should TRY spkData = ephyscore.CellData(oneCell)
    ephysDir = settings.EPHYS_PATH
    ephysData = ephyscore.CellData(
        oneCell
    )  #ephyscore's CellData object uses loadopenephys.DataSpikes method, gets back a DataSpikes object that is stored in CellData.spikes, with fields: nSpikes, samples, timestamps(ephyscore already divides this by sampling rate),gain,samplingRate.
    spikeData = ephysData.spikes  #so this is the DataSpikes object from loadopenephys
    spikeTimestamps = spikeData.timestamps  #ephyscore already divided this by sampling rate so unit is in seconds
    waveforms = spikeData.samples.astype(
        float) - 2**15  #This is specific to open Ephys
    waveforms = (1000.0 / spikeData.gain[0, 0]
                 ) * waveforms  #converting to microvolt,specific to open Ephys

    fullEventFilename = os.path.join(ephysDir, oneCell.animalName,
                                     oneCell.ephysSession,
                                     'all_channels.events')
    eventData = loadopenephys.Events(fullEventFilename)

    eventData.timestamps = np.array(eventData.timestamps) / SAMPLING_RATE
    eventOnsetTimes = np.array(
        eventData.timestamps)  #all events not just soundonset
    return (spikeTimestamps, waveforms, eventOnsetTimes, eventData)
def load_one_cell_waveform(oneCell):
    '''
    This function takes a cell of the class CellInfo (from jaratoolbox.celldatabase) and gets the waveform data for that cell.
    '''
    ephysData = ephyscore.CellData(
        oneCell
    )  #ephyscore's CellData object uses loadopenephys.DataSpikes method, gets back a DataSpikes object that is stored in CellData.spikes, with fields: nSpikes, samples, timestamps(ephyscore already divides this by sampling rate),gain,samplingRate.
    spikeData = ephysData.spikes  #so this is the DataSpikes object from loadopenephys
    waveforms = spikeData.samples.astype(
        float) - 2**15  #This is specific to open Ephys
    waveforms = (1000.0 / spikeData.gain[0, 0]
                 ) * waveforms  #converting to microvolt,specific to open Ephys

    return waveforms
            leftward = bdata['choice']==bdata.labels['choice']['left']
            valid = (bdata['outcome']==bdata.labels['outcome']['correct'])|(bdata['outcome']==bdata.labels['outcome']['error'])
            correct = bdata['outcome']==bdata.labels['outcome']['correct']
            correctRightward = rightward & correct
            correctLeftward = leftward & correct

            possibleFreq = np.unique(bdata['targetFrequency'])
            numberOfFrequencies = len(possibleFreq)
            numberOfTrials = len(bdata['choice'])
            targetFreqs = bdata['targetFrequency']

            for possFreq in possibleFreq:
                modIDict[behavSession][possFreq] = np.empty([clusNum*numTetrodes])

        # -- Load Spike Data From Certain Cluster --
        spkData = ephyscore.CellData(oneCell)
        spkTimeStamps = spkData.spikes.timestamps


        clusterNumber = (oneCell.tetrode-1)*clusNum+(oneCell.cluster-1)
        for Freq in possibleFreq:
            oneFreq = targetFreqs == Freq

            trialsToUseRight = rightward & oneFreq
            trialsToUseLeft = leftward & oneFreq

            #print 'behavior ',behavSession,' tetrode ',oneCell.tetrode,' cluster ',oneCell.cluster,'freq',Freq

            (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
                spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)
def raster_tuning(ax):

    fullbehaviorDir = behaviorDir + subject + '/'
    behavName = subject + '_tuning_curve_' + tuningBehavior + '.h5'
    tuningBehavFileName = os.path.join(fullbehaviorDir, behavName)

    tuning_bdata = loadbehavior.BehaviorData(tuningBehavFileName,
                                             readmode='full')
    freqEachTrial = tuning_bdata['currentFreq']
    possibleFreq = np.unique(freqEachTrial)
    numberOfTrials = len(freqEachTrial)

    # -- The old way of sorting (useful for plotting sorted raster) --
    sortedTrials = []
    numTrialsEachFreq = [
    ]  #Used to plot lines after each group of sorted trials
    for indf, oneFreq in enumerate(
            possibleFreq
    ):  #indf is index of this freq and oneFreq is the frequency
        indsThisFreq = np.flatnonzero(
            freqEachTrial == oneFreq)  #this gives indices of this frequency
        sortedTrials = np.concatenate(
            (sortedTrials,
             indsThisFreq))  #adds all indices to a list called sortedTrials
        numTrialsEachFreq.append(
            len(indsThisFreq))  #finds number of trials each frequency has
    sortingInds = argsort(
        sortedTrials)  #gives array of indices that would sort the sortedTrials

    # -- Load event data and convert event timestamps to ms --
    tuning_ephysDir = os.path.join(settings.EPHYS_PATH, subject, tuningEphys)
    tuning_eventFilename = os.path.join(tuning_ephysDir, 'all_channels.events')
    tuning_ev = loadopenephys.Events(
        tuning_eventFilename)  #load ephys data (like bdata structure)
    tuning_eventTimes = np.array(
        tuning_ev.timestamps
    ) / SAMPLING_RATE  #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and
    tuning_evID = np.array(
        tuning_ev.eventID
    )  #loads the onset times of events (matches up with eventID to say if event 1 went on (1) or off (0)
    tuning_eventOnsetTimes = tuning_eventTimes[
        tuning_evID ==
        1]  #array that is a time stamp for when the chosen event happens.
    #ev.eventChannel woul load array of events like trial start and sound start and finish times (sound event is 0 and trial start is 1 for example). There is only one event though and its sound start
    while (numberOfTrials < len(tuning_eventOnsetTimes)):
        tuning_eventOnsetTimes = tuning_eventOnsetTimes[:-1]

    #######################################################################################################
    ###################THIS IS SUCH A HACK TO GET SPKDATA FROM EPHYSCORE###################################
    #######################################################################################################

    thisCell = celldatabase.CellInfo(
        animalName=subject,  ############################################
        ephysSession=tuningEphys,
        tuningSession='DO NOT NEED THIS',
        tetrode=tetrode,
        cluster=cluster,
        quality=1,
        depth=0,
        tuningBehavior='DO NOT NEED THIS',
        behavSession=tuningBehavior)

    tuning_spkData = ephyscore.CellData(thisCell)
    tuning_spkTimeStamps = tuning_spkData.spikes.timestamps

    (tuning_spikeTimesFromEventOnset, tuning_trialIndexForEachSpike,
     tuning_indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
         tuning_spkTimeStamps, tuning_eventOnsetTimes, tuning_timeRange)

    #print 'numTrials ',max(tuning_trialIndexForEachSpike)#####################################
    '''
        Create a vector with the spike timestamps w.r.t. events onset.

        (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = 
            eventlocked_spiketimes(timeStamps,eventOnsetTimes,timeRange)

        timeStamps: (np.array) the time of each spike.
        eventOnsetTimes: (np.array) the time of each instance of the event to lock to.
        timeRange: (list or np.array) two-element array specifying time-range to extract around event.

        spikeTimesFromEventOnset: 1D array with time of spikes locked to event.
    o    trialIndexForEachSpike: 1D array with the trial corresponding to each spike.
           The first spike index is 0.
        indexLimitsEachTrial: [2,nTrials] range of spikes for each trial. Note that
           the range is from firstSpike to lastSpike+1 (like in python slices)
        spikeIndices
    '''

    tuning_sortedIndexForEachSpike = sortingInds[
        tuning_trialIndexForEachSpike]  #Takes values of trialIndexForEachSpike and finds value of sortingInds at that index and makes array. This array gives an array with the sorted index of each trial for each spike

    # -- Calculate tuning --
    #nSpikes = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset,indexLimitsEachTrial,responseRange) #array of the number of spikes in range for each trial
    '''Count number of spikes on each trial in a given time range.

           spikeTimesFromEventOnset: vector of spikes timestamps with respect
             to the onset of the event.
           indexLimitsEachTrial: each column contains [firstInd,lastInd+1] of the spikes on a trial.
           timeRange: time range to evaluate. Spike times exactly at the limits are not counted.

           returns nSpikes
    '''
    '''
    meanSpikesEachFrequency = np.empty(len(possibleFreq)) #make empty array of same size as possibleFreq

    # -- This part will be replace by something like behavioranalysis.find_trials_each_type --
    trialsEachFreq = []
    for indf,oneFreq in enumerate(possibleFreq):
        trialsEachFreq.append(np.flatnonzero(freqEachTrial==oneFreq)) #finds indices of each frequency. Appends them to get an array of indices of trials sorted by freq

    # -- Calculate average firing for each freq --
    for indf,oneFreq in enumerate(possibleFreq):
        meanSpikesEachFrequency[indf] = np.mean(nSpikes[trialsEachFreq[indf]])
    '''
    #clf()
    #if (len(tuning_spkTimeStamps)>0):
    #ax1 = plt.subplot2grid((4,4), (3, 0), colspan=1)
    #spikesorting.plot_isi_loghist(spkData.spikes.timestamps)
    #ax3 = plt.subplot2grid((4,4), (3, 3), colspan=1)
    #spikesorting.plot_events_in_time(tuning_spkTimeStamps)
    #samples = tuning_spkData.spikes.samples.astype(float)-2**15
    #samples = (1000.0/tuning_spkData.spikes.gain[0,0]) *samples
    #ax2 = plt.subplot2grid((4,4), (3, 1), colspan=2)
    #spikesorting.plot_waveforms(samples)
    #ax4 = plt.subplot2grid((4,4), (0, 0), colspan=3,rowspan = 3)
    plot(tuning_spikeTimesFromEventOnset,
         tuning_sortedIndexForEachSpike,
         '.',
         ms=3)
    #axvline(x=0, ymin=0, ymax=1, color='r')

    #The cumulative sum of the list of specific frequency presentations,
    #used below for plotting the lines across the figure.
    numTrials = cumsum(numTrialsEachFreq)

    #Plot the lines across the figure in between each group of sorted trials
    for indf, num in enumerate(numTrials):
        ax.axhline(y=num, xmin=0, xmax=1, color='0.90', zorder=0)

    tickPositions = numTrials - mean(numTrialsEachFreq) / 2
    tickLabels = [
        "%0.2f" % (possibleFreq[indf] / 1000)
        for indf in range(len(possibleFreq))
    ]
    ax.set_yticks(tickPositions)
    ax.set_yticklabels(tickLabels)
    ax.set_ylim([-1, numberOfTrials])
    ylabel('Frequency Presented (kHz), {} total trials'.format(numTrials[-1]))
    #title(ephysSession+' T{}c{}'.format(tetrodeID,clusterID))
    xlabel('Time (sec)')
    '''

    ax5 = plt.subplot2grid((4,4), (0, 3), colspan=1,rowspan=3)
    ax5.set_xscale('log')
    plot(possibleFreq,meanSpikesEachFrequency,'o-')
    ylabel('Avg spikes in window {0}-{1} sec'.format(*responseRange))
    xlabel('Frequency')
    '''
    #show()
    '''
def main():
    global behavSession
    global subject
    global tetrode
    global cluster
    global tuningBehavior  #behavior file name of tuning curve
    global tuningEphys  #ephys session name of tuning curve
    global bdata
    global eventOnsetTimes
    global spikeTimesFromEventOnset
    global indexLimitsEachTrial
    global spikeTimesFromMovementOnset
    global indexLimitsEachMovementTrial
    global titleText

    print "switch_tuning_block_allfreq_report"
    for cellID in range(0, numOfCells):
        oneCell = allcells.cellDB[cellID]
        try:
            if (behavSession != oneCell.behavSession):

                subject = oneCell.animalName
                behavSession = oneCell.behavSession
                ephysSession = oneCell.ephysSession
                tuningSession = oneCell.tuningSession
                ephysRoot = os.path.join(ephysRootDir, subject)
                tuningBehavior = oneCell.tuningBehavior
                tuningEphys = oneCell.tuningSession

                print behavSession

                # -- Load Behavior Data --
                behaviorFilename = loadbehavior.path_to_behavior_data(
                    subject=subject,
                    paradigm=paradigm,
                    sessionstr=behavSession)
                bdata = loadbehavior.FlexCategBehaviorData(behaviorFilename)
                #bdata = loadbehavior.BehaviorData(behaviorFilename)
                numberOfTrials = len(bdata['choice'])

                # -- Load event data and convert event timestamps to ms --
                ephysDir = os.path.join(ephysRoot, ephysSession)
                eventFilename = os.path.join(ephysDir, 'all_channels.events')
                events = loadopenephys.Events(
                    eventFilename)  # Load events data
                eventTimes = np.array(
                    events.timestamps
                ) / SAMPLING_RATE  #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and

                soundOnsetEvents = (events.eventID == 1) & (
                    events.eventChannel == soundTriggerChannel)

                eventOnsetTimes = eventTimes[soundOnsetEvents]
                soundOnsetTimeBehav = bdata['timeTarget']

                # Find missing trials
                missingTrials = behavioranalysis.find_missing_trials(
                    eventOnsetTimes, soundOnsetTimeBehav)
                # Remove missing trials
                bdata.remove_trials(missingTrials)
                bdata.find_trials_each_block()

                ###############################################################################################
                centerOutTimes = bdata[
                    'timeCenterOut']  #This is the times that the mouse goes out of the center port
                soundStartTimes = bdata[
                    'timeTarget']  #This gives an array with the times in seconds from the start of the behavior paradigm of when the sound was presented for each trial
                timeDiff = centerOutTimes - soundStartTimes
                if (len(eventOnsetTimes) < len(timeDiff)):
                    timeDiff = timeDiff[:-1]
                    eventOnsetTimesCenter = eventOnsetTimes + timeDiff
                elif (len(eventOnsetTimes) > len(timeDiff)):
                    eventOnsetTimesCenter = eventOnsetTimes[:-1] + timeDiff
                else:
                    eventOnsetTimesCenter = eventOnsetTimes + timeDiff
                ###############################################################################################

            tetrode = oneCell.tetrode
            cluster = oneCell.cluster

            # -- Load Spike Data From Certain Cluster --
            spkData = ephyscore.CellData(oneCell)
            spkTimeStamps = spkData.spikes.timestamps

            (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
                spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)

            (spikeTimesFromMovementOnset,movementTrialIndexForEachSpike,indexLimitsEachMovementTrial) = \
                spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimesCenter,timeRange)

            plt.clf()
            if (len(spkTimeStamps) > 0):
                ax1 = plt.subplot2grid((numRows, numCols),
                                       ((numRows - sizeClusterPlot), 0),
                                       colspan=(numCols / 3))
                spikesorting.plot_isi_loghist(spkData.spikes.timestamps)
                ax3 = plt.subplot2grid(
                    (numRows, numCols),
                    ((numRows - sizeClusterPlot), (numCols / 3) * 2),
                    colspan=(numCols / 3))
                spikesorting.plot_events_in_time(spkTimeStamps)
                samples = spkData.spikes.samples.astype(float) - 2**15
                samples = (1000.0 / spkData.spikes.gain[0, 0]) * samples
                ax2 = plt.subplot2grid(
                    (numRows, numCols),
                    ((numRows - sizeClusterPlot), (numCols / 3)),
                    colspan=(numCols / 3))
                spikesorting.plot_waveforms(samples)

            ###############################################################################
            ax4 = plt.subplot2grid((numRows, numCols), (0, 0),
                                   colspan=(numCols / 2),
                                   rowspan=3 * sizeRasters)
            #plt.setp(ax4.get_xticklabels(), visible=False)
            #fig.axes.get_xaxis().set_visible(False)
            raster_tuning(ax4)
            axvline(x=0, ymin=0, ymax=1, color='r')
            plt.gca().set_xlim(tuning_timeRange)

            ax6 = plt.subplot2grid((numRows, numCols), (0, (numCols / 2)),
                                   colspan=(numCols / 2),
                                   rowspan=sizeRasters)
            plt.setp(ax6.get_xticklabels(), visible=False)
            plt.setp(ax6.get_yticklabels(), visible=False)
            raster_sound_block_switching()
            plt.title(
                'sound aligned, Top: middle freq in blocks, Bottom: all freqs')

            ax7 = plt.subplot2grid((numRows, numCols),
                                   (sizeRasters, (numCols / 2)),
                                   colspan=(numCols / 2),
                                   rowspan=sizeHists,
                                   sharex=ax6)
            hist_sound_block_switching(ax7)
            #plt.setp(ax7.get_yticklabels(), visible=False)
            ax7.yaxis.tick_right()
            ax7.yaxis.set_ticks_position('both')
            plt.setp(ax7.get_xticklabels(), visible=False)
            plt.gca().set_xlim(timeRange)

            ax10 = plt.subplot2grid((numRows, numCols),
                                    ((sizeRasters + sizeHists), (numCols / 2)),
                                    colspan=(numCols / 2),
                                    rowspan=sizeRasters)
            plt.setp(ax10.get_xticklabels(), visible=False)
            plt.setp(ax10.get_yticklabels(), visible=False)
            raster_sound_allFreq_switching()

            ax11 = plt.subplot2grid(
                (numRows, numCols),
                ((2 * sizeRasters + sizeHists), (numCols / 2)),
                colspan=(numCols / 2),
                rowspan=sizeHists,
                sharex=ax10)
            hist_sound_allFreq_switching(ax11)
            ax11.yaxis.tick_right()
            ax11.yaxis.set_ticks_position('both')
            ax11.set_xlabel('Time (sec)')
            #plt.setp(ax11.get_yticklabels(), visible=False)
            plt.gca().set_xlim(timeRange)

            ###############################################################################
            #plt.tight_layout()
            modulation_index_switching()
            plt.suptitle(titleText)

            tetrodeClusterName = 'T' + str(oneCell.tetrode) + 'c' + str(
                oneCell.cluster)
            plt.gcf().set_size_inches((8.5, 11))
            figformat = 'png'  #'png' #'pdf' #'svg'
            filename = reportname + '_%s_%s_%s.%s' % (
                subject, behavSession, tetrodeClusterName, figformat)
            fulloutputDir = outputDir + subject + '/'
            fullFileName = os.path.join(fulloutputDir, filename)

            directory = os.path.dirname(fulloutputDir)
            if not os.path.exists(directory):  #makes sure output folder exists
                os.makedirs(directory)
            #print 'saving figure to %s'%fullFileName
            plt.gcf().savefig(fullFileName, format=figformat)

        except:
            if (oneCell.behavSession not in badSessionList):
                badSessionList.append(oneCell.behavSession)

    print 'error with sessions: '
    for badSes in badSessionList:
        print badSes
def main(oneCell):
    oneCell = allcells.cellDB[cellID]
    if (behavSession != oneCell.behavSession):


        subject = oneCell.animalName
        behavSession = oneCell.behavSession
        ephysSession = oneCell.ephysSession
        ephysRoot = os.path.join(ephysRootDir,subject)

        # -- Load Behavior Data --
        behaviorFilename = loadbehavior.path_to_behavior_data(subject,experimenter,paradigm,behavSession)
        bdata = loadbehavior.BehaviorData(behaviorFilename)
        numberOfTrials = len(bdata['choice'])

        # -- Load event data and convert event timestamps to ms --
        ephysDir = os.path.join(ephysRoot, ephysSession)
        eventFilename=os.path.join(ephysDir, 'all_channels.events')
        events = loadopenephys.Events(eventFilename) # Load events data
        eventTimes=np.array(events.timestamps)/SAMPLING_RATE #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and 

        soundOnsetEvents = (events.eventID==1) & (events.eventChannel==soundTriggerChannel)


        eventOnsetTimes = eventTimes[soundOnsetEvents]

        rightward = bdata['choice']==bdata.labels['choice']['right']
        leftward = bdata['choice']==bdata.labels['choice']['left']
        invalid = bdata['outcome']==bdata.labels['outcome']['invalid']
        correct = bdata['outcome']==bdata.labels['outcome']['correct']
        correctRightward = rightward & correct
        correctLeftward = leftward & correct

        possibleFreq = np.unique(bdata['targetFrequency'])
        Freq = possibleFreq[Frequency]
        oneFreq = bdata['targetFrequency'] == possibleFreq[Frequency]

        trialsToUseRight = correctRightward & oneFreq
        trialsToUseLeft = correctLeftward & oneFreq

        trialsEachCond = np.c_[trialsToUseLeft,trialsToUseRight]; colorEachCond = ['g','r']




    # -- Load Spike Data From Certain Cluster --
    spkData = ephyscore.CellData(oneCell)
    spkTimeStamps = spkData.spikes.timestamps

    (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
        spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)



    plt.clf()
    ax1 =  plt.subplot2grid((3,1), (0, 0), rowspan=2)
    extraplots.raster_plot(spikeTimesFromEventOnset,indexLimitsEachTrial,timeRange,trialsEachCond=trialsEachCond,colorEachCond=colorEachCond,fillWidth=None,labels=None)

    plt.ylabel('Trials')

    timeVec = np.arange(timeRange[0],timeRange[-1],binWidth)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset,indexLimitsEachTrial,timeVec)

    smoothWinSize = 3
    ax2 = plt.subplot2grid((3,1), (2, 0), sharex=ax1)

    extraplots.plot_psth(spikeCountMat/binWidth,smoothWinSize,timeVec,trialsEachCond=trialsEachCond,
                         colorEachCond=colorEachCond,linestyle=None,linewidth=3,downsamplefactor=1)

    plt.xlabel('Time from sound onset (s)')
    plt.ylabel('Firing rate (spk/sec)')

    nameFreq = str(Freq)
    tetrodeClusterName = 'T'+str(oneCell.tetrode)+'c'+str(oneCell.cluster)
    plt.gcf().set_size_inches((8.5,11))
    figformat = 'png' #'png' #'pdf' #'svg'
    filename = 'rast_%s_%s_%s_%s.%s'%(subject,behavSession,nameFreq,tetrodeClusterName,figformat)
    fulloutputDir = outputDir+subject+'/'+ nameFreq +'/'
    fullFileName = os.path.join(fulloutputDir,filename)

    directory = os.path.dirname(fulloutputDir)
    if not os.path.exists(directory):
        os.makedirs(directory)
    print 'saving figure to %s'%fullFileName
    plt.gcf().savefig(fullFileName,format=figformat)
def switch_report(mouseName, behavSession, tetrode, cluster):
    #global behavSession
    #global subject
    global bdata
    global eventOnsetTimes
    global spikeTimesFromEventOnset
    global indexLimitsEachTrial
    global spikeTimesFromMovementOnset
    global indexLimitsEachMovementTrial

    allcellsFileName = 'allcells_' + mouseName
    sys.path.append(settings.ALLCELLS_PATH)
    allcells = importlib.import_module(allcellsFileName)

    cellID = allcells.cellDB.findcell(mouseName, behavSession, tetrode,
                                      cluster)
    oneCell = allcells.cellDB[cellID]

    subject = oneCell.animalName
    behavSession = oneCell.behavSession
    ephysSession = oneCell.ephysSession
    ephysRoot = os.path.join(ephysRootDir, subject)

    # -- Load Behavior Data --
    behaviorFilename = loadbehavior.path_to_behavior_data(
        subject, experimenter, paradigm, behavSession)
    bdata = loadbehavior.FlexCategBehaviorData(behaviorFilename)
    #bdata = loadbehavior.BehaviorData(behaviorFilename)
    bdata.find_trials_each_block()
    numberOfTrials = len(bdata['choice'])

    # -- Load event data and convert event timestamps to ms --
    ephysDir = os.path.join(ephysRoot, ephysSession)
    eventFilename = os.path.join(ephysDir, 'all_channels.events')
    events = loadopenephys.Events(eventFilename)  # Load events data
    eventTimes = np.array(
        events.timestamps
    ) / SAMPLING_RATE  #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and

    soundOnsetEvents = (events.eventID == 1) & (events.eventChannel
                                                == soundTriggerChannel)

    eventOnsetTimes = eventTimes[soundOnsetEvents]

    #################################################################################################
    centerOutTimes = bdata[
        'timeCenterOut']  #This is the times that the mouse goes out of the center port
    soundStartTimes = bdata[
        'timeTarget']  #This gives an array with the times in seconds from the start of the behavior paradigm of when the sound was presented for each trial
    timeDiff = centerOutTimes - soundStartTimes
    if (len(eventOnsetTimes) < len(timeDiff)):
        timeDiff = timeDiff[:-1]
    eventOnsetTimesCenter = eventOnsetTimes + timeDiff
    #################################################################################################

    # -- Load Spike Data From Certain Cluster --
    spkData = ephyscore.CellData(oneCell)
    spkTimeStamps = spkData.spikes.timestamps

    (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
        spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)

    (spikeTimesFromMovementOnset,movementTrialIndexForEachSpike,indexLimitsEachMovementTrial) = \
        spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimesCenter,timeRange)

    plt.clf()
    if (len(spkTimeStamps) > 0):
        ax1 = plt.subplot2grid((numRows, numCols),
                               ((numRows - sizeClusterPlot), 0),
                               colspan=(numCols / 3))
        spikesorting.plot_isi_loghist(spkData.spikes.timestamps)
        ax3 = plt.subplot2grid(
            (numRows, numCols),
            ((numRows - sizeClusterPlot), (numCols / 3) * 2),
            colspan=(numCols / 3))
        spikesorting.plot_events_in_time(spkTimeStamps)
        samples = spkData.spikes.samples.astype(float) - 2**15
        samples = (1000.0 / spkData.spikes.gain[0, 0]) * samples
        ax2 = plt.subplot2grid((numRows, numCols),
                               ((numRows - sizeClusterPlot), (numCols / 3)),
                               colspan=(numCols / 3))
        spikesorting.plot_waveforms(samples)

    ###############################################################################
    ax4 = plt.subplot2grid((numRows, numCols), (0, 0),
                           colspan=(numCols / 2),
                           rowspan=sizeRasters)
    raster_sound_block_switching()
    ax5 = plt.subplot2grid((numRows, numCols), (sizeRasters, 0),
                           colspan=(numCols / 2),
                           rowspan=sizeHists)
    hist_sound_block_switching()
    ax6 = plt.subplot2grid((numRows, numCols), (0, (numCols / 2)),
                           colspan=(numCols / 2),
                           rowspan=sizeRasters)
    raster_movement_block_switching()
    ax7 = plt.subplot2grid((numRows, numCols), (sizeRasters, (numCols / 2)),
                           colspan=(numCols / 2),
                           rowspan=sizeHists)
    hist_movement_block_switching()

    ax8 = plt.subplot2grid((numRows, numCols), ((sizeRasters + sizeHists), 0),
                           colspan=(numCols / 2),
                           rowspan=sizeRasters)
    raster_sound_allFreq_switching()
    ax9 = plt.subplot2grid((numRows, numCols),
                           ((2 * sizeRasters + sizeHists), 0),
                           colspan=(numCols / 2),
                           rowspan=sizeHists)
    hist_sound_allFreq_switching()
    ax10 = plt.subplot2grid((numRows, numCols),
                            ((sizeRasters + sizeHists), (numCols / 2)),
                            colspan=(numCols / 2),
                            rowspan=sizeRasters)
    raster_sound_switching()
    ax11 = plt.subplot2grid((numRows, numCols),
                            ((2 * sizeRasters + sizeHists), (numCols / 2)),
                            colspan=(numCols / 2),
                            rowspan=sizeHists)
    hist_sound_switching()
    ###############################################################################
    #plt.tight_layout()

    tetrodeClusterName = 'T' + str(oneCell.tetrode) + 'c' + str(
        oneCell.cluster)
    plt.suptitle(mouseName + ' ' + behavSession + ' ' + tetrodeClusterName)
    plt.gcf().set_size_inches((8.5, 11))
    #figformat = 'png' #'png' #'pdf' #'svg'
    #filename = 'report_%s_%s_%s.%s'%(subject,behavSession,tetrodeClusterName,figformat)
    #fulloutputDir = outputDir+subject +'/'
    #fullFileName = os.path.join(fulloutputDir,filename)

    #directory = os.path.dirname(fulloutputDir)
    #if not os.path.exists(directory): #makes sure output folder exists
    #os.makedirs(directory)
    #print 'saving figure to %s'%fullFileName
    #plt.gcf().savefig(fullFileName,format=figformat)

    plt.show()
from jaratoolbox import celldatabase
import pandas as pd
import os
from jaratoolbox import settings
from jaratoolbox import ephyscore
reload(ephyscore)

dbPath = os.path.join(settings.FIGURES_DATA_PATH, '2018thstr',
                      'celldatabase.h5')
db = pd.read_hdf(dbPath, key='dataframe')

#Select a cell
cell = db.ix[1033]
#Convert the cell to a dict. Or, you can make a dict yourself that has the right fields
cellDict = cell.to_dict()
cellData = ephyscore.CellData(**cellDict)

spikeData, events = cellData.load_ephys('am')
bdata = cellData.load_bdata('am')
Beispiel #9
0
def main():
    global behavSession
    global subject
    global ephysSession
    global tetrodeID
    global bdata
    global eventOnsetTimes
    global spikeTimesFromEventOnset
    global indexLimitsEachTrial
    global spikeTimesFromMovementOnset
    global indexLimitsEachMovementTrial
    for cellID in range(0, numOfCells):
        oneCell = allcells.cellDB[cellID]
        try:
            if (behavSession != oneCell.behavSession):

                subject = oneCell.animalName
                behavSession = oneCell.behavSession
                ephysSession = oneCell.ephysSession
                ephysRoot = os.path.join(ephysRootDir, subject)

                print behavSession

                # -- Load Behavior Data --
                behaviorFilename = loadbehavior.path_to_behavior_data(
                    subject, experimenter, paradigm, behavSession)
                bdata = loadbehavior.BehaviorData(behaviorFilename)
                numberOfTrials = len(bdata['choice'])

                # -- Load event data and convert event timestamps to ms --
                ephysDir = os.path.join(ephysRoot, ephysSession)
                eventFilename = os.path.join(ephysDir, 'all_channels.events')
                events = loadopenephys.Events(
                    eventFilename)  # Load events data
                eventTimes = np.array(
                    events.timestamps
                ) / SAMPLING_RATE  #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and

                soundOnsetEvents = (events.eventID == 1) & (
                    events.eventChannel == soundTriggerChannel)

                eventOnsetTimes = eventTimes[soundOnsetEvents]

                possibleFreq = np.unique(bdata['targetFrequency'])
                numberOfFrequencies = len(possibleFreq)
                centerFrequencies = [(numberOfFrequencies / 2 - 1),
                                     numberOfFrequencies / 2]

                #################################################################################################
                centerOutTimes = bdata[
                    'timeCenterOut']  #This is the times that the mouse goes out of the center port
                soundStartTimes = bdata[
                    'timeTarget']  #This gives an array with the times in seconds from the start of the behavior paradigm of when the sound was presented for each trial
                timeDiff = centerOutTimes - soundStartTimes
                if (len(eventOnsetTimes) < len(timeDiff)):
                    timeDiff = timeDiff[:-1]
                eventOnsetTimesCenter = eventOnsetTimes + timeDiff
                #################################################################################################

            # -- Load Spike Data From Certain Cluster --
            spkData = ephyscore.CellData(oneCell)
            spkTimeStamps = spkData.spikes.timestamps

            (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
                spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)

            (spikeTimesFromMovementOnset,movementTrialIndexForEachSpike,indexLimitsEachMovementTrial) = \
                spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimesCenter,timeRange)

            plt.clf()
            if (len(spkTimeStamps) > 0):
                ax1 = plt.subplot2grid((7, 6), (6, 0), colspan=2)
                spikesorting.plot_isi_loghist(spkData.spikes.timestamps)
                ax3 = plt.subplot2grid((7, 6), (6, 4), colspan=2)
                spikesorting.plot_events_in_time(spkData.spikes.timestamps)
                samples = spkData.spikes.samples.astype(float) - 2**15
                samples = (1000.0 / spkData.spikes.gain[0, 0]) * samples
                ax2 = plt.subplot2grid((7, 6), (6, 2), colspan=2)
                spikesorting.plot_waveforms(samples)

            ###############################################################################
            ax4 = plt.subplot2grid((7, 6), (0, 0), colspan=3, rowspan=2)
            raster_sound_psycurve(centerFrequencies[0])
            ax5 = plt.subplot2grid((7, 6), (2, 0), colspan=3)
            hist_sound_psycurve(centerFrequencies[0])
            ax6 = plt.subplot2grid((7, 6), (0, 3), colspan=3, rowspan=2)
            raster_sound_psycurve(centerFrequencies[1])
            ax7 = plt.subplot2grid((7, 6), (2, 3), colspan=3)
            hist_sound_psycurve(centerFrequencies[1])

            ax8 = plt.subplot2grid((7, 6), (3, 0), colspan=3, rowspan=2)
            raster_movement_psycurve(centerFrequencies[0])
            ax9 = plt.subplot2grid((7, 6), (5, 0), colspan=3)
            hist_movement_psycurve(centerFrequencies[0])
            ax10 = plt.subplot2grid((7, 6), (3, 3), colspan=3, rowspan=2)
            raster_movement_psycurve(centerFrequencies[1])
            ax11 = plt.subplot2grid((7, 6), (5, 3), colspan=3)
            hist_movement_psycurve(centerFrequencies[1])
            ###############################################################################
            #plt.tight_layout()

            tetrodeClusterName = 'T' + str(oneCell.tetrode) + 'c' + str(
                oneCell.cluster)
            plt.gcf().set_size_inches((8.5, 11))
            figformat = 'png'  #'png' #'pdf' #'svg'
            filename = 'report_centerFreq_%s_%s_%s.%s' % (
                subject, behavSession, tetrodeClusterName, figformat)
            fulloutputDir = outputDir + subject + '/'
            fullFileName = os.path.join(fulloutputDir, filename)

            directory = os.path.dirname(fulloutputDir)
            if not os.path.exists(directory):
                os.makedirs(directory)
            #print 'saving figure to %s'%fullFileName
            plt.gcf().savefig(fullFileName, format=figformat)

            #plt.show()

        except:
            #print "error with session "+oneCell.behavSession
            if (oneCell.behavSession not in badSessionList):
                badSessionList.append(oneCell.behavSession)

    print 'error with sessions: '
    for badSes in badSessionList:
        print badSes
Beispiel #10
0
def rasterBlock(oneCell):
    subject = oneCell.animalName
    behavSession = oneCell.behavSession
    ephysSession = oneCell.ephysSession
    ephysRoot = os.path.join(ephysRootDir, subject)

    # -- Load Behavior Data --
    behaviorFilename = loadbehavior.path_to_behavior_data(
        subject, experimenter, paradigm, behavSession)
    bdata = loadbehavior.FlexCategBehaviorData(behaviorFilename)
    bdata.find_trials_each_block()

    # -- Load event data and convert event timestamps to ms --
    ephysDir = os.path.join(ephysRoot, ephysSession)
    eventFilename = os.path.join(ephysDir, 'all_channels.events')
    events = loadopenephys.Events(eventFilename)  # Load events data
    eventTimes = np.array(events.timestamps) / SAMPLING_RATE

    soundOnsetEvents = (events.eventID == 1) & (events.eventChannel
                                                == soundTriggerChannel)

    # -- Load Spike Data From Certain Cluster --
    spkData = ephyscore.CellData(oneCell)
    spkTimeStamps = spkData.spikes.timestamps

    eventOnsetTimes = eventTimes[soundOnsetEvents]

    correct = bdata['outcome'] == bdata.labels['outcome']['correct']

    possibleFreq = np.unique(bdata['targetFrequency'])
    oneFreq = bdata['targetFrequency'] == possibleFreq[middleFreq]

    correctOneFreq = oneFreq & correct
    correctTrialsEachBlock = bdata.blocks[
        'trialsEachBlock'] & correctOneFreq[:, np.newaxis]

    #trialsEachCond = np.c_[invalid,leftward,rightward]; colorEachCond = ['0.75','g','r']
    #trialsEachCond = np.c_[leftward,rightward]; colorEachCond = ['0.5','0.7','0']
    trialsEachCond = correctTrialsEachBlock

    if bdata['currentBlock'][0] == bdata.labels['currentBlock'][
            'low_boundary']:
        colorEachBlock = 3 * ['g', 'r']
    else:
        colorEachBlock = 3 * ['r', 'g']


    (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
        spikesanalysis.eventlocked_spiketimes(spkTimeStamps,eventOnsetTimes,timeRange)

    #plot(spikeTimesFromEventOnset,trialIndexForEachSpike,'.')

    plt.clf()
    ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
    extraplots.raster_plot(spikeTimesFromEventOnset,
                           indexLimitsEachTrial,
                           timeRange,
                           trialsEachCond=correctTrialsEachBlock,
                           colorEachCond=colorEachBlock,
                           fillWidth=None,
                           labels=None)
    #plt.yticks([0,trialsEachCond.sum()])
    #ax1.set_xticklabels([])
    plt.ylabel('Trials')

    timeVec = np.arange(timeRange[0], timeRange[-1], binWidth)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
        spikeTimesFromEventOnset, indexLimitsEachTrial, timeVec)

    smoothWinSize = 3
    ax2 = plt.subplot2grid((3, 1), (2, 0), sharex=ax1)

    extraplots.plot_psth(spikeCountMat / binWidth,
                         smoothWinSize,
                         timeVec,
                         trialsEachCond=correctTrialsEachBlock,
                         colorEachCond=colorEachBlock,
                         linestyle=None,
                         linewidth=3,
                         downsamplefactor=1)

    plt.xlabel('Time from sound onset (s)')
    plt.ylabel('Firing rate (spk/sec)')

    #plt.show()

    nameFreq = str(possibleFreq[middleFreq])
    tetrodeClusterName = 'T' + str(oneCell.tetrode) + 'c' + str(
        oneCell.cluster)
    plt.gcf().set_size_inches((8.5, 11))
    figformat = 'png'  #'png' #'pdf' #'svg'
    filename = 'block_%s_%s_%s_%s.%s' % (
        subject, behavSession, tetrodeClusterName, nameFreq, figformat)
    fulloutputDir = outputDir + subject + '/'
    fullFileName = os.path.join(fulloutputDir, filename)

    directory = os.path.dirname(fulloutputDir)
    if not os.path.exists(directory):
        os.makedirs(directory)
    print 'saving figure to %s' % fullFileName
    plt.gcf().savefig(fullFileName, format=figformat)
def raster_tuning(ax):

    fullbehaviorDir = behaviorDir + subject + '/'
    behavName = subject + '_tuning_curve_' + tuningBehavior + '.h5'
    tuningBehavFileName = os.path.join(fullbehaviorDir, behavName)

    tuning_bdata = loadbehavior.BehaviorData(tuningBehavFileName,
                                             readmode='full')
    freqEachTrial = tuning_bdata['currentFreq']
    possibleFreq = np.unique(freqEachTrial)
    numberOfTrials = len(freqEachTrial)

    # -- The old way of sorting (useful for plotting sorted raster) --
    sortedTrials = []
    numTrialsEachFreq = [
    ]  #Used to plot lines after each group of sorted trials
    for indf, oneFreq in enumerate(
            possibleFreq
    ):  #indf is index of this freq and oneFreq is the frequency
        indsThisFreq = np.flatnonzero(
            freqEachTrial == oneFreq)  #this gives indices of this frequency
        sortedTrials = np.concatenate(
            (sortedTrials,
             indsThisFreq))  #adds all indices to a list called sortedTrials
        numTrialsEachFreq.append(
            len(indsThisFreq))  #finds number of trials each frequency has
    sortingInds = argsort(
        sortedTrials)  #gives array of indices that would sort the sortedTrials

    # -- Load event data and convert event timestamps to ms --
    tuning_ephysDir = os.path.join(settings.EPHYS_PATH, subject, tuningEphys)
    tuning_eventFilename = os.path.join(tuning_ephysDir, 'all_channels.events')
    tuning_ev = loadopenephys.Events(
        tuning_eventFilename)  #load ephys data (like bdata structure)
    tuning_eventTimes = np.array(
        tuning_ev.timestamps
    ) / SAMPLING_RATE  #get array of timestamps for each event and convert to seconds by dividing by sampling rate (Hz). matches with eventID and
    tuning_evID = np.array(
        tuning_ev.eventID
    )  #loads the onset times of events (matches up with eventID to say if event 1 went on (1) or off (0)
    tuning_eventOnsetTimes = tuning_eventTimes[
        tuning_evID ==
        1]  #array that is a time stamp for when the chosen event happens.
    #ev.eventChannel woul load array of events like trial start and sound start and finish times (sound event is 0 and trial start is 1 for example). There is only one event though and its sound start
    while (numberOfTrials < len(tuning_eventOnsetTimes)):
        tuning_eventOnsetTimes = tuning_eventOnsetTimes[:-1]

    #######################################################################################################
    ###################THIS IS SUCH A HACK TO GET SPKDATA FROM EPHYSCORE###################################
    #######################################################################################################

    thisCell = celldatabase.CellInfo(
        animalName=subject,  ############################################
        ephysSession=tuningEphys,
        tuningSession='DO NOT NEED THIS',
        tetrode=tetrode,
        cluster=cluster,
        quality=1,
        depth=0,
        tuningBehavior='DO NOT NEED THIS',
        behavSession=tuningBehavior)

    tuning_spkData = ephyscore.CellData(thisCell)
    tuning_spkTimeStamps = tuning_spkData.spikes.timestamps

    (tuning_spikeTimesFromEventOnset, tuning_trialIndexForEachSpike,
     tuning_indexLimitsEachTrial) = spikesanalysis.eventlocked_spiketimes(
         tuning_spkTimeStamps, tuning_eventOnsetTimes, tuning_timeRange)

    #print 'numTrials ',max(tuning_trialIndexForEachSpike)#####################################
    '''
        Create a vector with the spike timestamps w.r.t. events onset.

        (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = 
            eventlocked_spiketimes(timeStamps,eventOnsetTimes,timeRange)

        timeStamps: (np.array) the time of each spike.
        eventOnsetTimes: (np.array) the time of each instance of the event to lock to.
        timeRange: (list or np.array) two-element array specifying time-range to extract around event.

        spikeTimesFromEventOnset: 1D array with time of spikes locked to event.
    o    trialIndexForEachSpike: 1D array with the trial corresponding to each spike.
           The first spike index is 0.
        indexLimitsEachTrial: [2,nTrials] range of spikes for each trial. Note that
           the range is from firstSpike to lastSpike+1 (like in python slices)
        spikeIndices
    '''

    tuning_sortedIndexForEachSpike = sortingInds[
        tuning_trialIndexForEachSpike]  #Takes values of trialIndexForEachSpike and finds value of sortingInds at that index and makes array. This array gives an array with the sorted index of each trial for each spike

    # -- Calculate tuning --

    plot(tuning_spikeTimesFromEventOnset,
         tuning_sortedIndexForEachSpike,
         '.',
         ms=3)
    #axvline(x=0, ymin=0, ymax=1, color='r')

    #The cumulative sum of the list of specific frequency presentations,
    #used below for plotting the lines across the figure.
    numTrials = cumsum(numTrialsEachFreq)

    #Plot the lines across the figure in between each group of sorted trials
    for indf, num in enumerate(numTrials):
        ax.axhline(y=num, xmin=0, xmax=1, color='0.90', zorder=0)

    tickPositions = numTrials - mean(numTrialsEachFreq) / 2
    tickLabels = [
        "%0.2f" % (possibleFreq[indf] / 1000)
        for indf in range(len(possibleFreq))
    ]
    ax.set_yticks(tickPositions)
    ax.set_yticklabels(tickLabels)
    ax.set_ylim([-1, numberOfTrials])
    ylabel('Frequency Presented (kHz), {} total trials'.format(numTrials[-1]))
    #title(ephysSession+' T{}c{}'.format(tetrodeID,clusterID))
    xlabel('Time (sec)')
                soundeventData = loadopenephys.Events(
                    soundeventFilename)  # Load events data
                soundeventTimes = np.array(
                    soundeventData.timestamps) / SAMPLING_RATE
                soundOnsetEvents = (soundeventData.eventID == 1) & (
                    soundeventData.eventChannel == soundTriggerChannel)
                soundOnsetTimes = soundeventTimes[soundOnsetEvents]
                print "number of laser trials ", len(
                    laserOnsetTimes), "number of sound trials ", len(
                        soundOnsetTimes)

                maxZLaserDict[behavSession] = np.zeros([clusNum * numTetrodes])
                maxZSoundDict[behavSession] = np.zeros([clusNum * numTetrodes])
                # -- Load Spike Data From Certain Cluster --
            soundSpkData = ephyscore.CellData(
                oneCell
            )  #cannot use this methodfor laser data since it only loads ephys session not laser session
            soundSpkTimeStamps = soundSpkData.spikes.timestamps
            print len(soundSpkTimeStamps), len(soundOnsetTimes)
            laserSpkFullPath = os.path.join(
                laserephysDir, 'Tetrode{0}.spikes'.format(tetrode))
            laserSpkData = loadopenephys.DataSpikes(laserSpkFullPath)
            laserSpkData.timestamps = laserSpkData.timestamps / SAMPLING_RATE
            kkDataDir = os.path.dirname(laserSpkFullPath) + '_kk'
            clusterFilename = 'Tetrode{0}.clu.1'.format(tetrode)
            clusterFullPath = os.path.join(kkDataDir, clusterFilename)
            clusters = np.fromfile(clusterFullPath, dtype='int32', sep=' ')[1:]
            spikesMaskThisCluster = clusters == cluster
            laserSpkTimeStamps = laserSpkData.timestamps[spikesMaskThisCluster]
            print len(laserSpkTimeStamps), len(laserOnsetTimes)