예제 #1
0
def plot_projections_each_cluster(cellObj, sessionType='behavior'):
    '''Function to plot the projection cloud of a given cluster.
    :param arg1: Cell object from ephyscore.
    :param arg2: A string of the type of the ephys session to use. 
    '''
    sessionInd = cellObj.get_session_inds(sessionType)[0]
    ephysData = cellObj.load_ephys_by_index(sessionInd) 
    wavesThisCluster = ephysData['samples']
    spikesorting.plot_projections(wavesThisCluster)
def plot_projections_each_cluster(animal, ephysSession, tetrode, cluster):
    '''Function to plot the projection cloud of a given cluster.
    :param arg1: String containing animal name.
    :param arg2: A string of the name of the ephys session, this is the full filename, in {date}_XX-XX-XX format. 
    :param arg3: Integer in range(1,9) for tetrode number.
    :param arg4: Integer for cluster number.
    '''
    spikeData = load_spike_data(animal, ephysSession, tetrode, cluster)
    wavesThisCluster = spikeData.samples
    spikesorting.plot_projections(wavesThisCluster)
 def plot_report(self, showfig=False):
     print 'Plotting report...'
     #plt.figure(self.fig)
     self.fig = plt.gcf()
     self.fig.clf()
     self.fig.set_facecolor('w')
     nCols = 3
     nRows = self.nRows
     #for indc,clusterID in enumerate(self.clustersList[:3]):
     for indc, clusterID in enumerate(self.clustersList):
         #print('Preparing cluster %d'%clusterID)
         if (indc + 1) > self.nRows:
             print 'WARNING! This cluster was ignore (more clusters than rows)'
             continue
         tsThisCluster = self.dataTT.timestamps[self.spikesEachCluster[
             indc, :]]
         wavesThisCluster = self.dataTT.samples[self.spikesEachCluster[
             indc, :], :, :]
         # -- Plot ISI histogram --
         plt.subplot(self.nRows, nCols, indc * nCols + 1)
         spikesorting.plot_isi_loghist(tsThisCluster)
         if indc < (self.nClusters - 1):  #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         plt.ylabel('c%d' % clusterID, rotation=0, va='center', ha='center')
         # -- Plot events in time --
         plt.subplot(2 * self.nRows, nCols, 2 * (indc * nCols) + 6)
         spikesorting.plot_events_in_time(tsThisCluster)
         if indc < (self.nClusters - 1):  #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         # -- Plot projections --
         plt.subplot(2 * self.nRows, nCols, 2 * (indc * nCols) + 3)
         spikesorting.plot_projections(wavesThisCluster)
         # -- Plot waveforms --
         plt.subplot(self.nRows, nCols, indc * nCols + 2)
         ##NOTE: This comes from above, re-defined in this file
         plot_waveforms_average_all(wavesThisCluster)
     #figTitle = self.get_title()
     plt.figtext(0.5,
                 0.92,
                 self.figTitle,
                 ha='center',
                 fontweight='bold',
                 fontsize=10)
     if showfig:
         #plt.draw()
         plt.show()
예제 #4
0
 def plot_report(self,showfig=False):
     print 'Plotting report...'
     #plt.figure(self.fig)
     self.fig = plt.gcf()
     self.fig.clf()
     self.fig.set_facecolor('w')
     nCols = 3
     nRows = self.nRows
     #for indc,clusterID in enumerate(self.clustersList[:3]):
     for indc,clusterID in enumerate(self.clustersList):
         #print('Preparing cluster %d'%clusterID)
         if (indc+1)>self.nRows:
             print 'WARNING! This cluster was ignored (more clusters than rows)'
             continue
         tsThisCluster = self.timestamps[self.spikesEachCluster[indc,:]]
         wavesThisCluster = self.samples[self.spikesEachCluster[indc,:],:,:]
         # -- Plot ISI histogram --
         plt.subplot(self.nRows,nCols,indc*nCols+1)
         spikesorting.plot_isi_loghist(tsThisCluster)
         if indc<(self.nClusters-1): #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')
         # -- Plot events in time --
         plt.subplot(2*self.nRows,nCols,2*(indc*nCols)+6)
         spikesorting.plot_events_in_time(tsThisCluster)
         if indc<(self.nClusters-1): #indc<2:#
             plt.xlabel('')
             plt.gca().set_xticklabels('')
         # -- Plot projections --
         plt.subplot(2*self.nRows,nCols,2*(indc*nCols)+3)
         spikesorting.plot_projections(wavesThisCluster)
         # -- Plot waveforms --
         plt.subplot(self.nRows,nCols,indc*nCols+2)
         spikesorting.plot_waveforms(wavesThisCluster)
     #figTitle = self.get_title()
     plt.figtext(0.5,0.92, self.figTitle,ha='center',fontweight='bold',fontsize=10)
     if showfig:
         #plt.draw()
         plt.show()
def laser_tc_analysis(site, sitenum):

    '''
    Data analysis function for laser/tuning curve experiments

    This function will take a RecordingSite object, do multisession clustering on it, and save all of the clusters 
    back to the original session cluster directories. We can then use an EphysExperiment object (version 2) 
    to load each session, select clusters, plot the appropriate plots, etc. This code is being removed from 
    the EphysExperiment object because that object should be general and apply to any kind of recording 
    experiment. This function does the data analysis for one specific kind of experiment. 
    
    Args:

        site (RecordingSite object): An instance of the RecordingSite class from the ephys_experiment_v2 module
        sitenum (int): The site number for the site, used for constructing directory names
    
    Example:
    
        from jaratoolbox.test.nick.ephysExperiments import laserTCanalysis
        for indSite, site in enumerate(today.siteList):
            laserTCanalysis.laser_tc_analysis(site, indSite+1)
    '''
    #This is where I should incorporate Lan's sorting function
    #Construct a multiple session clustering object with the session list. 
    for tetrode in site.goodTetrodes:

        oneTT = cms2.MultipleSessionsToCluster(site.animalName, site.get_session_filenames(), tetrode, '{}site{}'.format(site.date, sitenum))
        oneTT.load_all_waveforms()

        #Do the clustering if necessary. 
        clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file() 
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file() 

        oneTT.save_single_session_clu_files()
        '''
        0710: Ran into problems while saving single session clu file:
        ipdb> can't invoke "event" command: application has been destroyed
        while executing
        "event generate $w <<ThemeChanged>>"
        (procedure "ttk::ThemeChanged" line 6)
        invoked from within
        "ttk::ThemeChanged"
        Fixed by commenting out this line. 
        0712: This seemed to be a warning instead of actual error, still generated reports and save        ed files.
        0712: if not saving single session clu files, ran into another problem when ploting single         cluster raster plots. so have to save them and fix the
        ValueError:too many boolean indices
        > /home/languo/src/jaratoolbox/test/nick/ephysExperiments/clusterManySessions_v2.py(119)sav        e_single_session_clu_files()
        118 
        --> 119             clusterNumsThisSession = self.clusters[self.recordingNumber == indSessi        on]
        120             print "Writing .clu.1 file for session {}".format(session)
        This is fixed: old clu files from last clustering is messing me up, making self.clusters an        d self.RecordingNumber be different size.
        '''

        possibleClusters = np.unique(oneTT.clusters)

        #We also need to initialize an EphysExperiment object to get the sessions
        exp2 = ee2.EphysExperiment(site.animalName, site.date, experimenter = site.experimenter)

        #Iterate through the clusters, making a new figure for each cluster. 
        for indClust, cluster in enumerate(possibleClusters): #Using possibleClusters[1:] was a hack to omit cluster 1 which usually contains noise and sometimes don't have spikes in the range for raster plot to run properlyl
            plt.figure(figsize = (8.5,11))

            #The first noise burst raster plot
            plt.subplot2grid((5, 6), (0, 0), rowspan = 1, colspan = 3)
            nbIndex = site.get_session_types().index('noiseBurst')
            nbSession = site.get_session_filenames()[nbIndex]
            
            exp2.plot_session_raster(nbSession, tetrode, cluster = cluster, replace = 1)
            
            plt.ylabel('Noise Bursts')
            plt.title(nbSession, fontsize = 10)

            #The laser pulse raster plot
            plt.subplot2grid((5, 6), (1, 0), rowspan = 1, colspan = 3)
            lpIndex = site.get_session_types().index('laserPulse')
            lpSession = site.get_session_filenames()[lpIndex]
            exp2.plot_session_raster(lpSession, tetrode, cluster = cluster, replace = 1)
            plt.ylabel('Laser Pulses')
            plt.title(lpSession, fontsize = 10)

            #The laser train raster plot
            plt.subplot2grid((5, 6), (2, 0), rowspan = 1, colspan = 3)
            try:  
                ltIndex = site.get_session_types().index('laserTrain')
                ltSession = site.get_session_filenames()[ltIndex]
                exp2.plot_session_raster(ltSession, tetrode, cluster = cluster, replace = 1)
                plt.ylabel('Laser Trains')
                plt.title(ltSession, fontsize = 10)
            except ValueError:
                print 'This session doesnot exist.'

            #The tuning curve
            plt.subplot2grid((5, 6), (0, 3), rowspan = 3, colspan = 3)
            tcIndex = site.get_session_types().index('tuningCurve')
            tcSession = site.get_session_filenames()[tcIndex]
            tcBehavID = site.get_session_behavIDs()[tcIndex]
            exp2.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)

            '''
            The best freq presentation, if a session is not initialized, a Value            Error is raised  when indexing the list returned by the get_ methods            of ee2.RecordingSite. Could use Try Except ValueError?
            '''
           #plt.subplot2grid((6, 6), (3, 0), rowspan=1, colspan=3)
           #bfIndex = site.get_session_types().index('bestFreq')
           #bfSession = site.get_session_filenames()[bfIndex]
           #exp2.plot_session_raster(bfSession, tetrode, cluster = cluster, replace = 1)
           #plt.ylabel('Best Frequency')
           #plt.title(bfSession, fontsize = 10)

           #FIXME: Omitting the laser pulses at different intensities for now

            '''LG0710: Added reports (ISI, waveform, events in time, projections) for each cluster to its sessions summary graph. The MultiSessionClusterReport class initializer calls plot_report automatically... it's hard to get around doing repeated work (i.e. getting bits and pieces of necessary functionality out of this class manually) here without rewriting this class'''

            nSpikes = len(oneTT.timestamps) 
            nClusters = len(possibleClusters)
            spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            if oneTT.clusters == None:
                oneTT.set_clusters_from_file()
            for indc, clusterID in enumerate (possibleClusters):
                spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)
            
            tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust,:]]
            wavesThisCluster = oneTT.samples[spikesEachCluster[indClust,:],:,:]
            # -- Plot ISI histogram --
            plt.subplot2grid((5,6), (3,0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')

            # -- Plot waveforms --
            plt.subplot2grid((5,6), (4,0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((5,6), (3,3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)  
            
            # -- Plot events in time --
            plt.subplot2grid((5,6), (4,3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)
            
            #Save the figure in the multisession clustering folder so that it is easy to find
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            plt.tight_layout()
            plt.savefig(full_fig_path, format = 'png')
            #plt.show()
            plt.close()

        plt.figure()
        oneTT.save_multisession_report()
        plt.close()
def plot_rew_change_per_cell(oneCell,trialLimit=[],alignment='sound'):
    '''
    Plots raster and PSTH for one cell during reward_change_freq_dis task, split by block; alignment parameter should be set to either 'sound', 'center-out', or 'side-in'.
    '''    
    bdata = load_behav_per_cell(oneCell)
    (spikeTimestamps,waveforms,eventOnsetTimes,eventData) = load_ephys_per_cell(oneCell)

    # -- Check to see if ephys has skipped trials, if so remove trials from behav data 
    soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
    soundOnsetTimeEphys = eventOnsetTimes[soundOnsetEvents]
    soundOnsetTimeBehav = bdata['timeTarget']

    # Find missing trials
    missingTrials = behavioranalysis.find_missing_trials(soundOnsetTimeEphys,soundOnsetTimeBehav)
    # Remove missing trials
    bdata.remove_trials(missingTrials)

    currentBlock = bdata['currentBlock']
    blockTypes = [bdata.labels['currentBlock']['same_reward'],bdata.labels['currentBlock']['more_left'],bdata.labels['currentBlock']['more_right']]
    #blockLabels = ['more_left', 'more_right']
    if(not len(trialLimit)):
        validTrials = np.ones(len(currentBlock),dtype=bool)
    else:
        validTrials = np.zeros(len(currentBlock),dtype=bool)
        validTrials[trialLimit[0]:trialLimit[1]] = 1

    trialsEachType = behavioranalysis.find_trials_each_type(currentBlock,blockTypes)
    
        
    if alignment == 'sound':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
    elif alignment == 'center-out':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
        diffTimes=bdata['timeCenterOut']-bdata['timeTarget']
        EventOnsetTimes+=diffTimes
    elif alignment == 'side-in':
        soundOnsetEvents = (eventData.eventID==1) & (eventData.eventChannel==soundTriggerChannel)
        EventOnsetTimes = eventOnsetTimes[soundOnsetEvents]
        diffTimes=bdata['timeSideIn']-bdata['timeTarget']
        EventOnsetTimes+=diffTimes

    freqEachTrial = bdata['targetFrequency']
    possibleFreq = np.unique(freqEachTrial)
    
    rightward = bdata['choice']==bdata.labels['choice']['right']
    leftward = bdata['choice']==bdata.labels['choice']['left']
    invalid = bdata['outcome']==bdata.labels['outcome']['invalid']
        
    correct = bdata['outcome']==bdata.labels['outcome']['correct'] 
    incorrect = bdata['outcome']==bdata.labels['outcome']['error']  

    ######Split left and right trials into correct and  incorrect categories to look at error trials#########
    rightcorrect = rightward&correct&validTrials
    leftcorrect = leftward&correct&validTrials
    #righterror = rightward&incorrect&validTrials
    #lefterror = leftward&incorrect&validTrials

    rightcorrectBlockSameReward = rightcorrect&trialsEachType[:,0]
    rightcorrectBlockMoreLeft = rightcorrect&trialsEachType[:,1] 
    rightcorrectBlockMoreRight = rightcorrect&trialsEachType[:,2]
    leftcorrectBlockSameReward = leftcorrect&trialsEachType[:,0]
    leftcorrectBlockMoreLeft = leftcorrect&trialsEachType[:,1]
    leftcorrectBlockMoreRight = leftcorrect&trialsEachType[:,2]

    trialsEachCond = np.c_[leftcorrectBlockMoreLeft,rightcorrectBlockMoreLeft,leftcorrectBlockMoreRight,rightcorrectBlockMoreRight,leftcorrectBlockSameReward,rightcorrectBlockSameReward] 


    colorEachCond = ['g','r','m','b','y','darkgray']
    #trialsEachCond = np.c_[invalid,leftcorrect,rightcorrect,lefterror,righterror] 
    #colorEachCond = ['0.75','g','r','b','m'] 

    (spikeTimesFromEventOnset,trialIndexForEachSpike,indexLimitsEachTrial) = \
spikesanalysis.eventlocked_spiketimes(spikeTimestamps,EventOnsetTimes,timeRange)
    
    ###########Plot raster and PSTH#################
    plt.figure()
    ax1 = plt.subplot2grid((8,5), (0, 0), rowspan=4,colspan=5)
    extraplots.raster_plot(spikeTimesFromEventOnset,indexLimitsEachTrial,timeRange,trialsEachCond=trialsEachCond,colorEachCond=colorEachCond,fillWidth=None,labels=None)
    plt.ylabel('Trials')
    plt.xlim(timeRange)

    plt.title('{0}_{1}_TT{2}_c{3}_{4}'.format(oneCell.animalName,oneCell.behavSession,oneCell.tetrode,oneCell.cluster,alignment))

    timeVec = np.arange(timeRange[0],timeRange[-1],binWidth)
    spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset,indexLimitsEachTrial,timeVec)
    smoothWinSize = 3
    ax2 = plt.subplot2grid((8,5), (4, 0),colspan=5,rowspan=2,sharex=ax1)
    extraplots.plot_psth(spikeCountMat/binWidth,smoothWinSize,timeVec,trialsEachCond=trialsEachCond,colorEachCond=colorEachCond,linestyle=None,linewidth=3,downsamplefactor=1)
    plt.xlabel('Time from {0} onset (s)'.format(alignment))
    plt.ylabel('Firing rate (spk/sec)')
   
    # -- Plot ISI histogram --
    plt.subplot2grid((8,5), (6,0), rowspan=1, colspan=2)
    spikesorting.plot_isi_loghist(spikeTimestamps)
    plt.ylabel('c%d'%oneCell.cluster,rotation=0,va='center',ha='center')
    plt.xlabel('')

    # -- Plot waveforms --
    plt.subplot2grid((8,5), (7,0), rowspan=1, colspan=3)
    spikesorting.plot_waveforms(waveforms)

    # -- Plot projections --
    plt.subplot2grid((8,5), (6,2), rowspan=1, colspan=3)
    spikesorting.plot_projections(waveforms)

    # -- Plot events in time --
    plt.subplot2grid((8,5), (7,3), rowspan=1, colspan=2)
    spikesorting.plot_events_in_time(spikeTimestamps)

    plt.subplots_adjust(wspace = 0.7)
    
    #plt.show()
    #fig_path = 
    #fig_name = 'TT{0}Cluster{1}{2}.png'.format(tetrode, cluster, '_2afc plot_each_type')
    #full_fig_path = os.path.join(fig_path, fig_name)
    #print full_fig_path
    plt.gcf().set_size_inches((8.5,11))
예제 #7
0
    def generate_main_report(self, siteName):
        '''
        Generate the reports for all of the sessions in this site. This is where we should interface with
        the multiunit clustering code, since all of the sessions that need to be clustered together have
        been defined at this point.

        FIXME: This method should be able to load some kind of report template perhaps, so that the
        types of reports we can make are not a limited. For instance, what happens when we just have
        rasters for a site and no tuning curve? Implementing this is a lower priority for now.

        Incorporated lan's code for plotting the cluster reports directly on the main report
        '''
        #FIXME: import another piece of code to do this?
        #FIXME: OR, break into two functions: one that will do the multisite clustering, and one that
        #knows the type of report that we want. The first one can probably be a method of MSTC, the other
        #should either live in extraplots or should go in someone's directory

        for tetrode in self.goodTetrodes:
            oneTT = cms2.MultipleSessionsToCluster(self.animalName, self.get_session_filenames(), tetrode, '{}_{}'.format(self.date, siteName))
            oneTT.load_all_waveforms()

            #Do the clustering if necessary.
            clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
            if os.path.isfile(clusterFile):
                oneTT.set_clusters_from_file()
            else:
                oneTT.create_multisession_fet_files()
                oneTT.run_clustering()
                oneTT.set_clusters_from_file()

            oneTT.save_single_session_clu_files()
            possibleClusters = np.unique(oneTT.clusters)

            ee = EphysExperiment(self.animalName, self.date, experimenter = self.experimenter)

            #Iterate through the clusters, making a new figure for each cluster.
            #for indClust, cluster in enumerate([3]):
            for indClust, cluster in enumerate(possibleClusters):


                mainRasterInds = self.get_session_inds_one_type(plotType='raster', report='main')
                mainRasterSessions = [self.get_session_filenames()[i] for i in mainRasterInds]
                mainRasterTypes = [self.get_session_types()[i] for i in mainRasterInds]

                mainTCinds = self.get_session_inds_one_type(plotType='tc_heatmap', report='main')
                mainTCsessions = [self.get_session_filenames()[i] for i in mainTCinds]

                mainTCbehavIDs = [self.get_session_behavIDs()[i] for i in mainTCinds]
                mainTCtypes = [self.get_session_types()[i] for i in mainTCinds]

                plt.figure() #The main report for this cluster/tetrode/session

                for indRaster, rasterSession in enumerate(mainRasterSessions):
                    plt.subplot2grid((6, 6), (indRaster, 0), rowspan = 1, colspan = 3)
                    ee.plot_session_raster(rasterSession, tetrode, cluster = cluster, replace = 1, ms=1)
                    plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster], rasterSession.split('_')[1]), fontsize = 10)
                    ax=plt.gca()
                    extraplots.set_ticks_fontsize(ax,6)

                #We can only do one main TC for now.
                if len(mainTCsessions)>0:
                    plt.subplot2grid((6, 6), (0, 3), rowspan = 3, colspan = 3)
                    #tcIndex = site.get_session_types().index('tuningCurve')
                    tcSession = mainTCsessions[0]
                    tcBehavID = mainTCbehavIDs[0]
                    ee.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
                    plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)

                nSpikes = len(oneTT.timestamps)
                nClusters = len(possibleClusters)
                #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
                #if oneTT.clusters == None:
                    #oneTT.set_clusters_from_file()
                #for indc, clusterID in enumerate (possibleClusters):
                    #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

                tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
                wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]
                # -- Plot ISI histogram --
                plt.subplot2grid((6,6), (4,0), rowspan=1, colspan=3)
                spikesorting.plot_isi_loghist(tsThisCluster)
                plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
                plt.xlabel('')

                # -- Plot waveforms --
                plt.subplot2grid((6,6), (5,0), rowspan=1, colspan=3)
                spikesorting.plot_waveforms(wavesThisCluster)

                # -- Plot projections --
                plt.subplot2grid((6,6), (4,3), rowspan=1, colspan=3)
                spikesorting.plot_projections(wavesThisCluster)

                # -- Plot events in time --
                plt.subplot2grid((6,6), (5,3), rowspan=1, colspan=3)
                spikesorting.plot_events_in_time(tsThisCluster)

                plt.subplots_adjust(wspace = 0.7)
                fig_path = oneTT.clustersDir
                fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
                full_fig_path = os.path.join(fig_path, fig_name)
                print full_fig_path
                #plt.tight_layout()
                plt.savefig(full_fig_path, format = 'png')
                #plt.show()
                plt.close()


            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
예제 #8
0
    def generate_reports(self):
        '''
        Generate the reports for all of the sessions in this site. This is where we should interface with
        the multiunit clustering code, since all of the sessions that need to be clustered together have
        been defined at this point. 
        
        FIXME: This method should be able to load some kind of report template perhaps, so that the
        types of reports we can make are not a limited. For instance, what happens when we just have 
        rasters for a site and no tuning curve? Implementing this is a lower priority for now. 
        
        Incorporated lan's code for plotting the cluster reports directly on the main report
        '''
        #FIXME: import another piece of code to do this?

        for tetrode in self.goodTetrodes:
            oneTT = cms2.MultipleSessionsToCluster(self.animalName, self.get_session_filenames(), tetrode, '{}at{}um'.format(self.date, self.depth))
            oneTT.load_all_waveforms()

            #Do the clustering if necessary. 
            clusterFile = os.path.join(oneTT.clustersDir,'Tetrode%d.clu.1'%oneTT.tetrode)
            if os.path.isfile(clusterFile):
                oneTT.set_clusters_from_file() 
            else:
                oneTT.create_multisession_fet_files()
                oneTT.run_clustering()
                oneTT.set_clusters_from_file() 

            oneTT.save_single_session_clu_files()
            possibleClusters = np.unique(oneTT.clusters)
            
            exp2 = ee2.EphysExperiment(self.animalName, self.date, experimenter = self.experimenter)

            #Iterate through the clusters, making a new figure for each cluster. 
            for indClust, cluster in enumerate(possibleClusters):

                plt.figure() #The main report for this cluster/tetrode/session
                
                mainRasterInds = self.get_session_inds_one_type(plotType='raster', report='main')
                mainRasterSessions = self.get_session_filenames()[mainRasterInds]
                mainRasterTypes = self.get_session_types()[mainRasterInds]
                
                mainTCinds = self.get_session_inds_one_type(plotType='tc_heatmap', report='main')
                mainTCsessions = self.get_session_filenames()[mainTCinds]
                mainTCbehavIDs = self.get_session_behavIDs()[mainTCinds]
                mainTCtypes = self.get_session_types()[mainTCinds]
                
                for indRaster, rasterSession in enumerate(mainRasterSessions):
                    plt.subplot2grid((6, 6), (indRaster, 0), rowspan = 1, colspan = 3)
                    exp2.plot_session_raster(rasterSession, tetrode, cluster = cluster, replace = 1)
                    plt.ylabel(mainRasterTypes[indRaster])
                    plt.title(rasterSession, fontsize = 10)
                

                #We can only do one main TC for now. 
                plt.subplot2grid((6, 6), (0, 3), rowspan = 3, colspan = 3)
                tcIndex = site.get_session_types().index('tuningCurve')
                tcSession = mainTCsessions[0]
                tcBehavID = mainTCbehavIDs[0]
                exp2.plot_session_tc_heatmap(tcSession, tetrode, tcBehavID, replace = 1, cluster = cluster)
                plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID), fontsize = 10)


                nSpikes = len(oneTT.timestamps) 
                nClusters = len(possibleClusters)
                spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
                if oneTT.clusters == None:
                    oneTT.set_clusters_from_file()
                for indc, clusterID in enumerate (possibleClusters):
                    spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

                tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust,:]]
                wavesThisCluster = oneTT.samples[spikesEachCluster[indClust,:],:,:]
                # -- Plot ISI histogram --
                plt.subplot2grid((6,6), (4,0), rowspan=1, colspan=3)
                spikesorting.plot_isi_loghist(tsThisCluster)
                plt.ylabel('c%d'%clusterID,rotation=0,va='center',ha='center')

                # -- Plot waveforms --
                plt.subplot2grid((6,6), (5,0), rowspan=1, colspan=3)
                spikesorting.plot_waveforms(wavesThisCluster)

                # -- Plot projections --
                plt.subplot2grid((6,6), (4,3), rowspan=1, colspan=3)
                spikesorting.plot_projections(wavesThisCluster)  

                # -- Plot events in time --
                plt.subplot2grid((6,6), (5,3), rowspan=1, colspan=3)
                spikesorting.plot_events_in_time(tsThisCluster)

                fig_path = oneTT.clustersDir
                fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
                full_fig_path = os.path.join(fig_path, fig_name)
                print full_fig_path
                plt.tight_layout()
                plt.savefig(full_fig_path, format = 'png')
                #plt.show()
                plt.close()

            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
예제 #9
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0] / 1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']

    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()

            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3],
                                                     tetrode,
                                                     cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)

            firstSortLabels = [
                '{}'.format(band) for band in np.unique(currentBand)
            ]
            secondSortLabels = [
                'Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)
            ]

            trialsEachCond = behavioranalysis.find_trials_each_combination(
                currentBand, numBands, currentAmp, numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5 * ind, 0), rowspan=4, colspan=7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(
                    spikeTimesFromEventOnset,
                    indexLimitsEachTrial,
                    timeRange,
                    trialsEachCond=trialsThisSecondVal,
                    labels=firstSortLabels)
                plt.setp(pRaster, ms=4)

                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")

            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12, 15), (10, 0), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.0, 1.0],
                             title='bandwidth selectivity')
            plt.subplot2grid((12, 15), (10, 3), rowspan=2, colspan=3)
            band_select_plot(spikeTimestamps,
                             eventOnsetTimes,
                             currentAmp,
                             currentBand, [0.2, 1.0],
                             title='first 200ms excluded')

            # -- plot frequency tuning heat map --
            plt.subplot2grid((12, 15), (5, 7), rowspan=4, colspan=4)

            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps

            dataplotter.two_axis_heatmap(
                spikeTimestamps=spikeTimestamps,
                eventOnsetTimes=eventOnsetTimes,
                firstSortArray=currentInt,
                secondSortArray=currentFreq,
                firstSortLabels=[
                    "%.1f" % inten for inten in np.unique(currentInt)
                ],
                secondSortLabels=[
                    "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
                ],
                xlabel='Frequency (kHz)',
                ylabel='Intensity (dB SPL)',
                plotTitle='Frequency Tuning Curve',
                flipFirstAxis=True,
                flipSecondAxis=False,
                timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')

            # -- plot frequency tuning raster --
            plt.subplot2grid((12, 15), (0, 7), rowspan=4, colspan=4)
            freqLabels = [
                "%.1f" % freq for freq in np.unique(currentFreq) / 1000.0
            ]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentFreq,
                                    timeRange=[-0.1, 0.5],
                                    labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')

            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2],
                                                      tetrode,
                                                      cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]

            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                spikeTimestamps, eventOnsetTimes, timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(
                currentRate, numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(
                spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12, 15), (5, 11), rowspan=4, colspan=4)
            pPSTH = extraplots.plot_psth(spikeCountMat / 0.05,
                                         1,
                                         binEdges[:-1],
                                         trialsEachCond,
                                         colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')

            # -- plot AM raster --
            plt.subplot2grid((12, 15), (0, 11), rowspan=4, colspan=4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps,
                                    eventOnsetTimes,
                                    sortArray=currentRate,
                                    timeRange=[-0.2, 0.8],
                                    labels=rateLabels,
                                    colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')

            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((12, 15), (10, 6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12, 15), (10, 9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12, 15), (10, 12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12, 15), (11, 12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=1.5)
            plt.suptitle(
                '{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'
                .format(mouse, date, siteName, tetrode, cluster, charfreq,
                        modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format='png', bbox_inches='tight')
예제 #10
0
def nick_lan_main_report(siteObj,
                         show=False,
                         save=True,
                         saveClusterReport=True):
    for tetrode in siteObj.goodTetrodes:
        oneTT = cms2.MultipleSessionsToCluster(
            siteObj.animalName, siteObj.get_session_filenames(), tetrode,
            '{}at{}um'.format(siteObj.date, siteObj.depth))
        oneTT.load_all_waveforms()

        # Do the clustering if necessary.
        clusterFile = os.path.join(oneTT.clustersDir,
                                   'Tetrode%d.clu.1' % oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        possibleClusters = np.unique(oneTT.clusters)

        ee = ee3.EphysExperiment(siteObj.animalName,
                                 siteObj.date,
                                 experimenter=siteObj.experimenter)

        # Iterate through the clusters, making a new figure for each cluster.
        # for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterInds = siteObj.get_session_inds_one_type(
                plotType='raster', report='main')
            mainRasterSessions = [
                siteObj.get_session_filenames()[i] for i in mainRasterInds
            ]
            mainRasterTypes = [
                siteObj.get_session_types()[i] for i in mainRasterInds
            ]

            mainTCinds = siteObj.get_session_inds_one_type(
                plotType='tc_heatmap', report='main')
            mainTCsessions = [
                siteObj.get_session_filenames()[i] for i in mainTCinds
            ]

            mainTCbehavIDs = [
                siteObj.get_session_behavIDs()[i] for i in mainTCinds
            ]
            mainTCtypes = [siteObj.get_session_types()[i] for i in mainTCinds]

            # The main report for this cluster/tetrode/session
            plt.figure()

            for indRaster, rasterSession in enumerate(mainRasterSessions):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)
                ee.plot_session_raster(rasterSession,
                                       tetrode,
                                       cluster=cluster,
                                       replace=1,
                                       ms=1)
                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster],
                                           rasterSession.split('_')[1]),
                           fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(ax, 6)

            # We can only do one main TC for now.
            plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)
            tcSession = mainTCsessions[0]
            tcBehavID = mainTCbehavIDs[0]
            ee.plot_session_tc_heatmap(tcSession,
                                       tetrode,
                                       tcBehavID,
                                       replace=1,
                                       cluster=cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID),
                      fontsize=10)

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            # if oneTT.clusters == None:
            # oneTT.set_clusters_from_file()
            # for indc, clusterID in enumerate (possibleClusters):
            #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]
            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            # plt.tight_layout()

            if save:
                plt.savefig(full_fig_path, format='png')
            if show:
                plt.show()
            if not show:
                plt.close()

        if saveClusterReport:
            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
예제 #11
0
def nick_lan_main_report(siteObj, show=False, save=True, saveClusterReport=True):
    for tetrode in siteObj.goodTetrodes:
        oneTT = cms2.MultipleSessionsToCluster(
            siteObj.animalName,
            siteObj.get_session_filenames(),
            tetrode,
            '{}at{}um'.format(
                siteObj.date,
                siteObj.depth))
        oneTT.load_all_waveforms()

        # Do the clustering if necessary.
        clusterFile = os.path.join(
            oneTT.clustersDir,
            'Tetrode%d.clu.1' %
            oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        possibleClusters = np.unique(oneTT.clusters)

        ee = ee3.EphysExperiment(
            siteObj.animalName,
            siteObj.date,
            experimenter=siteObj.experimenter)

        # Iterate through the clusters, making a new figure for each cluster.
        # for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterInds = siteObj.get_session_inds_one_type(
                plotType='raster',
                report='main')
            mainRasterSessions = [
                siteObj.get_session_filenames()[i] for i in mainRasterInds]
            mainRasterTypes = [
                siteObj.get_session_types()[i] for i in mainRasterInds]

            mainTCinds = siteObj.get_session_inds_one_type(
                plotType='tc_heatmap',
                report='main')
            mainTCsessions = [
                siteObj.get_session_filenames()[i] for i in mainTCinds]

            mainTCbehavIDs = [
                siteObj.get_session_behavIDs()[i] for i in mainTCinds]
            mainTCtypes = [siteObj.get_session_types()[i] for i in mainTCinds]

            # The main report for this cluster/tetrode/session
            plt.figure()

            for indRaster, rasterSession in enumerate(mainRasterSessions):
                plt.subplot2grid(
                    (6, 6), (indRaster, 0), rowspan=1, colspan=3)
                ee.plot_session_raster(
                    rasterSession,
                    tetrode,
                    cluster=cluster,
                    replace=1,
                    ms=1)
                plt.ylabel(
                    '{}\n{}'.format(
                        mainRasterTypes[indRaster],
                        rasterSession.split('_')[1]),
                    fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(ax, 6)

            # We can only do one main TC for now.
            plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)
            tcSession = mainTCsessions[0]
            tcBehavID = mainTCbehavIDs[0]
            ee.plot_session_tc_heatmap(
                tcSession,
                tetrode,
                tcBehavID,
                replace=1,
                cluster=cluster)
            plt.title(
                "{0}\nBehavFileID = '{1}'".format(
                    tcSession,
                    tcBehavID),
                fontsize=10)

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            #spikesEachCluster = np.empty((nClusters, nSpikes),dtype = bool)
            # if oneTT.clusters == None:
            # oneTT.set_clusters_from_file()
            # for indc, clusterID in enumerate (possibleClusters):
            #spikesEachCluster[indc, :] = (oneTT.clusters==clusterID)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]
            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel(
                'c%d' %
                cluster,
                rotation=0,
                va='center',
                ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            # plt.tight_layout()

            if save: 
                plt.savefig(full_fig_path, format='png')
            if show:
                plt.show()
            if not show:
                plt.close()

        if saveClusterReport:
            plt.figure()
            oneTT.save_multisession_report()
            plt.close()
예제 #12
0
def nick_lan_daily_report(site,
                          siteName,
                          mainRasterInds,
                          mainTCind,
                          pcasort=True):
    '''

    '''

    loader = dataloader.DataLoader('offline', experimenter=site.experimenter)

    for tetrode in site.tetrodes:

        #Tetrodes with no spikes will cause an error when clustering
        try:
            if pcasort:
                oneTT = cluster_site_PCA(site, siteName, tetrode)
            else:
                oneTT = cluster_site(site, siteName, tetrode)
        except AttributeError:
            print "There was an attribute error for tetrode {} at {}".format(
                tetrode, siteName)
            continue

        possibleClusters = np.unique(oneTT.clusters)

        #Iterate through the clusters, making a new figure for each cluster.
        #for indClust, cluster in enumerate([3]):
        for indClust, cluster in enumerate(possibleClusters):

            mainRasterEphysFilenames = [
                site.get_mouse_relative_ephys_filenames()[i]
                for i in mainRasterInds
            ]
            mainRasterTypes = [
                site.get_session_types()[i] for i in mainRasterInds
            ]
            if mainTCind:
                mainTCsession = site.get_mouse_relative_ephys_filenames(
                )[mainTCind]
                mainTCbehavFilename = site.get_mouse_relative_behav_filenames(
                )[mainTCind]
                mainTCtype = site.get_session_types()[mainTCind]
            else:
                mainTCsession = None

            # plt.figure() #The main report for this cluster/tetrode/session
            plt.clf()

            for indRaster, rasterSession in enumerate(
                    mainRasterEphysFilenames):
                plt.subplot2grid((6, 6), (indRaster, 0), rowspan=1, colspan=3)

                rasterSpikes = loader.get_session_spikes(
                    rasterSession, tetrode)
                spikeTimestamps = rasterSpikes.timestamps[rasterSpikes.clusters
                                                          == cluster]

                rasterEvents = loader.get_session_events(rasterSession)
                eventOnsetTimes = loader.get_event_onset_times(rasterEvents)

                dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, ms=1)

                plt.ylabel('{}\n{}'.format(mainRasterTypes[indRaster],
                                           rasterSession.split('_')[1]),
                           fontsize=10)
                ax = plt.gca()
                extraplots.set_ticks_fontsize(
                    ax, 6)  #Should this go in dataplotter?

            #We can only do one main TC for now.
            if mainTCsession:

                plt.subplot2grid((6, 6), (0, 3), rowspan=3, colspan=3)

                bdata = loader.get_session_behavior(mainTCbehavFilename)
                plotTitle = loader.get_session_filename(mainTCsession)
                eventData = loader.get_session_events(mainTCsession)
                spikeData = loader.get_session_spikes(mainTCsession, tetrode)

                spikeTimestamps = spikeData.timestamps[spikeData.clusters ==
                                                       cluster]

                eventOnsetTimes = loader.get_event_onset_times(eventData)

                freqEachTrial = bdata['currentFreq']
                intensityEachTrial = bdata['currentIntensity']

                possibleFreq = np.unique(freqEachTrial)
                possibleIntensity = np.unique(intensityEachTrial)

                xlabel = 'Frequency (kHz)'
                ylabel = 'Intensity (dB SPL)'

                # firstSortLabels = ["%.1f" % freq for freq in possibleFreq/1000.0]
                # secondSortLabels = ['{}'.format(inten) for inten in possibleIntensity]

                # dataplotter.two_axis_heatmap(spikeTimestamps,
                #                             eventOnsetTimes,
                #                             freqEachTrial,
                #                             intensityEachTrial,
                #                             firstSortLabels,
                #                             secondSortLabels,
                #                             xlabel,
                #                             ylabel,
                #                             plotTitle=plotTitle,
                #                             flipFirstAxis=False,
                #                             flipSecondAxis=True,
                #                             timeRange=[0, 0.1])

                freqLabels = ["%.1f" % freq for freq in possibleFreq / 1000.0]
                intenLabels = ["%.1f" % inten for inten in possibleIntensity]

                dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                             eventOnsetTimes=eventOnsetTimes,
                                             firstSortArray=intensityEachTrial,
                                             secondSortArray=freqEachTrial,
                                             firstSortLabels=intenLabels,
                                             secondSortLabels=freqLabels,
                                             xlabel=xlabel,
                                             ylabel=ylabel,
                                             plotTitle=plotTitle,
                                             flipFirstAxis=True,
                                             flipSecondAxis=False,
                                             timeRange=[0, 0.1])

                plt.title("{0}\n{1}".format(mainTCsession,
                                            mainTCbehavFilename),
                          fontsize=10)
                plt.show()

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)

            tsThisCluster = oneTT.timestamps[oneTT.clusters == cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters == cluster]

            # -- Plot ISI histogram --
            plt.subplot2grid((6, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % cluster, rotation=0, va='center', ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((6, 6), (5, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((6, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((6, 6), (5, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace=0.7)
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            #plt.tight_layout()
            plt.savefig(full_fig_path, format='png')
예제 #13
0
def plot_bandwidth_report(mouse, date, site, siteName):
    sessions = site.get_session_ephys_filenames()
    behavFilename = site.get_session_behav_filenames()
    ei = ephysinterface.EphysInterface(mouse, date, '', 'bandwidth_am')
    bdata = ei.loader.get_session_behavior(behavFilename[3][-4:-3])
    charfreq = str(np.unique(bdata['charFreq'])[0]/1000)
    modrate = str(np.unique(bdata['modRate'])[0])
    ei2 = ephysinterface.EphysInterface(mouse, date, '', 'am_tuning_curve')
    bdata2 = ei2.loader.get_session_behavior(behavFilename[1][-4:-3])  
    bdata3 = ei2.loader.get_session_behavior(behavFilename[2][-4:-3])  
    currentFreq = bdata2['currentFreq']
    currentBand = bdata['currentBand']
    currentAmp = bdata['currentAmp']
    currentInt = bdata2['currentIntensity']
    currentRate = bdata3['currentFreq']
      
    #for tetrode in site.tetrodes:
    for tetrode in [2]:
        oneTT = sitefuncs.cluster_site(site, siteName, tetrode)
        dataSpikes = ei.loader.get_session_spikes(sessions[3], tetrode)
        dataSpikes2 = ei2.loader.get_session_spikes(sessions[1], tetrode)
        #clusters = np.unique(dataSpikes.clusters)
        clusters = [8]
        for cluster in clusters:
            plt.clf()
            
            # -- plot bandwidth rasters --
            eventData = ei.loader.get_session_events(sessions[3])
            spikeData = ei.loader.get_session_spikes(sessions[3], tetrode, cluster=cluster)
            eventOnsetTimes = ei.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            numBands = np.unique(currentBand)
            numAmps = np.unique(currentAmp)
            
            firstSortLabels = ['{}'.format(band) for band in np.unique(currentBand)]
            secondSortLabels = ['Amplitude: {}'.format(amp) for amp in np.unique(currentAmp)]
            
            trialsEachCond = behavioranalysis.find_trials_each_combination(currentBand, 
                                                                           numBands, 
                                                                           currentAmp, 
                                                                           numAmps)
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            for ind, secondArrayVal in enumerate(numAmps):
                plt.subplot2grid((12, 15), (5*ind, 0), rowspan = 4, colspan = 7)
                trialsThisSecondVal = trialsEachCond[:, :, ind]
                pRaster, hcond, zline = extraplots.raster_plot(spikeTimesFromEventOnset,
                                                                indexLimitsEachTrial,
                                                                timeRange,
                                                                trialsEachCond=trialsThisSecondVal,
                                                                labels=firstSortLabels)
                plt.setp(pRaster, ms=4)
                
                plt.title(secondSortLabels[ind])
                plt.ylabel('bandwidth (octaves)')
                if ind == len(np.unique(currentAmp)) - 1:
                    plt.xlabel("Time from sound onset (sec)")
            
            # -- plot Yashar plots for bandwidth data --
            plt.subplot2grid((12,15), (10,0), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.0, 1.0], title='bandwidth selectivity')
            plt.subplot2grid((12,15), (10,3), rowspan = 2, colspan = 3)
            band_select_plot(spikeTimestamps, eventOnsetTimes, currentAmp, currentBand, [0.2, 1.0], title='first 200ms excluded')
            
            # -- plot frequency tuning heat map -- 
            plt.subplot2grid((12, 15), (5, 7), rowspan = 4, colspan = 4)
            
            eventData = ei2.loader.get_session_events(sessions[1])
            spikeData = ei2.loader.get_session_spikes(sessions[1], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            
            dataplotter.two_axis_heatmap(spikeTimestamps=spikeTimestamps,
                                            eventOnsetTimes=eventOnsetTimes,
                                            firstSortArray=currentInt,
                                            secondSortArray=currentFreq,
                                            firstSortLabels=["%.1f" % inten for inten in np.unique(currentInt)],
                                            secondSortLabels=["%.1f" % freq for freq in np.unique(currentFreq)/1000.0],
                                            xlabel='Frequency (kHz)',
                                            ylabel='Intensity (dB SPL)',
                                            plotTitle='Frequency Tuning Curve',
                                            flipFirstAxis=True,
                                            flipSecondAxis=False,
                                            timeRange=[0, 0.1])
            plt.ylabel('Intensity (dB SPL)')
            plt.xlabel('Frequency (kHz)')
            plt.title('Frequency Tuning Curve')
            
            # -- plot frequency tuning raster --
            plt.subplot2grid((12,15), (0, 7), rowspan = 4, colspan = 4)
            freqLabels = ["%.1f" % freq for freq in np.unique(currentFreq)/1000.0]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentFreq, timeRange=[-0.1, 0.5], labels=freqLabels)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Frequency (kHz)')
            plt.title('Frequency Tuning Raster')
            
            # -- plot AM PSTH --
            eventData = ei2.loader.get_session_events(sessions[2])
            spikeData = ei2.loader.get_session_spikes(sessions[2], tetrode, cluster=cluster)
            eventOnsetTimes = ei2.loader.get_event_onset_times(eventData)
            spikeTimestamps = spikeData.timestamps
            timeRange = [-0.2, 1.5]
            
            spikeTimesFromEventOnset, trialIndexForEachSpike, indexLimitsEachTrial = spikesanalysis.eventlocked_spiketimes(
                                                                                                        spikeTimestamps, 
                                                                                                        eventOnsetTimes,
                                                                                                        timeRange)
            colourList = ['b', 'g', 'y', 'orange', 'r']
            numRates = np.unique(currentRate)
            trialsEachCond = behavioranalysis.find_trials_each_type(currentRate, 
                                                                           numRates)
            binEdges = np.around(np.arange(-0.2, 0.85, 0.05), decimals=2)
            spikeCountMat = spikesanalysis.spiketimes_to_spikecounts(spikeTimesFromEventOnset, indexLimitsEachTrial, binEdges)
            plt.subplot2grid((12,15), (5, 11), rowspan = 4, colspan = 4)
            pPSTH = extraplots.plot_psth(spikeCountMat/0.05, 1, binEdges[:-1], trialsEachCond, colorEachCond=colourList)
            plt.setp(pPSTH)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Firing rate (Hz)')
            plt.title('AM PSTH')
            
            # -- plot AM raster --
            plt.subplot2grid((12,15), (0, 11), rowspan = 4, colspan = 4)
            rateLabels = ["%.1f" % rate for rate in np.unique(currentRate)]
            dataplotter.plot_raster(spikeTimestamps, eventOnsetTimes, sortArray=currentRate, timeRange=[-0.2, 0.8], labels=rateLabels, colorEachCond=colourList)
            plt.xlabel('Time from sound onset (sec)')
            plt.ylabel('Modulation Rate (Hz)')
            plt.title('AM Raster')
            
            # -- show cluster analysis --
            tsThisCluster = oneTT.timestamps[oneTT.clusters==cluster]
            wavesThisCluster = oneTT.samples[oneTT.clusters==cluster]
            
            # -- Plot ISI histogram --
            plt.subplot2grid((12,15), (10,6), rowspan=2, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d'%cluster,rotation=0,va='center',ha='center')
            plt.xlabel('')

            # -- Plot waveforms --
            plt.subplot2grid((12,15), (10,9), rowspan=2, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((12,15), (10,12), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((12,15), (11,12), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            plt.subplots_adjust(wspace = 1.5)
            plt.suptitle('{0}, {1}, {2}, Tetrode {3}, Cluster {4}, {5}kHz, {6}Hz modulation'.format(mouse, date, siteName, tetrode, cluster, charfreq, modrate))
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            fig = plt.gcf()
            fig.set_size_inches(24, 12)
            fig.savefig(full_fig_path, format = 'png', bbox_inches='tight')
def laser_tc_analysis(site, sitenum):
    '''
    Data analysis function for laser/tuning curve experiments

    This function will take a RecordingSite object, do multisession clustering on it, and save all of the clusters 
    back to the original session cluster directories. We can then use an EphysExperiment object (version 2) 
    to load each session, select clusters, plot the appropriate plots, etc. This code is being removed from 
    the EphysExperiment object because that object should be general and apply to any kind of recording 
    experiment. This function does the data analysis for one specific kind of experiment. 
    
    Args:

        site (RecordingSite object): An instance of the RecordingSite class from the ephys_experiment_v2 module
        sitenum (int): The site number for the site, used for constructing directory names
    
    Example:
    
        from jaratoolbox.test.nick.ephysExperiments import laserTCanalysis
        for indSite, site in enumerate(today.siteList):
            laserTCanalysis.laser_tc_analysis(site, indSite+1)
    '''
    #This is where I should incorporate Lan's sorting function
    #Construct a multiple session clustering object with the session list.
    for tetrode in site.goodTetrodes:

        oneTT = cms2.MultipleSessionsToCluster(
            site.animalName, site.get_session_filenames(), tetrode,
            '{}site{}'.format(site.date, sitenum))
        oneTT.load_all_waveforms()

        #Do the clustering if necessary.
        clusterFile = os.path.join(oneTT.clustersDir,
                                   'Tetrode%d.clu.1' % oneTT.tetrode)
        if os.path.isfile(clusterFile):
            oneTT.set_clusters_from_file()
        else:
            oneTT.create_multisession_fet_files()
            oneTT.run_clustering()
            oneTT.set_clusters_from_file()

        oneTT.save_single_session_clu_files()
        '''
        0710: Ran into problems while saving single session clu file:
        ipdb> can't invoke "event" command: application has been destroyed
        while executing
        "event generate $w <<ThemeChanged>>"
        (procedure "ttk::ThemeChanged" line 6)
        invoked from within
        "ttk::ThemeChanged"
        Fixed by commenting out this line. 
        0712: This seemed to be a warning instead of actual error, still generated reports and save        ed files.
        0712: if not saving single session clu files, ran into another problem when ploting single         cluster raster plots. so have to save them and fix the
        ValueError:too many boolean indices
        > /home/languo/src/jaratoolbox/test/nick/ephysExperiments/clusterManySessions_v2.py(119)sav        e_single_session_clu_files()
        118 
        --> 119             clusterNumsThisSession = self.clusters[self.recordingNumber == indSessi        on]
        120             print "Writing .clu.1 file for session {}".format(session)
        This is fixed: old clu files from last clustering is messing me up, making self.clusters an        d self.RecordingNumber be different size.
        '''

        possibleClusters = np.unique(oneTT.clusters)

        #We also need to initialize an EphysExperiment object to get the sessions
        exp2 = ee2.EphysExperiment(site.animalName,
                                   site.date,
                                   experimenter=site.experimenter)

        #Iterate through the clusters, making a new figure for each cluster.
        for indClust, cluster in enumerate(
                possibleClusters
        ):  #Using possibleClusters[1:] was a hack to omit cluster 1 which usually contains noise and sometimes don't have spikes in the range for raster plot to run properlyl
            plt.figure(figsize=(8.5, 11))

            #The first noise burst raster plot
            plt.subplot2grid((5, 6), (0, 0), rowspan=1, colspan=3)
            nbIndex = site.get_session_types().index('noiseBurst')
            nbSession = site.get_session_filenames()[nbIndex]

            exp2.plot_session_raster(nbSession,
                                     tetrode,
                                     cluster=cluster,
                                     replace=1)

            plt.ylabel('Noise Bursts')
            plt.title(nbSession, fontsize=10)

            #The laser pulse raster plot
            plt.subplot2grid((5, 6), (1, 0), rowspan=1, colspan=3)
            lpIndex = site.get_session_types().index('laserPulse')
            lpSession = site.get_session_filenames()[lpIndex]
            exp2.plot_session_raster(lpSession,
                                     tetrode,
                                     cluster=cluster,
                                     replace=1)
            plt.ylabel('Laser Pulses')
            plt.title(lpSession, fontsize=10)

            #The laser train raster plot
            plt.subplot2grid((5, 6), (2, 0), rowspan=1, colspan=3)
            try:
                ltIndex = site.get_session_types().index('laserTrain')
                ltSession = site.get_session_filenames()[ltIndex]
                exp2.plot_session_raster(ltSession,
                                         tetrode,
                                         cluster=cluster,
                                         replace=1)
                plt.ylabel('Laser Trains')
                plt.title(ltSession, fontsize=10)
            except ValueError:
                print 'This session doesnot exist.'

            #The tuning curve
            plt.subplot2grid((5, 6), (0, 3), rowspan=3, colspan=3)
            tcIndex = site.get_session_types().index('tuningCurve')
            tcSession = site.get_session_filenames()[tcIndex]
            tcBehavID = site.get_session_behavIDs()[tcIndex]
            exp2.plot_session_tc_heatmap(tcSession,
                                         tetrode,
                                         tcBehavID,
                                         replace=1,
                                         cluster=cluster)
            plt.title("{0}\nBehavFileID = '{1}'".format(tcSession, tcBehavID),
                      fontsize=10)
            '''
            The best freq presentation, if a session is not initialized, a Value            Error is raised  when indexing the list returned by the get_ methods            of ee2.RecordingSite. Could use Try Except ValueError?
            '''
            #plt.subplot2grid((6, 6), (3, 0), rowspan=1, colspan=3)
            #bfIndex = site.get_session_types().index('bestFreq')
            #bfSession = site.get_session_filenames()[bfIndex]
            #exp2.plot_session_raster(bfSession, tetrode, cluster = cluster, replace = 1)
            #plt.ylabel('Best Frequency')
            #plt.title(bfSession, fontsize = 10)

            #FIXME: Omitting the laser pulses at different intensities for now
            '''LG0710: Added reports (ISI, waveform, events in time, projections) for each cluster to its sessions summary graph. The MultiSessionClusterReport class initializer calls plot_report automatically... it's hard to get around doing repeated work (i.e. getting bits and pieces of necessary functionality out of this class manually) here without rewriting this class'''

            nSpikes = len(oneTT.timestamps)
            nClusters = len(possibleClusters)
            spikesEachCluster = np.empty((nClusters, nSpikes), dtype=bool)
            if oneTT.clusters == None:
                oneTT.set_clusters_from_file()
            for indc, clusterID in enumerate(possibleClusters):
                spikesEachCluster[indc, :] = (oneTT.clusters == clusterID)

            tsThisCluster = oneTT.timestamps[spikesEachCluster[indClust, :]]
            wavesThisCluster = oneTT.samples[spikesEachCluster[
                indClust, :], :, :]
            # -- Plot ISI histogram --
            plt.subplot2grid((5, 6), (3, 0), rowspan=1, colspan=3)
            spikesorting.plot_isi_loghist(tsThisCluster)
            plt.ylabel('c%d' % clusterID, rotation=0, va='center', ha='center')

            # -- Plot waveforms --
            plt.subplot2grid((5, 6), (4, 0), rowspan=1, colspan=3)
            spikesorting.plot_waveforms(wavesThisCluster)

            # -- Plot projections --
            plt.subplot2grid((5, 6), (3, 3), rowspan=1, colspan=3)
            spikesorting.plot_projections(wavesThisCluster)

            # -- Plot events in time --
            plt.subplot2grid((5, 6), (4, 3), rowspan=1, colspan=3)
            spikesorting.plot_events_in_time(tsThisCluster)

            #Save the figure in the multisession clustering folder so that it is easy to find
            fig_path = oneTT.clustersDir
            fig_name = 'TT{0}Cluster{1}.png'.format(tetrode, cluster)
            full_fig_path = os.path.join(fig_path, fig_name)
            print full_fig_path
            plt.tight_layout()
            plt.savefig(full_fig_path, format='png')
            #plt.show()
            plt.close()

        plt.figure()
        oneTT.save_multisession_report()
        plt.close()