コード例 #1
0
def compare_timefreqs(s, sample_rate, win_sizes=[0.050, 0.100, 0.250, 0.500, 1.25]):
    """
        Compare the time frequency representation of a signal using different window sizes and estimators.
    """

    #construct different types of estimators
    gaussian_est = GaussianSpectrumEstimator(nstd=6)
    mt_est_lowbw = MultiTaperSpectrumEstimator(bandwidth=10.0, adaptive=False)
    mt_est_lowbw_adapt = MultiTaperSpectrumEstimator(bandwidth=10.0, adaptive=True, max_adaptive_iter=150)
    mt_est_lowbw_jn = MultiTaperSpectrumEstimator(bandwidth=10.0, adaptive=False, jackknife=True)
    mt_est_highbw = MultiTaperSpectrumEstimator(bandwidth=30.0, adaptive=False)
    mt_est_highbw_adapt = MultiTaperSpectrumEstimator(bandwidth=30.0, adaptive=True, max_adaptive_iter=150)
    mt_est_highbw_jn = MultiTaperSpectrumEstimator(bandwidth=30.0, adaptive=False, jackknife=True)
    wavelet = WaveletSpectrumEstimator(num_cycles_per_window=10, min_freq=1, max_freq=sample_rate/2, num_freqs=50, nstd=6)
    #estimators = [gaussian_est, mt_est_lowbw, mt_est_lowbw_adapt, mt_est_highbw, mt_est_highbw_adapt]
    estimators = [wavelet]
    #enames = ['gauss', 'lowbw', 'lowbw_a', 'highbw', 'highbw_a']
    enames = ['wavelet']

    #run each estimator for each window size and plot the amplitude of the time frequency representation
    plt.figure()
    spnum = 1
    for k,win_size in enumerate(win_sizes):
        increment = 1.0 / sample_rate
        for j,est in enumerate(estimators):
            t,freq,tf = timefreq(s, sample_rate, win_size, increment, est)
            print('freq=',freq)
            ax = plt.subplot(len(win_sizes), len(estimators), spnum)
            plot_spectrogram(t, freq, np.abs(tf), ax=ax, colorbar=True, ticks=True)
            if k == 0:
                plt.title(enames[j])
            #if j == 0:
                #plt.ylabel('%d ms' % (win_size*1000))
            spnum += 1
コード例 #2
0
    def test_delta(self):

        dur = 30.
        sample_rate = 1e3
        nt = int(dur * sample_rate)
        t = np.arange(nt) / sample_rate
        freqs = np.linspace(0.5, 1.5, nt)
        # freqs = np.ones_like(t)*2.
        s = np.sin(2 * np.pi * freqs * t)

        center_freqs = np.arange(0.5, 4.5, 0.5)

        psi = lambda _t, _f, _bw: (np.pi * _bw**2)**(-0.5) * np.exp(
            2 * np.pi * complex(0, 1) * _f * _t) * np.exp(-_t**2 / _bw**2)
        """
        scalogram = np.zeros([len(center_freqs), nt])
        bandwidth = 1.
        nstd = 6
        nwt = int(bandwidth*nstd*sample_rate)
        wt = np.arange(nwt) / sample_rate

        for k,f in enumerate(center_freqs):
            w = psi(wt, f, bandwidth)
            scalogram[k, :] = convolve1d(s, w)
        """

        win_len = 2.
        spec_t, spec_freq, spec, spec_rms = gaussian_stft(
            s, sample_rate, win_len, 100e-3)

        fi = (spec_freq < 10) & (spec_freq > 0)

        plt.figure()
        gs = plt.GridSpec(100, 1)
        ax = plt.subplot(gs[:30, 0])
        plt.plot(t, s, 'k-', linewidth=4.0, alpha=0.7)

        wa = WaveletAnalysis(s, dt=1. / sample_rate, frequency=True)

        ax = plt.subplot(gs[35:, 0])
        power = wa.wavelet_power
        scales = wa.scales
        t = wa.time
        T, S = np.meshgrid(t, scales)
        # ax.contourf(T, S, power, 100)
        # ax.set_yscale('log')
        # plt.imshow(np.abs(scalogram)**2, interpolation='nearest', aspect='auto', cmap=plt.cm.afmhot_r, origin='lower',
        #            extent=[t.min(), t.max(), min(center_freqs), max(center_freqs)])
        plot_spectrogram(spec_t,
                         spec_freq[fi],
                         np.abs(spec[fi, :])**2,
                         ax=ax,
                         colorbar=False,
                         colormap=plt.cm.afmhot_r)
        plt.plot(t, freqs, 'k-', alpha=0.7, linewidth=4.0)
        plt.axis('tight')
        plt.show()
コード例 #3
0
 def _render():
     plt.sca(spec_ax)
     plt.cla()
     plot_spectrogram(spec_t,
                      spec_freq,
                      spec,
                      ax=spec_ax,
                      colorbar=False,
                      fmin=300.,
                      fmax=8000.,
                      colormap='SpectroColorMap')
     plt.plot(wave_t, amp_env * 8000, 'k-', linewidth=3.0, alpha=0.7)
     plt.axis('tight')
     for k, cp in enumerate(click_points):
         snum = int(k / 2)
         plt.axvline(cp, c='k', linewidth=2.0, alpha=0.8)
         plt.text(cp, 7000., str(snum), fontsize=14)
     plt.draw()
コード例 #4
0
    def transform(self,
                  experiment,
                  stim_types_to_segment=('Ag', 'Di', 'Be', 'DC', 'Te', 'Ne',
                                         'LT', 'Th', 'song'),
                  plot=False,
                  excluded_types=tuple()):

        assert isinstance(
            experiment, Experiment
        ), 'experiment argument must be an instance of class Experiment!'

        self.bird = experiment.bird_name
        all_stim_ids = list()
        # iterate through the segments and get the stim ids from each epoch table
        for seg in experiment.get_all_segments():
            etable = experiment.get_epoch_table(seg)
            stim_ids = etable['id'].unique()
            all_stim_ids.extend(stim_ids)

        stim_ids = np.unique(all_stim_ids)

        stim_data = {
            'stim_id': list(),
            'stim_type': list(),
            'start_time': list(),
            'end_time': list(),
            'order': list()
        }

        for aprop in self.acoustic_props:
            stim_data[aprop] = list()

        # specify type-specific thresholds for segmentation
        seg_params = {
            'default': {
                'min_thresh': 0.05,
                'max_thresh': 0.25
            },
            'Ag': {
                'min_thresh': 0.05,
                'max_thresh': 0.10
            },
            'song': {
                'min_thresh': 0.05,
                'max_thresh': 0.10
            },
            'Di': {
                'min_thresh': 0.15,
                'max_thresh': 0.20
            }
        }

        for stim_id in stim_ids:

            print('Transforming stimulus {}'.format(stim_id))

            # get sound type
            si = experiment.stim_table['id'] == str(stim_id)
            assert si.sum(
            ) == 1, "Zero or more than one stimulus defined for id=%d, (si.sum()=%d)" % (
                stim_id, si.sum())
            stim_type = experiment.stim_table['type'][si].values[0]
            if stim_type == 'call':
                stim_type = experiment.stim_table['callid'][si].values[0]

            if stim_type in excluded_types:
                continue

            # get the stimulus waveform and sample rate
            sound = experiment.sound_manager.reconstruct(stim_id)
            waveform = np.array(sound.squeeze())
            sample_rate = float(sound.samplerate)
            stim_dur = len(waveform) / sample_rate

            if stim_type in stim_types_to_segment:
                # compute the spectrogram of the stim
                spec_sample_rate = 1000.
                spec_t, spec_freq, spec_stft, spec_rms = gaussian_stft(
                    waveform, sample_rate, 0.007, 1.0 / spec_sample_rate)
                spec = np.abs(spec_stft)
                nz = spec > 0
                spec[nz] = 20 * np.log10(spec[nz]) + 50
                spec[spec < 0] = 0

                # compute the amplitude envelope
                amp_env = spec_rms
                amp_env -= amp_env.min()
                amp_env /= amp_env.max()

                # segment the amplitude envelope
                minimum_isi = int(4e-3 * spec_sample_rate)
                if stim_type in seg_params:
                    min_thresh = seg_params[stim_type]['min_thresh']
                    max_thresh = seg_params[stim_type]['max_thresh']
                else:
                    min_thresh = seg_params['default']['min_thresh']
                    max_thresh = seg_params['default']['max_thresh']
                syllable_times = break_envelope_into_events(
                    amp_env,
                    threshold=min_thresh,
                    merge_thresh=minimum_isi,
                    max_amp_thresh=max_thresh)

                if plot:
                    plt.figure()
                    ax = plt.subplot(111)
                    plot_spectrogram(spec_t,
                                     spec_freq,
                                     np.abs(spec),
                                     ax=ax,
                                     fmin=300.0,
                                     fmax=8000.0,
                                     colormap=plt.cm.afmhot,
                                     colorbar=False)
                    sfd = spec_freq.max() - spec_freq.min()
                    amp_env *= sfd
                    amp_env += spec_freq.min()

                    tline = sfd * min_thresh + amp_env.min()
                    tline2 = sfd * max_thresh + amp_env.min()
                    plt.axhline(tline, c='w', alpha=0.50)
                    plt.axhline(tline2, c='w', alpha=0.50)

                    plt.plot(spec_t, amp_env, 'w-', linewidth=2.0, alpha=0.75)
                    for k, (si, ei, max_amp) in enumerate(syllable_times):
                        plt.plot(spec_t[si], 0, 'go', markersize=8)
                        plt.plot(spec_t[ei], 0, 'ro', markersize=8)
                    plt.title('stim %d, %s, minimum_isi=%d' %
                              (stim_id, stim_type, minimum_isi))
                    plt.axis('tight')
                    plt.show()

                the_order = 0
                for k, (si, ei, max_amp) in enumerate(syllable_times):

                    sii = int((si / spec_sample_rate) * sample_rate)
                    eii = int((ei / spec_sample_rate) * sample_rate)

                    s = waveform[sii:eii]
                    if len(s) < 1024:
                        continue

                    bs = BioSound(soundWave=s, fs=sample_rate)
                    bs.spectrum(f_high=8000.)
                    bs.ampenv()
                    bs.fundest()

                    stime = sii / sample_rate
                    etime = eii / sample_rate

                    stim_data['stim_id'].append(stim_id)
                    stim_data['stim_type'].append(stim_type)
                    stim_data['start_time'].append(stime)
                    stim_data['end_time'].append(etime)
                    stim_data['order'].append(the_order)
                    the_order += 1

                    for aprop in self.acoustic_props:
                        aval = getattr(bs, aprop)
                        if aval is None:
                            aval = -1
                        stim_data[aprop].append(aval)

            else:

                bs = BioSound(soundWave=waveform, fs=sample_rate)
                bs.spectrum(f_high=8000.)
                bs.ampenv()
                bs.fundest()

                stim_data['stim_id'].append(stim_id)
                stim_data['stim_type'].append(stim_type)
                stim_data['start_time'].append(0)
                stim_data['end_time'].append(stim_dur)
                stim_data['order'].append(0)

                for aprop in self.acoustic_props:
                    aval = getattr(bs, aprop)
                    stim_data[aprop].append(aval)

            self.stim_data = stim_data
            self.stim_df = pd.DataFrame(self.stim_data)
コード例 #5
0
    def plot_single_electrode(self, enumber, start_time, end_time):

        #get stimulus spectrogram
        spec_t, spec_freq, spec = self.experiment.get_spectrogram_slice(
            self.segment, start_time, end_time)

        lfp = self.experiment.get_single_lfp_slice(self.segment, enumber,
                                                   start_time, end_time)

        sr = self.get_sample_rate()
        si = int(start_time * sr)
        ei = int(end_time * sr)
        t = (np.arange(ei - si) / sr) + start_time

        #get the bands
        index2electrode = list(self.index2electrode)
        eindex = index2electrode.index(enumber)
        bands_to_plot = list(range(6))
        X = self.get_bands()
        bands = np.array([X[n, eindex, si:ei] for n in bands_to_plot])

        #make plots
        nrows = len(bands_to_plot) + 2
        ncols = 1
        rcParams.update({'font.size': 18})
        plt.figure()
        plt.subplots_adjust(top=0.95,
                            bottom=0.07,
                            left=0.03,
                            right=0.99,
                            hspace=0.10)

        #plot spectrogram
        ax = plt.subplot(nrows, ncols, 1)
        plot_spectrogram(spec_t,
                         spec_freq,
                         spec,
                         ax=ax,
                         colormap=cm.afmhot_r,
                         colorbar=False,
                         fmin=300.0,
                         fmax=8000.0)
        plt.ylabel('Spectrogram')
        plt.yticks([])
        plt.xticks([])

        #plot raw LFP
        ax = plt.subplot(nrows, ncols, 2)
        plt.plot(t, lfp, 'k-', linewidth=3.0)
        plt.xticks([])
        plt.yticks([])
        plt.ylabel('LFP')
        plt.axis('tight')

        #plot the bands
        for n in bands_to_plot:
            ax = plt.subplot(nrows, ncols, 2 + n + 1)
            plt.plot(t, bands[n, :], 'r-', linewidth=3.0)
            plt.yticks([])
            plt.ylabel('band %d' % n)
            if n < len(bands_to_plot) - 1:
                plt.xticks([])
            else:
                plt.xlabel('Time (s)')
            plt.axis('tight')
コード例 #6
0
    def plot(self, **kwargs):

        seg_uname = segment_to_unique_name(self.segment)

        #set keywords to defaults
        kw_params = {
            'start_time': self.start_time,
            'end_time': self.end_time,
            'bands_to_plot': list(range(self.num_bands)),
            'output_dir': None
        }
        for key, val in kw_params.items():
            if key not in kwargs:
                kwargs[key] = val

        bands_to_plot = kwargs['bands_to_plot']
        start_time = kwargs['start_time']
        end_time = kwargs['end_time']
        output_dir = kwargs['output_dir']

        sr = self.get_sample_rate()
        stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
            self.segment, start_time, end_time)
        t1 = int((start_time - self.start_time) * sr)
        t2 = int((end_time - self.start_time) * sr)

        #compute the average power spectrum of each band across electrodes
        band_ps_list = list()
        band_ps_freq = None
        for n in bands_to_plot:
            ps = list()
            for k, enumber in enumerate(self.index2electrode):
                band = self.X[n, k, t1:t2].squeeze()
                band_fft = fft(band)
                freq = fftfreq(len(band), d=1.0 / self.get_sample_rate())
                findex = freq > 0.0
                band_ps = np.real(band_fft[findex] * np.conj(band_fft[findex]))
                ps.append(band_ps)
                band_ps_freq = freq[findex]
                ps.append(band_ps)

            band_ps_list.append(np.array(ps))

        band_ps = np.array(band_ps_list)

        #plot the average power spectrums
        thresh = 0.0
        plt.figure(figsize=(24.0, 13.5))
        max_pow = -np.inf
        clrs = ['b', 'g', 'r', 'c', 'm', 'y']
        for k, n in enumerate(bands_to_plot):
            #compute mean across electrodes
            band_ps_mean = band_ps[k, :].mean(axis=0)
            band_ps_mean_filt = gaussian_filter1d(band_ps_mean, 10)
            nzindex = band_ps_mean_filt > 0.0
            band_ps_mean_filt[nzindex] = np.log10(band_ps_mean_filt)
            #plt.subplot(len(bands_to_plot), 1, k+1)

            #threshold the power spectrum at zero
            band_ps_mean_filt[band_ps_mean_filt < thresh] = 0.0

            cnorm = float(k) / len(bands_to_plot)
            a = 0.75 * cnorm + 0.25
            c = [cnorm, 1.0 - cnorm, 1.0 - cnorm]
            max_pow = max(max_pow, band_ps_mean_filt.max())
            c = clrs[k]
            plt.plot(band_ps_freq, band_ps_mean_filt, '-', c=c, linewidth=3.0)
        plt.ylim(0.0, max_pow)
        plt.legend(['%d' % (n + 1) for n in bands_to_plot])
        plt.title("Band Average Power Spectrum Across Electrodes")
        plt.xlabel('Freq (Hz)')
        plt.ylabel('Power (dB)')
        plt.axis('tight')

        if output_dir is not None:
            fname = 'band_ps_%s_%s_start%0.6f_end%0.6f.png' % (
                seg_uname, ','.join(self.rcg_names), start_time, end_time)
            plt.savefig(os.path.join(output_dir, fname))
            plt.close('all')

        #plot the bands
        t = np.arange(t2 - t1) / self.get_sample_rate()

        subplot_nrows = 8 + 1
        subplot_ncols = len(self.rcg_names) * 2

        for n in bands_to_plot:
            plt.figure(figsize=(24.0, 13.5))
            plt.subplots_adjust(top=0.95,
                                bottom=0.05,
                                left=0.03,
                                right=0.99,
                                hspace=0.10)
            plt.suptitle('Band %d' % (n + 1))

            for k in range(subplot_ncols):
                ax = plt.subplot(subplot_nrows, subplot_ncols, k + 1)
                plot_spectrogram(stim_spec_t,
                                 stim_spec_freq,
                                 stim_spec,
                                 ax=ax,
                                 colormap=cm.gist_yarg,
                                 colorbar=False)
                plt.ylabel('')
                plt.yticks([])

            for k, electrode in enumerate(self.index2electrode):

                rcg, rc = self.experiment.get_channel(self.segment.block,
                                                      electrode)
                row = rc.annotations['row']
                col = rc.annotations['col']

                if len(self.rcg_names) > 1:
                    sp = (row + 1) * subplot_ncols + col + 1
                else:
                    sp = (row + 1) * subplot_ncols + (col % 2) + 1

                plt.subplot(subplot_nrows, subplot_ncols, sp)

                band = self.X[n, k, t1:t2].squeeze()
                plt.plot(t, band, 'k-', linewidth=2.0)
                plt.ylabel('E%d' % electrode)
                plt.axis('tight')

            if output_dir is not None:
                fname = 'band_raw_%s_%s_%d_start%0.6f_end%0.6f.png' % (
                    seg_uname, ','.join(
                        self.rcg_names), n + 1, start_time, end_time)
                plt.savefig(os.path.join(output_dir, fname))
                plt.close('all')

        if output_dir is None:
            plt.show()
コード例 #7
0
    def plot_complex(self, **kwargs):

        #set keywords to defaults
        kw_params = {
            'start_time': self.start_time,
            'end_time': self.end_time,
            'output_dir': None,
            'include_spec': False,
            'include_ifreq': False,
            'spikes': False,
            'nbands_to_plot': None,
            'sort_code': '0',
            'demodulate': True,
            'phase_only': False,
            'log_amplitude': False
        }
        for key, val in kw_params.items():
            if key not in kwargs:
                kwargs[key] = val

        nbands, nelectrodes, nt = self.X.shape
        start_time = kwargs['start_time']
        end_time = kwargs['end_time']
        duration = end_time - start_time
        output_dir = kwargs['output_dir']

        sr = self.get_sample_rate()
        t1 = int((start_time - self.start_time) * sr)
        d = int((end_time - start_time) * sr)
        t2 = t1 + d
        t = (np.arange(d) / sr) + kwargs['start_time']

        spike_trains = None
        bin_size = 1e-3
        if kwargs['spikes']:
            if self.spike_rasters is None:
                #load up the spike raster for this time slice
                self.spike_rasters = self.experiment.get_spike_slice(
                    self.segment,
                    0.0,
                    self.segment.annotations['duration'],
                    rcg_names=self.rcg_names,
                    as_matrix=False,
                    sort_code=kwargs['sort_code'],
                    bin_size=bin_size)
            if len(self.spike_rasters) > 1:
                print(
                    "WARNING: plot_complex doesn't work well when more than one electrode array is specified."
                )
            spike_trains_full, spike_train_group = self.spike_rasters[
                self.rcg_names[0]]
            #select out the spikes for the interval to plot
            spike_trains = list()
            for st in spike_trains_full:
                sindex = (st >= start_time) & (st <= end_time)
                spike_trains.append(st[sindex])

        colors = np.array([
            [244.0, 244.0, 244.0],  #white
            [241.0, 37.0, 9.0],  #red
            [238.0, 113.0, 25.0],  #orange
            [255.0, 200.0, 8.0],  #yellow
            [19.0, 166.0, 50.0],  #green
            [1.0, 134.0, 141.0],  #blue
            [244.0, 244.0, 244.0],  #white
        ])
        colors /= 255.0

        #get stimulus spectrogram
        stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
            self.segment, kwargs['start_time'], kwargs['end_time'])

        #compute the amplitude, phase, and instantaneous frequency of the complex signal
        amplitude = np.abs(self.Z[:, :, t1:t2])
        phase = np.angle(self.Z[:, :, t1:t2])

        # rescale the amplitude of each electrode so it ranges from 0 to 1
        for k in range(nbands):
            for n in range(nelectrodes):
                amplitude[k, n, :] /= amplitude[k, n].max()

        if kwargs['phase_only']:
            # make sure the amplitude is equal to 1
            nz = amplitude > 0
            amplitude[nz] /= amplitude[nz]

        if kwargs['log_amplitude']:
            nz = amplitude > 0
            amplitude[nz] = np.log10(amplitude[nz])
            amplitude[nz] -= amplitude[nz].min()
            amplitude /= amplitude.max()

        nbands_to_plot = nbands
        if kwargs['nbands_to_plot'] is not None:
            nbands_to_plot = kwargs['nbands_to_plot']

        seg_uname = segment_to_unique_name(self.segment)

        if kwargs['include_ifreq']:
            ##################
            ## make plots of the joint instantaneous frequency per band
            ##################
            rcParams.update({'font.size': 10})
            plt.figure(figsize=(24.0, 13.5))
            plt.subplots_adjust(top=0.98,
                                bottom=0.01,
                                left=0.03,
                                right=0.99,
                                hspace=0.10)
            nsubplots = nbands_to_plot + 2

            #plot the stimulus spectrogram
            ax = plt.subplot(nsubplots, 1, 1)
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.afmhot_r,
                             colorbar=False,
                             fmax=8000.0)
            plt.ylabel('')
            plt.yticks([])

            ifreq = np.zeros([nbands, nelectrodes, d])
            sr = self.get_sample_rate()
            for k in range(nbands):
                for j in range(nelectrodes):
                    ifreq[k, j, :] = compute_instantaneous_frequency(
                        self.Z[k, j, t1:t2], sr)
                    ifreq[k, j, :] = lowpass_filter(ifreq[k, j, :],
                                                    sr,
                                                    cutoff_freq=50.0)

            #plot the instantaneous frequency along with it's amplitude
            for k in range(nbands_to_plot):
                img = np.zeros([nelectrodes, d, 4], dtype='float32')

                ifreq_min = np.percentile(ifreq[k, :, :], 5)
                ifreq_max = np.percentile(ifreq[k, :, :], 95)
                ifreq_dist = ifreq_max - ifreq_min
                #print 'ifreq_max=%0.3f, ifreq_min=%0.3f' % (ifreq_max, ifreq_min)

                for j in range(nelectrodes):
                    max_amp = np.percentile(amplitude[k, j, :], 85)

                    #set the alpha and color for the bins
                    alpha = amplitude[k, j, :] / max_amp
                    alpha[alpha > 1.0] = 1.0  #saturate
                    alpha[alpha < 0.05] = 0.0  #nonlinear threshold

                    cnorm = (ifreq[k, j, :] - ifreq_min) / ifreq_dist
                    cnorm[cnorm > 1.0] = 1.0
                    cnorm[cnorm < 0.0] = 0.0
                    img[j, :, 0] = 1.0 - cnorm
                    img[j, :, 1] = 1.0 - cnorm
                    img[j, :, 2] = 1.0 - cnorm
                    #img[j, :, 3] = alpha
                    img[j, :, 3] = 1.0

                ax = plt.subplot(nsubplots, 1, k + 2)
                ax.set_axis_bgcolor('black')
                im = plt.imshow(img,
                                interpolation='nearest',
                                aspect='auto',
                                origin='upper',
                                extent=[t.min(),
                                        t.max(), 1, nelectrodes])
                plt.axis('tight')
                plt.ylabel('Electrode')
                plt.title('band %d' % (k + 1))
            plt.suptitle('Instantaneous Frequency')

            if output_dir is not None:
                fname = 'band_ifreq_%s_%s_start%0.6f_end%0.6f.png' % (
                    seg_uname, ','.join(self.rcg_names), start_time, end_time)
                plt.savefig(os.path.join(output_dir, fname))

        ##################
        ## make plots of the joint phase per band
        ##################

        def compute_envelope(the_matrix, log=False):
            tm_env = np.abs(the_matrix).sum(axis=0)
            tm_env -= tm_env.min()
            tm_env /= tm_env.max()

            if log:
                nz = tm_env > 0.0
                tm_env[nz] = np.log10(tm_env[nz])
                tm_env_thresh = -np.percentile(np.abs(tm_env[nz]), 95)
                tm_env[~nz] = tm_env_thresh
                tm_env[tm_env <= tm_env_thresh] = tm_env_thresh
                tm_env -= tm_env_thresh
                tm_env /= tm_env.max()

            return tm_env

        #compute the amplitude envelope for the spectrogram
        stim_spec_env = compute_envelope(stim_spec)

        rcParams.update({'font.size': 10})
        fig = plt.figure(figsize=(24.0, 13.5), facecolor='gray')
        plt.subplots_adjust(top=0.98,
                            bottom=0.01,
                            left=0.03,
                            right=0.99,
                            hspace=0.10)
        nsubplots = nbands_to_plot + 1 + int(kwargs['spikes'])

        #plot the stimulus spectrogram
        ax = plt.subplot(nsubplots, 1, 1)
        ax.set_axis_bgcolor('black')
        plot_spectrogram(stim_spec_t,
                         stim_spec_freq,
                         stim_spec,
                         ax=ax,
                         colormap=cm.afmhot,
                         colorbar=False,
                         fmax=8000.0)
        plt.plot(stim_spec_t,
                 stim_spec_env * stim_spec_freq.max(),
                 'w-',
                 linewidth=3.0,
                 alpha=0.75)
        plt.axis('tight')
        plt.ylabel('')
        plt.yticks([])

        #plot the spike raster
        if kwargs['spikes']:
            spike_count_env = spike_envelope(spike_trains,
                                             start_time,
                                             duration,
                                             bin_size=bin_size,
                                             win_size=30.0)
            ax = plt.subplot(nsubplots, 1, 2)
            plot_raster(spike_trains,
                        ax=ax,
                        duration=duration,
                        bin_size=bin_size,
                        time_offset=start_time,
                        ylabel='Cell',
                        bgcolor='k',
                        spike_color='#ff0000')
            tenv = np.arange(len(spike_count_env)) * bin_size + start_time
            plt.plot(tenv,
                     spike_count_env * len(spike_trains),
                     'w-',
                     linewidth=1.5,
                     alpha=0.5)
            plt.axis('tight')
            plt.xticks([])
            plt.yticks([])
            plt.ylabel('Spikes')

        #phase_min = phase.min()
        #phase_max = phase.max()
        #print 'phase_max=%0.3f, phase_min=%0.3f' % (phase_max, phase_min)

        for k in range(nbands_to_plot):

            the_phase = phase[k, :, :]
            if kwargs['demodulate']:
                Ztemp = amplitude[k, :, :] * (
                    np.cos(the_phase) + complex(0, 1) * np.sin(the_phase))
                the_phase, complex_pcs = demodulate(Ztemp, depth=1)
                del Ztemp

            img = make_phase_image(amplitude[k, :, :],
                                   the_phase,
                                   normalize=True,
                                   threshold=True,
                                   saturate=True)
            amp_env = compute_envelope(amplitude[k, :, :], log=False)

            ax = plt.subplot(nsubplots, 1, k + 2 + int(kwargs['spikes']))
            ax.set_axis_bgcolor('black')
            im = plt.imshow(img,
                            interpolation='nearest',
                            aspect='auto',
                            origin='upper',
                            extent=[t.min(), t.max(), 1, nelectrodes])
            if not kwargs['phase_only']:
                plt.plot(t,
                         amp_env * nelectrodes,
                         'w-',
                         linewidth=2.0,
                         alpha=0.75)
            plt.axis('tight')
            plt.ylabel('Electrode')
            plt.title('band %d' % (k + 1))

        plt.suptitle('Phase')

        if output_dir is not None:
            fname = 'band_phase_%s_%s_start%0.6f_end%0.6f.png' % (
                seg_uname, ','.join(self.rcg_names), start_time, end_time)
            plt.savefig(os.path.join(output_dir, fname),
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
            del fig

        if not kwargs['include_spec']:
            return

        ##################
        ## make plots of the band spectrograms
        ##################
        subplot_nrows = 8 + 1
        subplot_ncols = len(self.rcg_names) * 2
        plt.figure()
        plt.subplots_adjust(top=0.95,
                            bottom=0.05,
                            left=0.03,
                            right=0.99,
                            hspace=0.10)

        for k in range(subplot_ncols):
            ax = plt.subplot(subplot_nrows, subplot_ncols, k + 1)
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.gist_yarg,
                             colorbar=False)
            plt.ylabel('')
            plt.yticks([])

        for j in range(nelectrodes):

            plt.figure(figsize=(24.0, 13.5))
            plt.subplots_adjust(top=0.95,
                                bottom=0.05,
                                left=0.03,
                                right=0.99,
                                hspace=0.10)

            electrode = self.index2electrode[j]
            rcg, rc = self.experiment.get_channel(self.segment.block,
                                                  electrode)
            row = rc.annotations['row']
            col = rc.annotations['col']
            if len(self.rcg_names) > 1:
                sp = (row + 1) * subplot_ncols + col + 1
            else:
                sp = (row + 1) * subplot_ncols + (col % 2) + 1

            #ax = plt.subplot(subplot_nrows, subplot_ncols, sp)
            gs = GridSpec(100, 1)
            ax = plt.subplot(gs[:20])
            plot_spectrogram(stim_spec_t,
                             stim_spec_freq,
                             stim_spec,
                             ax=ax,
                             colormap=cm.gist_yarg,
                             colorbar=False)
            plt.ylabel('')
            plt.yticks([])

            ax = plt.subplot(gs[20:])
            ax.set_axis_bgcolor('black')

            #get the maximum frequency and set the resolution
            max_freq = ifreq[:, j, :].max()
            nf = 150
            df = max_freq / nf

            #create an image to hold the frequencies
            img = np.zeros([nf, ifreq.shape[-1], 4], dtype='float32')

            #fill in the image for each band
            for k in range(nbands):
                max_amp = np.percentile(amplitude[k, j, :], 85)

                freq_bin = (ifreq[k, j, :] / df).astype('int') - 1
                freq_bin[freq_bin < 0] = 0

                #set the color and alpha for the bins
                alpha = amplitude[k, j, :] / max_amp
                alpha[alpha > 1.0] = 1.0  #saturate
                alpha[alpha < 0.05] = 0.0  #nonlinear threshold

                for m, fbin in enumerate(freq_bin):
                    #print 'm=%d, fbin=%d, colors[k, :].shape=%s' % (m, fbin, str(colors[k, :].shape))
                    img[fbin, m, :3] = colors[k, :]
                    img[fbin, m, 3] = alpha[m]

            #plot the image
            im = plt.imshow(img,
                            interpolation='nearest',
                            aspect='auto',
                            origin='lower',
                            extent=[t.min(), t.max(), 0.0, max_freq])
            plt.ylabel('E%d' % electrode)
            plt.axis('tight')
            plt.ylim(0.0, 140.0)

            if output_dir is not None:
                fname = 'band_spec_e%d_%s_%s_start%0.6f_end%0.6f.png' % (
                    electrode, seg_uname, ','.join(
                        self.rcg_names), start_time, end_time)
                plt.savefig(os.path.join(output_dir, fname))
                plt.close('all')

        if output_dir is None:
            plt.show()
コード例 #8
0
        #               np.save("%sSeg2.npy" %fname[:-4], seg2)
        #           else:
        #               print('short seg')
        #           shutil.move('/Users/Alicia/Desktop/BirdCallProject/Freq_20_10000/filteredCalls/%s' %fname, \
        #               '/Users/Alicia/Desktop/BirdCallProject/Freq_20_10000/reSegmentedCalls/%s' %fname)
        #       if len(minimum) == 2:
        #           seg1 = sound[:minimum[0]]
        #           seg2 = sound[minimum[0]:minimum[1]]
        #           seg3 = sound[minimum[1]:]
        #           print(len(seg1), len(seg2), len(seg3))
        #           if len(seg1) > 1323 and len(seg2) > 1323 and len(seg3) > 1323:
        #               np.save("%sSeg1.npy" %fname[:-4], seg1)
        #               np.save("%sSeg2.npy" %fname[:-4], seg2)
        #               np.save("%sSeg3.npy" %fname[:-4], seg3)
        #           else:
        #               print('short seg')
        #           shutil.move('/Users/Alicia/Desktop/BirdCallProject/Freq_20_10000/filteredCalls/%s' %fname, \
        #               '/Users/Alicia/Desktop/BirdCallProject/Freq_20_10000/reSegmentedCalls/%s' %fname)

        #        plt.plot(sound_env/sound_env.max(), color="red", linewidth=2)
        plt.figure()
        (tDebug, freqDebug, specDebug, rms) = spectrogram(sound,
                                                          fs,
                                                          1000.0,
                                                          50,
                                                          min_freq=0,
                                                          max_freq=10000,
                                                          nstd=6,
                                                          cmplx=True)
        plot_spectrogram(tDebug, freqDebug, specDebug)
        plt.show()
コード例 #9
0
ファイル: figure4.py プロジェクト: theunissenlab/zeebeez3
def plot_single_trial_data(d, syllable_index, trial_index):

    syllable_start = d['syllable_props'][syllable_index]['start_time'] - 0.030
    syllable_end = d['syllable_props'][syllable_index]['end_time'] + 0.030

    # set the figure width proportional to the length of the syllable for uniformity across stimuli
    max_fig_width = 12.0
    max_stim_duration = 2.5

    fig_width = (
        (syllable_end - syllable_start) / max_stim_duration) * max_fig_width

    figsize = (fig_width, 10)
    fig = plt.figure(figsize=figsize, facecolor='w')
    fig.subplots_adjust(top=0.95,
                        bottom=0.02,
                        right=0.97,
                        left=0.03,
                        hspace=0.20,
                        wspace=0.20)

    gs = plt.GridSpec(100, 1)

    # plot the biosound features
    ax = plt.subplot(gs[:10])
    sprops = d['syllable_props'][syllable_index]
    aprops = USED_ACOUSTIC_PROPS

    vals = [sprops[a] for a in aprops]
    plt.axhline(0, c='k')
    for k, (aprop, v) in enumerate(zip(aprops, vals)):
        bx = k
        rgb = np.array(ACOUSTIC_PROP_COLORS_BY_TYPE[aprop]).astype('int')
        clr_hex = to_hex(*rgb)
        plt.bar(bx, v, color=clr_hex, alpha=0.7)
    ax.xaxis.tick_top()
    # plt.xticks(range(len(aprops)), aprops, rotation=45, fontsize=6)
    plt.xticks([])

    # plot the spectrogram
    ax = plt.subplot(gs[15:40])
    spec = d['spec']
    spec[spec < np.percentile(spec, 15)] = 0

    # the spectogram is already log transformed, make sure log=False and dBNoise=None
    plot_spectrogram(d['spec_t'],
                     d['spec_freq'] * 1e-3,
                     spec,
                     ax=ax,
                     colormap='SpectroColorMap',
                     ticks=False,
                     log=False,
                     dBNoise=None,
                     colorbar=False)

    plt.xlim(syllable_start, syllable_end)

    # plot the raw LFP
    ax = plt.subplot(gs[45:70])

    sr = d['lfp_sample_rate']
    raw_lfp = d['lfp'][trial_index, :, :]
    lfp_t = np.arange(raw_lfp.shape[1]) / sr
    nelectrodes, nt = raw_lfp.shape
    lfp_i = (lfp_t >= syllable_start) & (lfp_t <= syllable_end)

    voffset = 5
    for n in range(nelectrodes):
        plt.plot(lfp_t[lfp_i],
                 raw_lfp[nelectrodes - n - 1, :][lfp_i] + voffset * n,
                 'k-',
                 linewidth=2.0,
                 alpha=0.75)
    plt.axis('tight')
    ytick_locs = np.arange(nelectrodes) * voffset
    plt.yticks(ytick_locs, list(reversed(d['electrode_order'])))
    plt.xticks([])

    # plot the spike train raster
    ax = plt.subplot(gs[75:])

    spike_mat = d[
        'spikes']  # list-of-lists with shape (num_trials, num_neurons)
    print('# of neurons: ', len(spike_mat))
    raw_spikes = list()
    for spike_train in spike_mat[trial_index]:
        i = (spike_train >= syllable_start) & (spike_train <= syllable_end)
        raw_spikes.append(spike_train[i] - syllable_start)
    plt.xticks([])

    plot_raster(raw_spikes,
                ax=ax,
                duration=syllable_end - syllable_start,
                bin_size=0.001,
                time_offset=0.0,
                ylabel='',
                groups=None,
                bgcolor=None,
                spike_color='k')
    plt.xticks([])
コード例 #10
0
    def plot_segment_slice(self,
                           seg,
                           start_time,
                           end_time,
                           output_dir=None,
                           sort_code='0'):
        """ Plot all the data for a given slice of time in a segment. This function will plot the spectrogram,
            multi-electrode spike trains, and LFPs for each array all on one figure.

        :param seg:
        :param start_time:
        :param end_time:
        :return:
        """

        num_arrays = len(seg.block.recordingchannelgroups)

        nrows = 1 + 2 * num_arrays
        ncols = 1

        plt.figure()

        #get the slice of spectrogram to plot
        stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
            seg, start_time, end_time)

        #get the spikes to plot
        spikes = self.experiment.get_spike_slice(seg,
                                                 start_time,
                                                 end_time,
                                                 sort_code=sort_code)

        #get the multielectrode LFPs for each channel
        lfps = self.experiment.get_lfp_slice(seg, start_time, end_time)

        #plot the spectrogram
        ax = plt.subplot(nrows, ncols, 1)
        plot_spectrogram(stim_spec_t,
                         stim_spec_freq,
                         stim_spec,
                         ax=ax,
                         colormap=cm.afmhot_r,
                         colorbar=False,
                         fmax=8000.0)
        plt.ylabel('')
        plt.yticks([])

        #plot the LFPs and spikes for each recording array
        for k, rcg in enumerate(seg.block.recordingchannelgroups):
            spike_trains, spike_train_groups = spikes[rcg.name]
            ax = plt.subplot(nrows, ncols, 2 + k * 2)
            plot_raster(spike_trains,
                        ax=ax,
                        duration=end_time - start_time,
                        bin_size=0.001,
                        time_offset=start_time,
                        ylabel='Electrode',
                        groups=spike_train_groups)

            electrode_indices, lfp_matrix, sample_rate = lfps[rcg.name]
            nelectrodes = len(electrode_indices)

            #zscore the LFP
            LFPmean = lfp_matrix.T.mean(axis=0)
            LFPstd = lfp_matrix.T.std(axis=0, ddof=1)
            nz = LFPstd > 0.0
            lfp_matrix.T[:, nz] -= LFPmean[nz]
            lfp_matrix.T[:, nz] /= LFPstd[nz]

            ax = plt.subplot(nrows, ncols, 2 + k * 2 + 1)
            plt.imshow(lfp_matrix,
                       interpolation='nearest',
                       aspect='auto',
                       cmap=cm.seismic,
                       extent=[start_time, end_time, 1, nelectrodes])
            lbls = ['%d' % e for e in electrode_indices]
            lbls.reverse()
            plt.yticks(range(nelectrodes), lbls)
            plt.axis('tight')
            plt.ylabel('Electrode')
コード例 #11
0
ファイル: pairwise_cf.py プロジェクト: theunissenlab/zeebeez3
    def transform(self,
                  stim_event,
                  lags=np.arange(-10, 11, 1),
                  min_syllable_dur=0.050,
                  post_syllable_dur=0.030,
                  rep_type='raw',
                  debug=False,
                  window_fraction=0.60,
                  noise_db=25.):

        assert isinstance(stim_event, StimEventTransform)
        assert rep_type in stim_event.lfp_reps_by_stim
        self.rep_type = rep_type
        self.bird = stim_event.bird

        self.segment_uname = stim_event.seg_uname
        self.rcg_names = stim_event.rcg_names

        self.lags = (lags / stim_event.lfp_sample_rate) * 1e3
        stim_ids = list(stim_event.lfp_reps_by_stim[rep_type].keys())

        all_psds = list()
        all_cross_cfs = list()

        all_spike_rates = list()
        all_spike_sync = list()

        # zscore the LFPs
        stim_event.zscore(rep_type)

        data = {
            'stim_id': list(),
            'stim_type': list(),
            'stim_duration': list(),
            'order': list(),
            'decomp': list(),
            'electrode1': list(),
            'electrode2': list(),
            'region1': list(),
            'region2': list(),
            'cell_index': list(),
            'cell_index2': list(),
            'index': list()
        }

        # get map of electrode indices to region
        index2region = list()
        for e in stim_event.index2electrode:
            i = stim_event.electrode_data['electrode'] == e
            index2region.append(stim_event.electrode_data['region'][i][0])
        print('index2region=', index2region)

        # map cell indices to electrodes
        ncells = len(stim_event.cell_df)
        if ncells > 0:
            print('ncells=%d' % ncells)
            cell_index2electrode = [0] * ncells
            assert len(stim_event.cell_df.sort_code.unique()) == 1
            for ci, e in zip(stim_event.cell_df['index'],
                             stim_event.cell_df['electrode']):
                cell_index2electrode[ci] = e
        else:
            cell_index2electrode = [-1]

        print('len(cell_index2electrode)=%d' % len(cell_index2electrode))

        # make a list of all valid syllables for each stimulus
        num_valid_syllables_per_type = dict()
        stim_syllables = dict()
        lags_max = np.abs(lags).max() / stim_event.lfp_sample_rate
        good_stim_ids = list()
        for stim_id in stim_ids:
            seg_times = list()
            i = stim_event.segment_df['stim_id'] == stim_id
            if i.sum() == 0:
                print('Missing stim information for stim %d!' % stim_id)
                continue

            stype = stim_event.segment_df['stim_type'][i].values[0]
            for k, (stime, etime, order) in enumerate(
                    zip(stim_event.segment_df['start_time'][i],
                        stim_event.segment_df['end_time'][i],
                        stim_event.segment_df['order'][i])):
                dur = (etime - stime) + post_syllable_dur
                if dur < min_syllable_dur:
                    continue

                # make sure the duration is long enough to support the lags
                assert dur > lags_max, "Lags is too long, duration=%0.3f, lags_max=%0.3f" % (
                    dur, lags_max)

                # add the syllable to the list, add in the extra post-stimulus time
                seg_times.append((stime, etime + post_syllable_dur, order))
                if stype not in num_valid_syllables_per_type:
                    num_valid_syllables_per_type[stype] = 0
                num_valid_syllables_per_type[stype] += 1

            if len(seg_times) > 0:
                stim_syllables[stim_id] = np.array(seg_times)
                good_stim_ids.append(stim_id)

        print('# of syllables per category:')
        for stype, nstype in list(num_valid_syllables_per_type.items()):
            print('%s: %d' % (stype, nstype))

        # specify the window size for the auto-spectra and cross-coherence. the minimum segment size is 80ms
        psd_window_size = 0.060
        psd_increment = 2 / stim_event.lfp_sample_rate

        for stim_id in good_stim_ids:
            # get stim type
            i = stim_event.trial_df['stim_id'] == stim_id
            stim_type = stim_event.trial_df['stim_type'][i].values[0]

            print('Computing CFs for stim %d (%s)' % (stim_id, stim_type))

            # get the raw LFP
            X = stim_event.lfp_reps_by_stim[rep_type][stim_id]
            ntrials, nelectrodes, nt = X.shape

            # get the spike trains, a ragged array of shape (num_trials, num_cells, num_spikes)
            # The important thing to know about the spike times is that they are with respect
            # to the stimulus onset. Negative spike times occur prior to stimulus onset.
            spike_mat = stim_event.spikes_by_stim[stim_id]
            assert ntrials == len(
                spike_mat), "Weird number of trials in spike_mat: %d" % (
                    len(spike_mat))

            # get segment data for this stim, start and end times of segments
            seg_times = stim_syllables[stim_id]

            # go through each syllable of the stimulus
            for stime, etime, order in seg_times:
                # compute the start and end indices of the LFP for this syllable, keeping in mind that the LFP
                # segment for this stimulus includes the pre and post stim times.
                lfp_syllable_start = (stim_event.pre_stim_time + stime)
                lfp_syllable_end = (stim_event.pre_stim_time + etime)
                si = int(lfp_syllable_start * stim_event.lfp_sample_rate)
                ei = int(lfp_syllable_end * stim_event.lfp_sample_rate)

                stim_dur = etime - stime

                # because spike times are with respect to the onset of the first syllable
                # of this stimulus, stime and etime define the appropriate window of analysis
                # for this syllable for spike times.
                spike_syllable_start = stime
                spike_syllable_end = etime
                spike_syllable_dur = etime - stime

                if debug:
                    # get the spectrogram, lfp and spikes for the stimulus
                    stim_spec = stim_event.spec_by_stim[stim_id]
                    syllable_spec = stim_spec[:, si:ei]
                    the_lfp = X[:, :, si:ei]
                    the_spikes = list()
                    lfp_t = np.arange(
                        the_lfp.shape[-1]) / stim_event.lfp_sample_rate
                    spec_t = np.arange(
                        syllable_spec.shape[1]) / stim_event.lfp_sample_rate
                    spec_freq = stim_event.spec_freq

                    for k in range(ntrials):
                        the_cell_spikes = list()
                        for n in range(ncells):
                            st = spike_mat[k][n]
                            i = (st >= stime) & (st <= etime)
                            print(
                                'trial=%d, cell=%d, nspikes=%d, (%0.3f, %0.3f)'
                                % (k, n, i.sum(), spike_syllable_start,
                                   spike_syllable_end))
                            print('st=', st)
                            the_cell_spikes.append(st[i] - stime)
                        the_spikes.append(the_cell_spikes)

                    # make some plots to check the raw data, left hand plot is LFP, right hand is spikes
                    nrows = ntrials + 1
                    plt.figure()
                    gs = plt.GridSpec(nrows, 2)

                    ax = plt.subplot(gs[0, 0])
                    plot_spectrogram(spec_t,
                                     spec_freq,
                                     syllable_spec,
                                     ax=ax,
                                     colormap='gray',
                                     colorbar=False)

                    ax = plt.subplot(gs[0, 1])
                    plot_spectrogram(spec_t,
                                     spec_freq,
                                     syllable_spec,
                                     ax=ax,
                                     colormap='gray',
                                     colorbar=False)

                    for k in range(ntrials):
                        ax = plt.subplot(gs[k + 1, 0])
                        lllfp = the_lfp[k, :, :]
                        absmax = np.abs(lllfp).max()
                        plt.imshow(
                            lllfp,
                            interpolation='nearest',
                            aspect='auto',
                            cmap=plt.cm.seismic,
                            vmin=-absmax,
                            vmax=absmax,
                            origin='lower',
                            extent=[lfp_t.min(),
                                    lfp_t.max(), 1, nelectrodes])

                        ax = plt.subplot(gs[k + 1, 1])
                        plot_raster(the_spikes[k],
                                    ax=ax,
                                    duration=spike_syllable_end -
                                    spike_syllable_start,
                                    time_offset=0,
                                    ylabel='')

                    plt.title('stim %d, syllable %d' % (stim_id, order))
                    plt.show()

                # compute the LFP props
                lfp_props = self.compute_lfp_spectra_and_cfs(
                    X[:, :, si:ei], stim_event.lfp_sample_rate, lags,
                    psd_window_size, psd_increment, window_fraction, noise_db)

                # save the power spectra to the data frame and data matrix
                decomp_types = ['trial_avg', 'mean_sub', 'full', 'onewin']
                for n, e in enumerate(stim_event.index2electrode):
                    for decomp in decomp_types:

                        data['stim_id'].append(stim_id)
                        data['stim_type'].append(stim_type)
                        data['order'].append(order)
                        data['decomp'].append(decomp)
                        data['electrode1'].append(e)
                        data['electrode2'].append(e)
                        data['region1'].append(index2region[n])
                        data['region2'].append(index2region[n])
                        data['cell_index'].append(-1)
                        data['cell_index2'].append(-1)
                        data['index'].append(len(all_psds))
                        data['stim_duration'].append(stim_dur)

                        pkey = '%s_psds' % decomp
                        all_psds.append(lfp_props[pkey][n])

                # save the cross terms to the data frame and data matrix
                for n1, e1 in enumerate(stim_event.index2electrode):
                    for n2 in range(n1):
                        e2 = stim_event.index2electrode[n2]

                        # get index of cfs for this electrode pair
                        pi = lfp_props['cross_electrodes'].index((n1, n2))
                        for decomp in decomp_types:

                            if decomp == 'onewin':
                                continue

                            data['stim_id'].append(stim_id)
                            data['stim_type'].append(stim_type)
                            data['order'].append(order)
                            data['decomp'].append(decomp)
                            data['electrode1'].append(e1)
                            data['electrode2'].append(e2)
                            data['region1'].append(index2region[n1])
                            data['region2'].append(index2region[n2])
                            data['cell_index'].append(-1)
                            data['cell_index2'].append(-1)
                            data['index'].append(len(all_cross_cfs))
                            data['stim_duration'].append(stim_dur)

                            pkey = '%s_cfs' % decomp
                            all_cross_cfs.append(lfp_props[pkey][pi])

                if len(cell_index2electrode
                       ) == 1 and cell_index2electrode[0] == -1:
                    continue

                # compute the spike rate vector for each neuron
                for ci, e in enumerate(cell_index2electrode):
                    rates = list()
                    for k in range(ntrials):
                        st = spike_mat[k][ci]

                        i = (st >= spike_syllable_start) & (st <=
                                                            spike_syllable_end)
                        r = i.sum() / (spike_syllable_end -
                                       spike_syllable_start)
                        rates.append(r)

                    rates = np.array(rates)
                    rv = [rates.mean(), rates.std(ddof=1)]

                    data['stim_id'].append(stim_id)
                    data['stim_type'].append(stim_type)
                    data['order'].append(order)
                    data['decomp'].append('spike_rate')
                    data['electrode1'].append(e)
                    data['electrode2'].append(e)
                    data['region1'].append(index2region[n])
                    data['region2'].append(index2region[n])
                    data['cell_index'].append(ci)
                    data['cell_index2'].append(ci)
                    data['index'].append(len(all_spike_rates))
                    data['stim_duration'].append(stim_dur)

                    all_spike_rates.append(rv)

                # compute the pairwise synchrony between spike trains
                spike_sync = np.zeros([ntrials, ncells, ncells])
                for k in range(ntrials):
                    for ci1, e1 in enumerate(cell_index2electrode):
                        st1 = spike_mat[k][ci1]
                        if len(st1) == 0:
                            continue

                        for ci2 in range(ci1):
                            st2 = spike_mat[k][ci2]
                            if len(st2) == 0:
                                continue
                            sync12 = simple_synchrony(st1,
                                                      st2,
                                                      spike_syllable_dur,
                                                      bin_size=3e-3)
                            spike_sync[k, ci1, ci2] = sync12
                            spike_sync[k, ci2, ci1] = sync12

                # average pairwise synchrony across trials
                spike_sync_avg = spike_sync.mean(axis=0)

                # save the spike sync data into the data frame
                for ci1, e1 in enumerate(cell_index2electrode):
                    n1 = stim_event.index2electrode.index(e1)
                    for ci2 in range(ci1):
                        e2 = cell_index2electrode[ci2]
                        n2 = stim_event.index2electrode.index(e2)

                        data['stim_id'].append(stim_id)
                        data['stim_type'].append(stim_type)
                        data['order'].append(order)
                        data['decomp'].append('spike_sync')
                        data['electrode1'].append(e1)
                        data['electrode2'].append(e2)
                        data['region1'].append(index2region[n1])
                        data['region2'].append(index2region[n2])
                        data['cell_index'].append(ci1)
                        data['cell_index2'].append(ci2)
                        data['index'].append(len(all_spike_sync))
                        data['stim_duration'].append(stim_dur)

                        all_spike_sync.append(spike_sync_avg[ci1, ci2])

        self.cell_index2electrode = cell_index2electrode

        self.psds = np.array(all_psds)
        self.cross_cfs = np.array(all_cross_cfs)

        self.spike_rate = np.array(all_spike_rates)
        self.spike_synchrony = np.array(all_spike_sync)

        self.data = data
        self.df = pd.DataFrame(self.data)
コード例 #12
0
import scipy.spatial as ss
import scipy.stats as sst
from scipy.special import digamma, gamma
from math import log, pi, exp
from scipy.integrate import odeint
from AttractorReconstructUtilities import TimeSeries3D, TimeDelayReconstruct
from matplotlib import style
from sciPlot import updateParams
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import MultipleLocator, FixedLocator
from soundsig.sound import WavFile, plot_spectrogram, spectrogram
updateParams()

X = np.load(
    '/Users/Alicia/Desktop/BirdCallProject/Freq_250_12000/filteredCalls/LblBla4548_130418-DC-46.npy'
)
X = X[2900:4000]
(tDebug, freqDebug, specDebug, rms) = spectrogram(X,
                                                  44100,
                                                  2000.0,
                                                  50,
                                                  min_freq=0,
                                                  max_freq=22050,
                                                  nstd=6,
                                                  cmplx=True)
fig, ax = plt.subplots(figsize=(5, 2))
plot_spectrogram(tDebug, freqDebug, specDebug, ax=ax, colorbar=False)
ax.set_xlabel('t (s)')
plt.show()
#fig.savefig('Fig5Bifurcations.svg', dpi=300, bbox_inches='tight', transparent=True)
コード例 #13
0
 def _plot_fn(self, data, ax):
     t_spec, f_spec, spec = data
     plot_spectrogram(t_spec, f_spec, spec, ax=ax, **self._spec_kwargs)
コード例 #14
0
#allMIs = [MI_lcmin(Seg0, order=2, PLOT=False) for Seg0 in Seg]
#print(allMIs)
fig = plt.figure(figsize=(5, 8))
for i in range(2):
    ax = fig.add_subplot(2, 1, i + 1)
    (tDebug, freqDebug, specDebug, rms) = spectrogram(Seg[i],
                                                      44100,
                                                      2000,
                                                      100,
                                                      min_freq=0,
                                                      max_freq=22050,
                                                      nstd=6,
                                                      cmplx=True)
    plot_spectrogram(tDebug,
                     freqDebug,
                     specDebug,
                     ax=ax,
                     colorbar=True,
                     dBNoise=None)
    ax.set_yticklabels([0, 5, 10, 15, 20])
    ax.set_ylabel('Frequency (kHz)')
    ax.xaxis.set_tick_params(direction='in', width=0.4)
    ax.yaxis.set_tick_params(direction='in', width=0.4)
    if i == 0:
        ax.set_xlabel('')
        ax.set_xticks([])
        ax.set_xticklabels([])
    else:
        ax.set_xlabel('t (s)')
#MI = [3,2]
#for i in range(2):
#	ax = fig.add_subplot(2, 1, i+1, projection='3d')
コード例 #15
0
ファイル: spike_event.py プロジェクト: theunissenlab/zeebeez3
    def plot(self, start_time, end_time):

        for seg_uname, events in list(self.events.items()):

            durations = events[:, 1] - events[:, 0]
            amplitudes = events[:, -1]
            event_start_times = events[:, 0]
            event_end_times = events[:, 1]

            inter_event_intervals = event_start_times[1:] - event_end_times[:-1]

            # print some event info
            print('Segment: %s' % seg_uname)
            print('\tIdentified %d events' % len(durations))
            print(
                '\tDuration: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                (durations.min(), durations.max(), durations.mean(),
                 np.median(durations)))
            print('\tIEI: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                  (inter_event_intervals.min(), inter_event_intervals.max(),
                   inter_event_intervals.mean(),
                   np.median(inter_event_intervals)))
            print(
                '\tAmplitude: min=%0.6f, max=%0.6f, mean=%0.6f, median=%0.6f' %
                (amplitudes.min(), amplitudes.max(), amplitudes.mean(),
                 np.median(amplitudes)))

            # plot some event statistics
            plt.figure()
            plt.subplot(3, 1, 1)
            plt.hist(durations, bins=100, color='r')
            plt.axis('tight')
            plt.title('Event Duration (s)')

            plt.subplot(3, 1, 2)
            plt.hist(inter_event_intervals, bins=100, color='k')
            plt.axis('tight')
            plt.title('Inter-event Interval (s)')

            plt.subplot(3, 1, 3)
            plt.hist(amplitudes, bins=100, color='g')
            plt.title('Peak Amplitude')
            plt.suptitle(seg_uname)
            plt.axis('tight')

            # plot event duration vs amplitude histogram
            plt.figure()
            plt.hist2d(durations,
                       amplitudes,
                       bins=[40, 30],
                       cmap=cm.Greys,
                       norm=LogNorm())
            plt.xlim(0, 2.)
            plt.xlabel('Duration')
            plt.ylabel('Amplitude')
            plt.colorbar(label="# of Joint Events")
            plt.suptitle(seg_uname)

            if self.experiment is not None:
                # get the segment
                seg = segment_from_unique_name(self.experiment.blocks,
                                               seg_uname)

                # get a slice of the spike raster
                spike_rasters = self.experiment.get_spike_slice(
                    seg,
                    start_time,
                    end_time,
                    rcg_names=self.rcg_names,
                    as_matrix=False,
                    sort_code=self.sort_code,
                    bin_size=self.bin_size)
                spike_trains, spike_train_group = spike_rasters[
                    self.rcg_names[0]]

                # get the spike envelope for this time slice
                si = int(start_time / self.bin_size)
                ei = int(end_time / self.bin_size)
                spike_env = self.envelopes[seg_uname][si:ei]

                # get the stim envelope for this time slice
                stim_env = self.spec_envelopes[seg_uname][si:ei]

                # get stimulus spectrogram
                stim_spec_t, stim_spec_freq, stim_spec = self.experiment.get_spectrogram_slice(
                    seg, start_time, end_time)

                # compute the amplitude envelope for the spectrogram
                stim_spec_env = stim_envelope(stim_spec)

                plt.figure()
                plt.suptitle(seg_uname)

                # make a plot of the spectrogram
                ax = plt.subplot(3, 1, 1)
                ax.set_axis_bgcolor('black')
                plot_spectrogram(stim_spec_t,
                                 stim_spec_freq,
                                 stim_spec,
                                 ax=ax,
                                 colormap=cm.afmhot,
                                 colorbar=False,
                                 fmax=8000.0)
                plt.plot(stim_spec_t,
                         stim_spec_env * stim_spec_freq.max(),
                         'w-',
                         linewidth=3.0,
                         alpha=0.75)
                plt.axis('tight')
                plt.ylabel('')
                plt.yticks([])

                # plot the stimulus amplitude envelope
                tenv = np.arange(len(spike_env)) * self.bin_size + start_time
                ax = plt.subplot(3, 1, 2)
                ax.set_axis_bgcolor('black')
                plt.plot(tenv, stim_env, 'w-', linewidth=2.0)
                plt.axis('tight')

                # make a plot of the spike raster, the envelope, and the events
                ax = plt.subplot(3, 1, 3)
                plot_raster(spike_trains,
                            ax=ax,
                            duration=end_time - start_time,
                            bin_size=self.bin_size,
                            time_offset=start_time,
                            ylabel='',
                            bgcolor='k',
                            spike_color='#ff0000')
                plt.plot(tenv, spike_env * len(spike_trains), 'w-', alpha=0.75)

                # plot start events that are within plotting range
                sei = (event_start_times >= start_time) & (event_start_times <=
                                                           end_time)
                plt.plot(event_start_times[sei], [1] * sei.sum(),
                         'g^',
                         markersize=10)

                # plot end events that are within plotting range
                eei = (event_end_times >= start_time) & (event_end_times <=
                                                         end_time)
                plt.plot(event_end_times[eei], [1] * eei.sum(),
                         'bv',
                         markersize=10)

                plt.axis('tight')
                plt.yticks([])
                plt.ylabel('Spikes')
コード例 #16
0
ファイル: figure4.py プロジェクト: theunissenlab/zeebeez3
def plot_full_data(d, syllable_index):

    syllable_start = d['syllable_props'][syllable_index]['start_time'] - 0.020
    syllable_end = d['syllable_props'][syllable_index]['end_time'] + 0.030

    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize, facecolor='w')
    fig.subplots_adjust(top=0.95,
                        bottom=0.02,
                        right=0.97,
                        left=0.03,
                        hspace=0.20,
                        wspace=0.20)

    gs = plt.GridSpec(100, 100)
    left_width = 55
    top_height = 30
    middle_height = 40
    # bottom_height = 40
    top_bottom_sep = 20

    # plot the spectrogram
    ax = plt.subplot(gs[:top_height + 1, :left_width])
    spec = d['spec']
    spec[spec < np.percentile(spec, 15)] = 0
    plot_spectrogram(d['spec_t'],
                     d['spec_freq'] * 1e-3,
                     spec,
                     ax=ax,
                     colormap='SpectroColorMap',
                     colorbar=False,
                     ticks=True,
                     log=False,
                     dBNoise=None)
    plt.axvline(syllable_start,
                c='k',
                linestyle='--',
                linewidth=3.0,
                alpha=0.7)
    plt.axvline(syllable_end, c='k', linestyle='--', linewidth=3.0, alpha=0.7)
    plt.ylabel('Frequency (kHz)')
    plt.xlabel('Time (s)')

    # plot the LFPs
    sr = d['lfp_sample_rate']
    # lfp_mean = d['lfp'].mean(axis=0)
    lfp_mean = d['lfp'][2, :, :]
    lfp_t = np.arange(lfp_mean.shape[1]) / sr
    nelectrodes, nt = lfp_mean.shape
    gs_i = top_height + top_bottom_sep
    gs_e = gs_i + middle_height + 1

    ax = plt.subplot(gs[gs_i:gs_e, :left_width])

    voffset = 5
    for n in range(nelectrodes):
        plt.plot(lfp_t,
                 lfp_mean[nelectrodes - n - 1, :] + voffset * n,
                 'k-',
                 linewidth=3.0,
                 alpha=0.75)
    plt.axis('tight')
    ytick_locs = np.arange(nelectrodes) * voffset
    plt.yticks(ytick_locs, list(reversed(d['electrode_order'])))
    plt.ylabel('Electrode')
    plt.axvline(syllable_start,
                c='k',
                linestyle='--',
                linewidth=3.0,
                alpha=0.7)
    plt.axvline(syllable_end, c='k', linestyle='--', linewidth=3.0, alpha=0.7)
    plt.xlabel('Time (s)')

    # plot the PSTH
    """
    gs_i = gs_e + 5
    gs_e = gs_i + bottom_height + 1
    ax = plt.subplot(gs[gs_i:gs_e, :left_width])
    ncells = d['psth'].shape[0]
    plt.imshow(d['psth'], interpolation='nearest', aspect='auto', origin='upper', extent=(0, lfp_t.max(), ncells, 0),
               cmap=psth_colormap(noise_level=0.1))

    cell_i2e = d['cell_index2electrode']
    print 'cell_i2e=',cell_i2e
    last_electrode = cell_i2e[0]
    for k,e in enumerate(cell_i2e):
        if e != last_electrode:
            plt.axhline(k, c='k', alpha=0.5)
            last_electrode = e

    ytick_locs = list()
    for e in d['electrode_order']:
        elocs = np.array([k for k,el in enumerate(cell_i2e) if el == e])
        emean = elocs.mean()
        ytick_locs.append(emean+0.5)
    plt.yticks(ytick_locs, d['electrode_order'])
    plt.ylabel('Electrode')

    plt.axvline(syllable_start, c='k', linestyle='--', linewidth=3.0, alpha=0.7)
    plt.axvline(syllable_end, c='k', linestyle='--', linewidth=3.0, alpha=0.7)
    """

    # plot the biosound properties
    sprops = d['syllable_props'][syllable_index]
    aprops = USED_ACOUSTIC_PROPS

    vals = [sprops[a] for a in aprops]
    ax = plt.subplot(gs[:top_height, (left_width + 5):])
    plt.axhline(0, c='k')
    for k, (aprop, v) in enumerate(zip(aprops, vals)):
        bx = k
        rgb = np.array(ACOUSTIC_PROP_COLORS_BY_TYPE[aprop]).astype('int')
        clr_hex = to_hex(*rgb)
        plt.bar(bx, v, color=clr_hex, alpha=0.7)

    # plt.bar(range(len(aprops)), vals, color='#c0c0c0')
    plt.axis('tight')
    plt.ylim(-1.5, 1.5)
    plt.xticks(np.arange(len(aprops)) + 0.5,
               [ACOUSTIC_PROP_NAMES[aprop] for aprop in aprops],
               rotation=90)
    plt.ylabel('Z-score')

    # plot the LFP power spectra
    gs_i = top_height + top_bottom_sep
    gs_e = gs_i + middle_height + 1

    f = d['psd_freq']
    ax = plt.subplot(gs[gs_i:gs_e, (left_width + 5):])
    plt.imshow(sprops['lfp_psd'],
               interpolation='nearest',
               aspect='auto',
               origin='upper',
               extent=(f.min(), f.max(), nelectrodes, 0),
               cmap=plt.cm.viridis,
               vmin=-2.,
               vmax=2.)
    plt.colorbar(label='Z-scored Log Power')
    plt.xlabel('Frequency (Hz)')
    plt.yticks(np.arange(nelectrodes) + 0.5, d['electrode_order'])
    plt.ylabel('Electrode')

    # fname = os.path.join(get_this_dir(), 'figure.svg')
    # plt.savefig(fname, facecolor='w', edgecolor='none')

    plt.show()
コード例 #17
0
def compare_stims(exp_file, stim_file, seg_file, bs_file):
    exp = Experiment.load(exp_file, stim_file)
    spec_colormap()

    all_stim_ids = list()
    for ekey in list(exp.epoch_table.keys()):
        etable = exp.epoch_table[ekey]
        stim_ids = etable['id'].unique()
        all_stim_ids.extend(stim_ids)
    stim_ids = np.unique(all_stim_ids)

    # read the manual segmentation data
    man_segs = dict()
    with open(seg_file, 'r') as f:
        lns = f.readlines()
        for ln in lns:
            x = ln.split(",")
            stim_id = int(x[0])
            stimes = [float(f) for f in x[1:]]
            assert len(
                stimes) % 2 == 0, "Uneven # of syllables for stim %d" % stim_id
            ns = len(stimes) / 2
            man_segs[stim_id] = np.array(stimes).reshape([ns, 2])

    # get the automated segmentation
    algo_segs = dict()
    bst = BiosoundTransform.load(bs_file)
    for stim_id in stim_ids:
        i = bst.stim_df.stim_id == stim_id
        d = list(zip(bst.stim_df[i].start_time, bst.stim_df[i].end_time))
        d.sort(key=operator.itemgetter(0))
        algo_segs[stim_id] = np.array(d)

    for stim_id in stim_ids:

        # get the raw sound pressure waveform
        wave = exp.sound_manager.reconstruct(stim_id)
        wave_sr = wave.samplerate
        wave = np.array(wave).squeeze()
        wave_t = np.arange(len(wave)) / wave_sr

        # compute the amplitude envelope
        amp_env = temporal_envelope(wave, wave_sr, cutoff_freq=200.0)
        amp_env /= amp_env.max()

        # compute the spectrogram
        spec_sr = 1000.
        spec_t, spec_freq, spec, spec_rms = gaussian_stft(wave,
                                                          float(wave_sr),
                                                          0.007,
                                                          1. / spec_sr,
                                                          min_freq=300.,
                                                          max_freq=8000.)
        spec = np.abs(spec)**2
        log_transform(spec, dbnoise=70)

        figsize = (23, 12)
        plt.figure(figsize=figsize)

        ax = plt.subplot(2, 1, 1)
        plot_spectrogram(spec_t,
                         spec_freq,
                         spec,
                         ax=ax,
                         colorbar=False,
                         fmin=300.,
                         fmax=8000.,
                         colormap='SpectroColorMap')
        for k, (stime, etime) in enumerate(algo_segs[stim_id]):
            plt.axvline(stime, c='k', linewidth=2.0, alpha=0.8)
            plt.axvline(etime, c='k', linewidth=2.0, alpha=0.8)
            plt.text(stime, 7000., str(k + 1), fontsize=14)
            plt.text(etime, 7000., str(k + 1), fontsize=14)
        plt.title('Algorithm Segmentation')
        plt.axis('tight')

        ax = plt.subplot(2, 1, 2)
        plot_spectrogram(spec_t,
                         spec_freq,
                         spec,
                         ax=ax,
                         colorbar=False,
                         fmin=300.,
                         fmax=8000.,
                         colormap='SpectroColorMap')
        for k, (stime, etime) in enumerate(man_segs[stim_id]):
            plt.axvline(stime, c='k', linewidth=2.0, alpha=0.8)
            plt.axvline(etime, c='k', linewidth=2.0, alpha=0.8)
            plt.text(stime, 7000., str(k + 1), fontsize=14)
            plt.text(etime, 7000., str(k + 1), fontsize=14)
        plt.title('Manual Segmentation')
        plt.axis('tight')

        plt.show()
コード例 #18
0
ファイル: ica.py プロジェクト: theunissenlab/zeebeez3
    def plot_slice(self,
                   start_time,
                   end_time,
                   exp=None,
                   seg=None,
                   absmax=None,
                   colors=None,
                   output_dir=None):

        #get stimulus spectrogram
        stim_spec_t, stim_spec_freq, stim_spec = exp.get_spectrogram_slice(
            seg, start_time, end_time)

        fig = plt.figure(figsize=(24.0, 13.5), facecolor='gray')
        fig.subplots_adjust(top=0.98,
                            bottom=0.02,
                            right=0.98,
                            left=0.02,
                            hspace=0.05)
        gs = plt.GridSpec(100, 1)

        #plot the stimulus spectrogram
        ax = plt.subplot(gs[:10, 0])
        ax.set_axis_bgcolor('black')
        plot_spectrogram(stim_spec_t,
                         stim_spec_freq,
                         stim_spec,
                         ax=ax,
                         colormap=plt.cm.afmhot,
                         colorbar=False,
                         fmax=8000.0)
        plt.axis('tight')
        plt.ylabel('')
        plt.yticks([])
        plt.xticks([])

        si = int(self.lfp_sample_rate * start_time)
        ei = int(self.lfp_sample_rate * end_time)
        nt = ei - si
        ncomps = self.components.shape[0]

        if absmax is None:
            absmax = np.abs(self.X[si:ei, :]).max()
        padding = 0.05 * absmax
        plot_height = (2 * absmax + padding) * ncomps

        t = np.linspace(start_time, end_time, nt)
        ax = plt.subplot(gs[10:, 0])
        ax.set_axis_bgcolor('black')
        for k in range(ncomps):

            offset = absmax + padding + k * (padding + 2 * absmax)

            x = self.X[si:ei, k] + offset
            c = 'k'
            if colors is not None:
                c = colors[k, :]

            plt.plot(t, x, '-', c=c, linewidth=2.0)

        plt.axis('tight')
        plt.ylim(0, plot_height)
        plt.yticks([])

        if output_dir is not None:
            ofile = os.path.join(
                output_dir, 'ica_%0.6f_%0.6f.png' % (start_time, end_time))
            plt.savefig(ofile, facecolor=fig.get_facecolor(), edgecolor='none')
        else:
            plt.show()