Esempio n. 1
0
    def plot_complex(self, **kwargs):

        #set keywords to defaults
        kw_params = {
            'start_time': self.start_time,
            'end_time': self.end_time,
            'output_dir': None,
            'include_spec': False,
            'include_ifreq': False,
            'spikes': False,
            'nbands_to_plot': None,
            'sort_code': '0',
            'demodulate': True,
            'phase_only': False,
            'log_amplitude': False
        }
        for key, val in kw_params.items():
            if key not in kwargs:
                kwargs[key] = val

        nbands, nelectrodes, nt = self.X.shape
        start_time = kwargs['start_time']
        end_time = kwargs['end_time']
        duration = end_time - start_time
        output_dir = kwargs['output_dir']

        sr = self.get_sample_rate()
        t1 = int((start_time - self.start_time) * sr)
        d = int((end_time - start_time) * sr)
        t2 = t1 + d
        t = (np.arange(d) / sr) + kwargs['start_time']

        spike_trains = None
        bin_size = 1e-3
        if kwargs['spikes']:
            if self.spike_rasters is None:
                #load up the spike raster for this time slice
                self.spike_rasters = self.experiment.get_spike_slice(
                    self.segment,
                    0.0,
                    self.segment.annotations['duration'],
                    rcg_names=self.rcg_names,
                    as_matrix=False,
                    sort_code=kwargs['sort_code'],
                    bin_size=bin_size)
            if len(self.spike_rasters) > 1:
                print(
                    "WARNING: plot_complex doesn't work well when more than one electrode array is specified."
                )
            spike_trains_full, spike_train_group = self.spike_rasters[
                self.rcg_names[0]]
            #select out the spikes for the interval to plot
            spike_trains = list()
            for st in spike_trains_full:
                sindex = (st >= start_time) & (st <= end_time)
                spike_trains.append(st[sindex])

        colors = np.array([
            [244.0, 244.0, 244.0],  #white
            [241.0, 37.0, 9.0],  #red
            [238.0, 113.0, 25.0],  #orange
            [255.0, 200.0, 8.0],  #yellow
            [19.0, 166.0, 50.0],  #green
            [1.0, 134.0, 141.0],  #blue
            [244.0, 244.0, 244.0],  #white
        ])
        colors /= 255.0

        #get stimulus spectrogram
        stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
            self.segment, kwargs['start_time'], kwargs['end_time'])

        #compute the amplitude, phase, and instantaneous frequency of the complex signal
        amplitude = np.abs(self.Z[:, :, t1:t2])
        phase = np.angle(self.Z[:, :, t1:t2])

        # rescale the amplitude of each electrode so it ranges from 0 to 1
        for k in range(nbands):
            for n in range(nelectrodes):
                amplitude[k, n, :] /= amplitude[k, n].max()

        if kwargs['phase_only']:
            # make sure the amplitude is equal to 1
            nz = amplitude > 0
            amplitude[nz] /= amplitude[nz]

        if kwargs['log_amplitude']:
            nz = amplitude > 0
            amplitude[nz] = np.log10(amplitude[nz])
            amplitude[nz] -= amplitude[nz].min()
            amplitude /= amplitude.max()

        nbands_to_plot = nbands
        if kwargs['nbands_to_plot'] is not None:
            nbands_to_plot = kwargs['nbands_to_plot']

        seg_uname = segment_to_unique_name(self.segment)

        if kwargs['include_ifreq']:
            ##################
            ## make plots of the joint instantaneous frequency per band
            ##################
            rcParams.update({'font.size': 10})
            plt.figure(figsize=(24.0, 13.5))
            plt.subplots_adjust(top=0.98,
                                bottom=0.01,
                                left=0.03,
                                right=0.99,
                                hspace=0.10)
            nsubplots = nbands_to_plot + 2

            #plot the stimulus spectrogram
            ax = plt.subplot(nsubplots, 1, 1)
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.afmhot_r,
                             colorbar=False,
                             fmax=8000.0)
            plt.ylabel('')
            plt.yticks([])

            ifreq = np.zeros([nbands, nelectrodes, d])
            sr = self.get_sample_rate()
            for k in range(nbands):
                for j in range(nelectrodes):
                    ifreq[k, j, :] = compute_instantaneous_frequency(
                        self.Z[k, j, t1:t2], sr)
                    ifreq[k, j, :] = lowpass_filter(ifreq[k, j, :],
                                                    sr,
                                                    cutoff_freq=50.0)

            #plot the instantaneous frequency along with it's amplitude
            for k in range(nbands_to_plot):
                img = np.zeros([nelectrodes, d, 4], dtype='float32')

                ifreq_min = np.percentile(ifreq[k, :, :], 5)
                ifreq_max = np.percentile(ifreq[k, :, :], 95)
                ifreq_dist = ifreq_max - ifreq_min
                #print 'ifreq_max=%0.3f, ifreq_min=%0.3f' % (ifreq_max, ifreq_min)

                for j in range(nelectrodes):
                    max_amp = np.percentile(amplitude[k, j, :], 85)

                    #set the alpha and color for the bins
                    alpha = amplitude[k, j, :] / max_amp
                    alpha[alpha > 1.0] = 1.0  #saturate
                    alpha[alpha < 0.05] = 0.0  #nonlinear threshold

                    cnorm = (ifreq[k, j, :] - ifreq_min) / ifreq_dist
                    cnorm[cnorm > 1.0] = 1.0
                    cnorm[cnorm < 0.0] = 0.0
                    img[j, :, 0] = 1.0 - cnorm
                    img[j, :, 1] = 1.0 - cnorm
                    img[j, :, 2] = 1.0 - cnorm
                    #img[j, :, 3] = alpha
                    img[j, :, 3] = 1.0

                ax = plt.subplot(nsubplots, 1, k + 2)
                ax.set_axis_bgcolor('black')
                im = plt.imshow(img,
                                interpolation='nearest',
                                aspect='auto',
                                origin='upper',
                                extent=[t.min(),
                                        t.max(), 1, nelectrodes])
                plt.axis('tight')
                plt.ylabel('Electrode')
                plt.title('band %d' % (k + 1))
            plt.suptitle('Instantaneous Frequency')

            if output_dir is not None:
                fname = 'band_ifreq_%s_%s_start%0.6f_end%0.6f.png' % (
                    seg_uname, ','.join(self.rcg_names), start_time, end_time)
                plt.savefig(os.path.join(output_dir, fname))

        ##################
        ## make plots of the joint phase per band
        ##################

        def compute_envelope(the_matrix, log=False):
            tm_env = np.abs(the_matrix).sum(axis=0)
            tm_env -= tm_env.min()
            tm_env /= tm_env.max()

            if log:
                nz = tm_env > 0.0
                tm_env[nz] = np.log10(tm_env[nz])
                tm_env_thresh = -np.percentile(np.abs(tm_env[nz]), 95)
                tm_env[~nz] = tm_env_thresh
                tm_env[tm_env <= tm_env_thresh] = tm_env_thresh
                tm_env -= tm_env_thresh
                tm_env /= tm_env.max()

            return tm_env

        #compute the amplitude envelope for the spectrogram
        stim_spec_env = compute_envelope(stim_spec)

        rcParams.update({'font.size': 10})
        fig = plt.figure(figsize=(24.0, 13.5), facecolor='gray')
        plt.subplots_adjust(top=0.98,
                            bottom=0.01,
                            left=0.03,
                            right=0.99,
                            hspace=0.10)
        nsubplots = nbands_to_plot + 1 + int(kwargs['spikes'])

        #plot the stimulus spectrogram
        ax = plt.subplot(nsubplots, 1, 1)
        ax.set_axis_bgcolor('black')
        plot_spectrogram(stim_spec_t,
                         stim_spec_freq,
                         stim_spec,
                         ax=ax,
                         colormap=cm.afmhot,
                         colorbar=False,
                         fmax=8000.0)
        plt.plot(stim_spec_t,
                 stim_spec_env * stim_spec_freq.max(),
                 'w-',
                 linewidth=3.0,
                 alpha=0.75)
        plt.axis('tight')
        plt.ylabel('')
        plt.yticks([])

        #plot the spike raster
        if kwargs['spikes']:
            spike_count_env = spike_envelope(spike_trains,
                                             start_time,
                                             duration,
                                             bin_size=bin_size,
                                             win_size=30.0)
            ax = plt.subplot(nsubplots, 1, 2)
            plot_raster(spike_trains,
                        ax=ax,
                        duration=duration,
                        bin_size=bin_size,
                        time_offset=start_time,
                        ylabel='Cell',
                        bgcolor='k',
                        spike_color='#ff0000')
            tenv = np.arange(len(spike_count_env)) * bin_size + start_time
            plt.plot(tenv,
                     spike_count_env * len(spike_trains),
                     'w-',
                     linewidth=1.5,
                     alpha=0.5)
            plt.axis('tight')
            plt.xticks([])
            plt.yticks([])
            plt.ylabel('Spikes')

        #phase_min = phase.min()
        #phase_max = phase.max()
        #print 'phase_max=%0.3f, phase_min=%0.3f' % (phase_max, phase_min)

        for k in range(nbands_to_plot):

            the_phase = phase[k, :, :]
            if kwargs['demodulate']:
                Ztemp = amplitude[k, :, :] * (
                    np.cos(the_phase) + complex(0, 1) * np.sin(the_phase))
                the_phase, complex_pcs = demodulate(Ztemp, depth=1)
                del Ztemp

            img = make_phase_image(amplitude[k, :, :],
                                   the_phase,
                                   normalize=True,
                                   threshold=True,
                                   saturate=True)
            amp_env = compute_envelope(amplitude[k, :, :], log=False)

            ax = plt.subplot(nsubplots, 1, k + 2 + int(kwargs['spikes']))
            ax.set_axis_bgcolor('black')
            im = plt.imshow(img,
                            interpolation='nearest',
                            aspect='auto',
                            origin='upper',
                            extent=[t.min(), t.max(), 1, nelectrodes])
            if not kwargs['phase_only']:
                plt.plot(t,
                         amp_env * nelectrodes,
                         'w-',
                         linewidth=2.0,
                         alpha=0.75)
            plt.axis('tight')
            plt.ylabel('Electrode')
            plt.title('band %d' % (k + 1))

        plt.suptitle('Phase')

        if output_dir is not None:
            fname = 'band_phase_%s_%s_start%0.6f_end%0.6f.png' % (
                seg_uname, ','.join(self.rcg_names), start_time, end_time)
            plt.savefig(os.path.join(output_dir, fname),
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
            del fig

        if not kwargs['include_spec']:
            return

        ##################
        ## make plots of the band spectrograms
        ##################
        subplot_nrows = 8 + 1
        subplot_ncols = len(self.rcg_names) * 2
        plt.figure()
        plt.subplots_adjust(top=0.95,
                            bottom=0.05,
                            left=0.03,
                            right=0.99,
                            hspace=0.10)

        for k in range(subplot_ncols):
            ax = plt.subplot(subplot_nrows, subplot_ncols, k + 1)
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.gist_yarg,
                             colorbar=False)
            plt.ylabel('')
            plt.yticks([])

        for j in range(nelectrodes):

            plt.figure(figsize=(24.0, 13.5))
            plt.subplots_adjust(top=0.95,
                                bottom=0.05,
                                left=0.03,
                                right=0.99,
                                hspace=0.10)

            electrode = self.index2electrode[j]
            rcg, rc = self.experiment.get_channel(self.segment.block,
                                                  electrode)
            row = rc.annotations['row']
            col = rc.annotations['col']
            if len(self.rcg_names) > 1:
                sp = (row + 1) * subplot_ncols + col + 1
            else:
                sp = (row + 1) * subplot_ncols + (col % 2) + 1

            #ax = plt.subplot(subplot_nrows, subplot_ncols, sp)
            gs = GridSpec(100, 1)
            ax = plt.subplot(gs[:20])
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.gist_yarg,
                             colorbar=False)
            plt.ylabel('')
            plt.yticks([])

            ax = plt.subplot(gs[20:])
            ax.set_axis_bgcolor('black')

            #get the maximum frequency and set the resolution
            max_freq = ifreq[:, j, :].max()
            nf = 150
            df = max_freq / nf

            #create an image to hold the frequencies
            img = np.zeros([nf, ifreq.shape[-1], 4], dtype='float32')

            #fill in the image for each band
            for k in range(nbands):
                max_amp = np.percentile(amplitude[k, j, :], 85)

                freq_bin = (ifreq[k, j, :] / df).astype('int') - 1
                freq_bin[freq_bin < 0] = 0

                #set the color and alpha for the bins
                alpha = amplitude[k, j, :] / max_amp
                alpha[alpha > 1.0] = 1.0  #saturate
                alpha[alpha < 0.05] = 0.0  #nonlinear threshold

                for m, fbin in enumerate(freq_bin):
                    #print 'm=%d, fbin=%d, colors[k, :].shape=%s' % (m, fbin, str(colors[k, :].shape))
                    img[fbin, m, :3] = colors[k, :]
                    img[fbin, m, 3] = alpha[m]

            #plot the image
            im = plt.imshow(img,
                            interpolation='nearest',
                            aspect='auto',
                            origin='lower',
                            extent=[t.min(), t.max(), 0.0, max_freq])
            plt.ylabel('E%d' % electrode)
            plt.axis('tight')
            plt.ylim(0.0, 140.0)

            if output_dir is not None:
                fname = 'band_spec_e%d_%s_%s_start%0.6f_end%0.6f.png' % (
                    electrode, seg_uname, ','.join(
                        self.rcg_names), start_time, end_time)
                plt.savefig(os.path.join(output_dir, fname))
                plt.close('all')

        if output_dir is None:
            plt.show()
Esempio n. 2
0
def plot_single_trial_data(d, syllable_index, trial_index):

    syllable_start = d['syllable_props'][syllable_index]['start_time'] - 0.030
    syllable_end = d['syllable_props'][syllable_index]['end_time'] + 0.030

    # set the figure width proportional to the length of the syllable for uniformity across stimuli
    max_fig_width = 12.0
    max_stim_duration = 2.5

    fig_width = (
        (syllable_end - syllable_start) / max_stim_duration) * max_fig_width

    figsize = (fig_width, 10)
    fig = plt.figure(figsize=figsize, facecolor='w')
    fig.subplots_adjust(top=0.95,
                        bottom=0.02,
                        right=0.97,
                        left=0.03,
                        hspace=0.20,
                        wspace=0.20)

    gs = plt.GridSpec(100, 1)

    # plot the biosound features
    ax = plt.subplot(gs[:10])
    sprops = d['syllable_props'][syllable_index]
    aprops = USED_ACOUSTIC_PROPS

    vals = [sprops[a] for a in aprops]
    plt.axhline(0, c='k')
    for k, (aprop, v) in enumerate(zip(aprops, vals)):
        bx = k
        rgb = np.array(ACOUSTIC_PROP_COLORS_BY_TYPE[aprop]).astype('int')
        clr_hex = to_hex(*rgb)
        plt.bar(bx, v, color=clr_hex, alpha=0.7)
    ax.xaxis.tick_top()
    # plt.xticks(range(len(aprops)), aprops, rotation=45, fontsize=6)
    plt.xticks([])

    # plot the spectrogram
    ax = plt.subplot(gs[15:40])
    spec = d['spec']
    spec[spec < np.percentile(spec, 15)] = 0

    # the spectogram is already log transformed, make sure log=False and dBNoise=None
    plot_spectrogram(d['spec_t'],
                     d['spec_freq'] * 1e-3,
                     spec,
                     ax=ax,
                     colormap='SpectroColorMap',
                     ticks=False,
                     log=False,
                     dBNoise=None,
                     colorbar=False)

    plt.xlim(syllable_start, syllable_end)

    # plot the raw LFP
    ax = plt.subplot(gs[45:70])

    sr = d['lfp_sample_rate']
    raw_lfp = d['lfp'][trial_index, :, :]
    lfp_t = np.arange(raw_lfp.shape[1]) / sr
    nelectrodes, nt = raw_lfp.shape
    lfp_i = (lfp_t >= syllable_start) & (lfp_t <= syllable_end)

    voffset = 5
    for n in range(nelectrodes):
        plt.plot(lfp_t[lfp_i],
                 raw_lfp[nelectrodes - n - 1, :][lfp_i] + voffset * n,
                 'k-',
                 linewidth=2.0,
                 alpha=0.75)
    plt.axis('tight')
    ytick_locs = np.arange(nelectrodes) * voffset
    plt.yticks(ytick_locs, list(reversed(d['electrode_order'])))
    plt.xticks([])

    # plot the spike train raster
    ax = plt.subplot(gs[75:])

    spike_mat = d[
        'spikes']  # list-of-lists with shape (num_trials, num_neurons)
    print('# of neurons: ', len(spike_mat))
    raw_spikes = list()
    for spike_train in spike_mat[trial_index]:
        i = (spike_train >= syllable_start) & (spike_train <= syllable_end)
        raw_spikes.append(spike_train[i] - syllable_start)
    plt.xticks([])

    plot_raster(raw_spikes,
                ax=ax,
                duration=syllable_end - syllable_start,
                bin_size=0.001,
                time_offset=0.0,
                ylabel='',
                groups=None,
                bgcolor=None,
                spike_color='k')
    plt.xticks([])
Esempio n. 3
0
    def testFFT(self):

        sr = 1000.
        freqs = [35.]
        dur = 0.500
        nt = int(dur*sr)
        t = np.arange(nt) / sr

        # create a psth that has the specific frequencies
        psth = np.zeros([nt])

        for f in freqs:
            psth += np.sin(2*np.pi*f*t)

        max_spike_rate = 0.1
        psth /= psth.max()
        psth += 1.
        psth /= 2.0
        psth *= max_spike_rate

        # simulate a spike train with a variety of frequencies in it
        trials = simulate_poisson(psth, dur, num_trials=10)

        bin_size = 0.001
        binned_trials = spike_trains_to_matrix(trials, bin_size, 0.0, dur)

        mean_psth = binned_trials.mean(axis=0)

        # compute the power spectrum of each spike train
        psds = list()
        pfreq = None
        win_len = 0.090
        inc = 0.010
        for st in binned_trials:
            pfreq,psd,ps_var,phase = power_spectrum_jn(st, 1.0 / bin_size, win_len, inc)

            nz = psd > 0
            psd[nz] = 20*np.log10(psd[nz]) + 100
            psd[psd < 0] = 0

            psds.append(psd)

        psds = np.array(psds)
        mean_psd = psds.mean(axis=0)

        pfreq,mean_psd2,ps_var,phase = power_spectrum_jn(mean_psth, 1.0/bin_size, win_len, inc)
        nz = mean_psd2 > 0
        mean_psd2[nz] = 20*np.log10(mean_psd2[nz]) + 100
        mean_psd2[mean_psd2 < 0] = 0

        plt.figure()

        ax = plt.subplot(2, 1, 1)
        plot_raster(trials, ax=ax, duration=dur, bin_size=0.001, time_offset=0.0, ylabel='Trial #', bgcolor=None, spike_color='k')

        ax = plt.subplot(2, 1, 2)
        plt.plot(pfreq, mean_psd, 'k-', linewidth=3.0)
        for psd in psds:
            plt.plot(pfreq, psd, '-', linewidth=2.0, alpha=0.75)

        plt.plot(pfreq, mean_psd2, 'k--', linewidth=3.0, alpha=0.60)
        plt.axis('tight')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power (dB)')
        plt.xlim(0, 100.)

        plt.show()
Esempio n. 4
0
    def transform(self,
                  stim_event,
                  lags=np.arange(-10, 11, 1),
                  min_syllable_dur=0.050,
                  post_syllable_dur=0.030,
                  rep_type='raw',
                  debug=False,
                  window_fraction=0.60,
                  noise_db=25.):

        assert isinstance(stim_event, StimEventTransform)
        assert rep_type in stim_event.lfp_reps_by_stim
        self.rep_type = rep_type
        self.bird = stim_event.bird

        self.segment_uname = stim_event.seg_uname
        self.rcg_names = stim_event.rcg_names

        self.lags = (lags / stim_event.lfp_sample_rate) * 1e3
        stim_ids = list(stim_event.lfp_reps_by_stim[rep_type].keys())

        all_psds = list()
        all_cross_cfs = list()

        all_spike_rates = list()
        all_spike_sync = list()

        # zscore the LFPs
        stim_event.zscore(rep_type)

        data = {
            'stim_id': list(),
            'stim_type': list(),
            'stim_duration': list(),
            'order': list(),
            'decomp': list(),
            'electrode1': list(),
            'electrode2': list(),
            'region1': list(),
            'region2': list(),
            'cell_index': list(),
            'cell_index2': list(),
            'index': list()
        }

        # get map of electrode indices to region
        index2region = list()
        for e in stim_event.index2electrode:
            i = stim_event.electrode_data['electrode'] == e
            index2region.append(stim_event.electrode_data['region'][i][0])
        print('index2region=', index2region)

        # map cell indices to electrodes
        ncells = len(stim_event.cell_df)
        if ncells > 0:
            print('ncells=%d' % ncells)
            cell_index2electrode = [0] * ncells
            assert len(stim_event.cell_df.sort_code.unique()) == 1
            for ci, e in zip(stim_event.cell_df['index'],
                             stim_event.cell_df['electrode']):
                cell_index2electrode[ci] = e
        else:
            cell_index2electrode = [-1]

        print('len(cell_index2electrode)=%d' % len(cell_index2electrode))

        # make a list of all valid syllables for each stimulus
        num_valid_syllables_per_type = dict()
        stim_syllables = dict()
        lags_max = np.abs(lags).max() / stim_event.lfp_sample_rate
        good_stim_ids = list()
        for stim_id in stim_ids:
            seg_times = list()
            i = stim_event.segment_df['stim_id'] == stim_id
            if i.sum() == 0:
                print('Missing stim information for stim %d!' % stim_id)
                continue

            stype = stim_event.segment_df['stim_type'][i].values[0]
            for k, (stime, etime, order) in enumerate(
                    zip(stim_event.segment_df['start_time'][i],
                        stim_event.segment_df['end_time'][i],
                        stim_event.segment_df['order'][i])):
                dur = (etime - stime) + post_syllable_dur
                if dur < min_syllable_dur:
                    continue

                # make sure the duration is long enough to support the lags
                assert dur > lags_max, "Lags is too long, duration=%0.3f, lags_max=%0.3f" % (
                    dur, lags_max)

                # add the syllable to the list, add in the extra post-stimulus time
                seg_times.append((stime, etime + post_syllable_dur, order))
                if stype not in num_valid_syllables_per_type:
                    num_valid_syllables_per_type[stype] = 0
                num_valid_syllables_per_type[stype] += 1

            if len(seg_times) > 0:
                stim_syllables[stim_id] = np.array(seg_times)
                good_stim_ids.append(stim_id)

        print('# of syllables per category:')
        for stype, nstype in list(num_valid_syllables_per_type.items()):
            print('%s: %d' % (stype, nstype))

        # specify the window size for the auto-spectra and cross-coherence. the minimum segment size is 80ms
        psd_window_size = 0.060
        psd_increment = 2 / stim_event.lfp_sample_rate

        for stim_id in good_stim_ids:
            # get stim type
            i = stim_event.trial_df['stim_id'] == stim_id
            stim_type = stim_event.trial_df['stim_type'][i].values[0]

            print('Computing CFs for stim %d (%s)' % (stim_id, stim_type))

            # get the raw LFP
            X = stim_event.lfp_reps_by_stim[rep_type][stim_id]
            ntrials, nelectrodes, nt = X.shape

            # get the spike trains, a ragged array of shape (num_trials, num_cells, num_spikes)
            # The important thing to know about the spike times is that they are with respect
            # to the stimulus onset. Negative spike times occur prior to stimulus onset.
            spike_mat = stim_event.spikes_by_stim[stim_id]
            assert ntrials == len(
                spike_mat), "Weird number of trials in spike_mat: %d" % (
                    len(spike_mat))

            # get segment data for this stim, start and end times of segments
            seg_times = stim_syllables[stim_id]

            # go through each syllable of the stimulus
            for stime, etime, order in seg_times:
                # compute the start and end indices of the LFP for this syllable, keeping in mind that the LFP
                # segment for this stimulus includes the pre and post stim times.
                lfp_syllable_start = (stim_event.pre_stim_time + stime)
                lfp_syllable_end = (stim_event.pre_stim_time + etime)
                si = int(lfp_syllable_start * stim_event.lfp_sample_rate)
                ei = int(lfp_syllable_end * stim_event.lfp_sample_rate)

                stim_dur = etime - stime

                # because spike times are with respect to the onset of the first syllable
                # of this stimulus, stime and etime define the appropriate window of analysis
                # for this syllable for spike times.
                spike_syllable_start = stime
                spike_syllable_end = etime
                spike_syllable_dur = etime - stime

                if debug:
                    # get the spectrogram, lfp and spikes for the stimulus
                    stim_spec = stim_event.spec_by_stim[stim_id]
                    syllable_spec = stim_spec[:, si:ei]
                    the_lfp = X[:, :, si:ei]
                    the_spikes = list()
                    lfp_t = np.arange(
                        the_lfp.shape[-1]) / stim_event.lfp_sample_rate
                    spec_t = np.arange(
                        syllable_spec.shape[1]) / stim_event.lfp_sample_rate
                    spec_freq = stim_event.spec_freq

                    for k in range(ntrials):
                        the_cell_spikes = list()
                        for n in range(ncells):
                            st = spike_mat[k][n]
                            i = (st >= stime) & (st <= etime)
                            print(
                                'trial=%d, cell=%d, nspikes=%d, (%0.3f, %0.3f)'
                                % (k, n, i.sum(), spike_syllable_start,
                                   spike_syllable_end))
                            print('st=', st)
                            the_cell_spikes.append(st[i] - stime)
                        the_spikes.append(the_cell_spikes)

                    # make some plots to check the raw data, left hand plot is LFP, right hand is spikes
                    nrows = ntrials + 1
                    plt.figure()
                    gs = plt.GridSpec(nrows, 2)

                    ax = plt.subplot(gs[0, 0])
                    plot_spectrogram(spec_t,
                                     spec_freq,
                                     syllable_spec,
                                     ax=ax,
                                     colormap='gray',
                                     colorbar=False)

                    ax = plt.subplot(gs[0, 1])
                    plot_spectrogram(spec_t,
                                     spec_freq,
                                     syllable_spec,
                                     ax=ax,
                                     colormap='gray',
                                     colorbar=False)

                    for k in range(ntrials):
                        ax = plt.subplot(gs[k + 1, 0])
                        lllfp = the_lfp[k, :, :]
                        absmax = np.abs(lllfp).max()
                        plt.imshow(
                            lllfp,
                            interpolation='nearest',
                            aspect='auto',
                            cmap=plt.cm.seismic,
                            vmin=-absmax,
                            vmax=absmax,
                            origin='lower',
                            extent=[lfp_t.min(),
                                    lfp_t.max(), 1, nelectrodes])

                        ax = plt.subplot(gs[k + 1, 1])
                        plot_raster(the_spikes[k],
                                    ax=ax,
                                    duration=spike_syllable_end -
                                    spike_syllable_start,
                                    time_offset=0,
                                    ylabel='')

                    plt.title('stim %d, syllable %d' % (stim_id, order))
                    plt.show()

                # compute the LFP props
                lfp_props = self.compute_lfp_spectra_and_cfs(
                    X[:, :, si:ei], stim_event.lfp_sample_rate, lags,
                    psd_window_size, psd_increment, window_fraction, noise_db)

                # save the power spectra to the data frame and data matrix
                decomp_types = ['trial_avg', 'mean_sub', 'full', 'onewin']
                for n, e in enumerate(stim_event.index2electrode):
                    for decomp in decomp_types:

                        data['stim_id'].append(stim_id)
                        data['stim_type'].append(stim_type)
                        data['order'].append(order)
                        data['decomp'].append(decomp)
                        data['electrode1'].append(e)
                        data['electrode2'].append(e)
                        data['region1'].append(index2region[n])
                        data['region2'].append(index2region[n])
                        data['cell_index'].append(-1)
                        data['cell_index2'].append(-1)
                        data['index'].append(len(all_psds))
                        data['stim_duration'].append(stim_dur)

                        pkey = '%s_psds' % decomp
                        all_psds.append(lfp_props[pkey][n])

                # save the cross terms to the data frame and data matrix
                for n1, e1 in enumerate(stim_event.index2electrode):
                    for n2 in range(n1):
                        e2 = stim_event.index2electrode[n2]

                        # get index of cfs for this electrode pair
                        pi = lfp_props['cross_electrodes'].index((n1, n2))
                        for decomp in decomp_types:

                            if decomp == 'onewin':
                                continue

                            data['stim_id'].append(stim_id)
                            data['stim_type'].append(stim_type)
                            data['order'].append(order)
                            data['decomp'].append(decomp)
                            data['electrode1'].append(e1)
                            data['electrode2'].append(e2)
                            data['region1'].append(index2region[n1])
                            data['region2'].append(index2region[n2])
                            data['cell_index'].append(-1)
                            data['cell_index2'].append(-1)
                            data['index'].append(len(all_cross_cfs))
                            data['stim_duration'].append(stim_dur)

                            pkey = '%s_cfs' % decomp
                            all_cross_cfs.append(lfp_props[pkey][pi])

                if len(cell_index2electrode
                       ) == 1 and cell_index2electrode[0] == -1:
                    continue

                # compute the spike rate vector for each neuron
                for ci, e in enumerate(cell_index2electrode):
                    rates = list()
                    for k in range(ntrials):
                        st = spike_mat[k][ci]

                        i = (st >= spike_syllable_start) & (st <=
                                                            spike_syllable_end)
                        r = i.sum() / (spike_syllable_end -
                                       spike_syllable_start)
                        rates.append(r)

                    rates = np.array(rates)
                    rv = [rates.mean(), rates.std(ddof=1)]

                    data['stim_id'].append(stim_id)
                    data['stim_type'].append(stim_type)
                    data['order'].append(order)
                    data['decomp'].append('spike_rate')
                    data['electrode1'].append(e)
                    data['electrode2'].append(e)
                    data['region1'].append(index2region[n])
                    data['region2'].append(index2region[n])
                    data['cell_index'].append(ci)
                    data['cell_index2'].append(ci)
                    data['index'].append(len(all_spike_rates))
                    data['stim_duration'].append(stim_dur)

                    all_spike_rates.append(rv)

                # compute the pairwise synchrony between spike trains
                spike_sync = np.zeros([ntrials, ncells, ncells])
                for k in range(ntrials):
                    for ci1, e1 in enumerate(cell_index2electrode):
                        st1 = spike_mat[k][ci1]
                        if len(st1) == 0:
                            continue

                        for ci2 in range(ci1):
                            st2 = spike_mat[k][ci2]
                            if len(st2) == 0:
                                continue
                            sync12 = simple_synchrony(st1,
                                                      st2,
                                                      spike_syllable_dur,
                                                      bin_size=3e-3)
                            spike_sync[k, ci1, ci2] = sync12
                            spike_sync[k, ci2, ci1] = sync12

                # average pairwise synchrony across trials
                spike_sync_avg = spike_sync.mean(axis=0)

                # save the spike sync data into the data frame
                for ci1, e1 in enumerate(cell_index2electrode):
                    n1 = stim_event.index2electrode.index(e1)
                    for ci2 in range(ci1):
                        e2 = cell_index2electrode[ci2]
                        n2 = stim_event.index2electrode.index(e2)

                        data['stim_id'].append(stim_id)
                        data['stim_type'].append(stim_type)
                        data['order'].append(order)
                        data['decomp'].append('spike_sync')
                        data['electrode1'].append(e1)
                        data['electrode2'].append(e2)
                        data['region1'].append(index2region[n1])
                        data['region2'].append(index2region[n2])
                        data['cell_index'].append(ci1)
                        data['cell_index2'].append(ci2)
                        data['index'].append(len(all_spike_sync))
                        data['stim_duration'].append(stim_dur)

                        all_spike_sync.append(spike_sync_avg[ci1, ci2])

        self.cell_index2electrode = cell_index2electrode

        self.psds = np.array(all_psds)
        self.cross_cfs = np.array(all_cross_cfs)

        self.spike_rate = np.array(all_spike_rates)
        self.spike_synchrony = np.array(all_spike_sync)

        self.data = data
        self.df = pd.DataFrame(self.data)
Esempio n. 5
0
    def plot(self, start_time, end_time):

        for seg_uname, events in list(self.events.items()):

            durations = events[:, 1] - events[:, 0]
            amplitudes = events[:, -1]
            event_start_times = events[:, 0]
            event_end_times = events[:, 1]

            inter_event_intervals = event_start_times[1:] - event_end_times[:-1]

            # print some event info
            print('Segment: %s' % seg_uname)
            print('\tIdentified %d events' % len(durations))
            print(
                '\tDuration: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                (durations.min(), durations.max(), durations.mean(),
                 np.median(durations)))
            print('\tIEI: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                  (inter_event_intervals.min(), inter_event_intervals.max(),
                   inter_event_intervals.mean(),
                   np.median(inter_event_intervals)))
            print(
                '\tAmplitude: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                (amplitudes.min(), amplitudes.max(), amplitudes.mean(),
                 np.median(amplitudes)))

            # plot some event statistics
            plt.figure()
            plt.subplot(3, 1, 1)
            plt.hist(durations, bins=100, color='r')
            plt.axis('tight')
            plt.title('Event Duration (s)')

            plt.subplot(3, 1, 2)
            plt.hist(inter_event_intervals, bins=100, color='k')
            plt.axis('tight')
            plt.title('Inter-event Interval (s)')

            plt.subplot(3, 1, 3)
            plt.hist(amplitudes, bins=100, color='g')
            plt.title('Peak Amplitude')
            plt.suptitle(seg_uname)
            plt.axis('tight')

            # plot event duration vs amplitude histogram
            plt.figure()
            plt.hist2d(durations,
                       amplitudes,
                       bins=[40, 30],
                       cmap=cm.Greys,
                       norm=LogNorm())
            plt.xlim(0, 2.)
            plt.xlabel('Duration')
            plt.ylabel('Amplitude')
            plt.colorbar(label="# of Joint Events")
            plt.suptitle(seg_uname)

            if self.experiment is not None:
                # get the segment
                seg = segment_from_unique_name(self.experiment.blocks,
                                               seg_uname)

                # get a slice of the spike raster
                spike_rasters = self.experiment.get_spike_slice(
                    seg,
                    start_time,
                    end_time,
                    rcg_names=self.rcg_names,
                    as_matrix=False,
                    sort_code=self.sort_code,
                    bin_size=self.bin_size)
                spike_trains, spike_train_group = spike_rasters[
                    self.rcg_names[0]]

                # get the spike envelope for this time slice
                si = int(start_time / self.bin_size)
                ei = int(end_time / self.bin_size)
                spike_env = self.envelopes[seg_uname][si:ei]

                # get the stim envelope for this time slice
                stim_env = self.spec_envelopes[seg_uname][si:ei]

                # get stimulus spectrogram
                stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
                    seg, start_time, end_time)

                # compute the amplitude envelope for the spectrogram
                stim_spec_env = stim_envelope(stim_spec)

                plt.figure()
                plt.suptitle(seg_uname)

                # make a plot of the spectrogram
                ax = plt.subplot(3, 1, 1)
                ax.set_axis_bgcolor('black')
                plot_spectrogram(stim_spec_t,
                                 stim_spec_freq,
                                 stim_spec,
                                 ax=ax,
                                 colormap=cm.afmhot,
                                 colorbar=False,
                                 fmax=8000.0)
                plt.plot(stim_spec_t,
                         stim_spec_env * stim_spec_freq.max(),
                         'w-',
                         linewidth=3.0,
                         alpha=0.75)
                plt.axis('tight')
                plt.ylabel('')
                plt.yticks([])

                # plot the stimulus amplitude envelope
                tenv = np.arange(len(spike_env)) * self.bin_size + start_time
                ax = plt.subplot(3, 1, 2)
                ax.set_axis_bgcolor('black')
                plt.plot(tenv, stim_env, 'w-', linewidth=2.0)
                plt.axis('tight')

                # make a plot of the spike raster, the envelope, and the events
                ax = plt.subplot(3, 1, 3)
                plot_raster(spike_trains,
                            ax=ax,
                            duration=end_time - start_time,
                            bin_size=self.bin_size,
                            time_offset=start_time,
                            ylabel='',
                            bgcolor='k',
                            spike_color='#ff0000')
                plt.plot(tenv, spike_env * len(spike_trains), 'w-', alpha=0.75)

                # plot start events that are within plotting range
                sei = (event_start_times >= start_time) & (event_start_times <=
                                                           end_time)
                plt.plot(event_start_times[sei], [1] * sei.sum(),
                         'g^',
                         markersize=10)

                # plot end events that are within plotting range
                eei = (event_end_times >= start_time) & (event_end_times <=
                                                         end_time)
                plt.plot(event_end_times[eei], [1] * eei.sum(),
                         'bv',
                         markersize=10)

                plt.axis('tight')
                plt.yticks([])
                plt.ylabel('Spikes')
    def plot_segment_slice(self,
                           seg,
                           start_time,
                           end_time,
                           output_dir=None,
                           sort_code='0'):
        """ Plot all the data for a given slice of time in a segment. This function will plot the spectrogram,
            multi-electrode spike trains, and LFPs for each array all on one figure.

        :param seg:
        :param start_time:
        :param end_time:
        :return:
        """

        num_arrays = len(seg.block.recordingchannelgroups)

        nrows = 1 + 2 * num_arrays
        ncols = 1

        plt.figure()

        #get the slice of spectrogram to plot
        stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
            seg, start_time, end_time)

        #get the spikes to plot
        spikes = self.experiment.get_spike_slice(seg,
                                                 start_time,
                                                 end_time,
                                                 sort_code=sort_code)

        #get the multielectrode LFPs for each channel
        lfps = self.experiment.get_lfp_slice(seg, start_time, end_time)

        #plot the spectrogram
        ax = plt.subplot(nrows, ncols, 1)
        plot_spectrogram(stim_spec_t,
                         stim_spec_freq,
                         stim_spec,
                         ax=ax,
                         colormap=cm.afmhot_r,
                         colorbar=False,
                         fmax=8000.0)
        plt.ylabel('')
        plt.yticks([])

        #plot the LFPs and spikes for each recording array
        for k, rcg in enumerate(seg.block.recordingchannelgroups):
            spike_trains, spike_train_groups = spikes[rcg.name]
            ax = plt.subplot(nrows, ncols, 2 + k * 2)
            plot_raster(spike_trains,
                        ax=ax,
                        duration=end_time - start_time,
                        bin_size=0.001,
                        time_offset=start_time,
                        ylabel='Electrode',
                        groups=spike_train_groups)

            electrode_indices, lfp_matrix, sample_rate = lfps[rcg.name]
            nelectrodes = len(electrode_indices)

            #zscore the LFP
            LFPmean = lfp_matrix.T.mean(axis=0)
            LFPstd = lfp_matrix.T.std(axis=0, ddof=1)
            nz = LFPstd > 0.0
            lfp_matrix.T[:, nz] -= LFPmean[nz]
            lfp_matrix.T[:, nz] /= LFPstd[nz]

            ax = plt.subplot(nrows, ncols, 2 + k * 2 + 1)
            plt.imshow(lfp_matrix,
                       interpolation='nearest',
                       aspect='auto',
                       cmap=cm.seismic,
                       extent=[start_time, end_time, 1, nelectrodes])
            lbls = ['%d' % e for e in electrode_indices]
            lbls.reverse()
            plt.yticks(range(nelectrodes), lbls)
            plt.axis('tight')
            plt.ylabel('Electrode')