Beispiel #1
0
def draw_preds(agg, data_dir='/auto/tdrive/mschachter/data'):

    df_best = pd.read_csv('/auto/tdrive/mschachter/data/aggregate/rnn_best.csv')

    # print 'min_err=', df_best.err.min()

    bird = 'GreBlu9508M'
    block = 'Site2'
    segment = 'Call2'
    hemi = 'L'
    best_md5 = 'db5d601c8af2621f341d8e4cbd4108c1'

    fname = '%s_%s_%s_%s' % (bird, block, segment, hemi)
    preproc_file = os.path.join(data_dir, bird, 'preprocess', 'RNNPreprocess_%s.h5' % fname)
    rnn_file = os.path.join(data_dir, bird, 'rnn', 'RNNLFPEncoder_%s_%s.h5'  %(fname, best_md5))
    pred_file = os.path.join(data_dir, bird, 'rnn', 'RNNLFPEncoderPred_%s_%s.h5'  %(fname, best_md5))
    linear_file = os.path.join(data_dir, bird, 'rnn', 'LFPEnvelope_%s.h5' % fname)

    if not os.path.exists(pred_file):
        rpt = RNNPreprocessTransform.load(preproc_file)
        rpt.write_pred_file(rnn_file, pred_file, lfp_enc_file=linear_file)
    visualize_pred_file(pred_file, 3, 0, 2.4, electrode_order=ROSTRAL_CAUDAL_ELECTRODES_LEFT[::-1], dbnoise=4.0)
    leg = custom_legend(['k', 'r', 'b'], ['Real', 'Linear', 'RNN'])
    plt.legend(handles=leg)

    fname = os.path.join(get_this_dir(), 'rnn_preds.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    plt.show()
Beispiel #2
0
def plot_raw_dists(data_dir='/auto/tdrive/mschachter/data'):
    edata = pd.read_csv(os.path.join(data_dir, 'aggregate', 'electrode_data+dist.csv'))

    i = edata.bird != 'BlaBro09xxF'
    edata = edata[i]
    birds = edata.bird.unique()

    clrs = ['r', 'g', 'b']

    fig = plt.figure()
    fig.subplots_adjust(top=0.99, bottom=0.01, right=0.99, left=0.01, hspace=0, wspace=0)

    for k,b in enumerate(birds):
        i = (edata.bird == b) & ~np.isnan(edata.dist_midline) & ~np.isnan(edata.dist_l2a)
        x = edata[i].dist_midline.values
        y = edata[i].dist_l2a.values
        if b == 'GreBlu9508M':
            y *= 4
        reg = edata[i].region.values

        for xx,yy,r in zip(x, y, reg):
            plt.plot(xx, yy, 'o', markersize=8, c=clrs[k], alpha=0.4)
            plt.text(xx, yy, r, fontsize=8)
        plt.axis('tight')
        plt.xlabel('Dist from Midline (mm)')
        plt.ylabel('Dist from LH (mm)')

    leg = custom_legend(clrs, birds)
    plt.legend(handles=leg)
    plt.show()
Beispiel #3
0
    def _plot_freqs(pdata, ax):
        plt.sca(ax)

        nsamps_lfp = len(pdata['lfp'])
        lkrat_lfp_mean = pdata['lfp'].mean(axis=0)
        lkrat_lfp_std = pdata['lfp'].std(axis=0, ddof=1) / np.sqrt(nsamps_lfp)

        nsamps_spike = len(pdata['spike'])
        lkrat_spike_mean = pdata['spike'].mean(axis=0)
        lkrat_spike_std = pdata['spike'].std(axis=0, ddof=1) / np.sqrt(nsamps_spike)

        if pdata['aprop'] != 'category':
            plt.axhline(1.0, c='k', linestyle='dashed', alpha=0.7, linewidth=2.0)
        plt.errorbar(pdata['freqs'], lkrat_lfp_mean, yerr=lkrat_lfp_std, c=COLOR_BLUE_LFP, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0)
        plt.errorbar(pdata['freqs']+2., lkrat_spike_mean, yerr=lkrat_spike_std, c=COLOR_YELLOW_SPIKE, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0)

        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Normalized LR')
        leg = custom_legend([COLOR_BLUE_LFP, COLOR_YELLOW_SPIKE], ['LFP', 'Spike'])
        plt.legend(handles=leg, fontsize='x-small')
        plt.title(pdata['aprop'])
        plt.axis('tight')
        if pdata['aprop'] != 'category':
            plt.ylim(0, 5)
        else:
            plt.axis('tight')
Beispiel #4
0
def draw_pairwise_weights_vs_dist(agg):

    wdf = export_pairwise_decoder_weights(agg)

    r2_thresh = 0.20
    # aprops = ALL_ACOUSTIC_PROPS
    aprops = ['meanspect', 'stdspect', 'sal', 'maxAmp']
    aprop_clrs = {'meanspect':'#FF8000', 'stdspect':'#FFBF00', 'sal':'#088A08', 'maxAmp':'k'}

    print wdf.decomp.unique()

    # get scatter data for weights vs distance
    scatter_data = dict()
    for decomp in ['full_psds+full_cfs', 'spike_rate+spike_sync']:
        i = wdf.decomp == decomp
        assert i.sum() > 0

        df = wdf[i]
        for n,aprop in enumerate(aprops):
            ii = (df.r2 > r2_thresh) & (df.aprop == aprop) & ~np.isnan(df.dist) & ~np.isinf(df.dist)
            d = df.dist[ii].values
            w = df.w[ii].values
            scatter_data[(decomp, aprop)] = (d, w)

    decomp_labels = {'full_psds+full_cfs':'LFP Pairwise Correlations', 'spike_rate+spike_sync':'Spike Synchrony'}

    figsize = (14, 5)
    fig = plt.figure(figsize=figsize)
    fig.subplots_adjust(left=0.10, right=0.98)

    for k,decomp in enumerate(['full_psds+full_cfs', 'spike_rate+spike_sync']):
        ax = plt.subplot(1, 2, k+1)
        for aprop in aprops:
            d,w = scatter_data[(decomp, aprop)]

            wz = w**2
            wz /= wz.std(ddof=1)

            xcenter, ymean, yerr, ymean_cs = compute_mean_from_scatter(d, wz, bins=5, num_smooth_points=300)
            # plt.plot(xcenter, ymean, '-', c=aprop_clrs[aprop], linewidth=7.0, alpha=0.7)
            plt.errorbar(xcenter, ymean, linestyle='-', yerr=yerr, c=aprop_clrs[aprop], linewidth=9.0, alpha=0.7,
                         elinewidth=8., ecolor='#d8d8d8', capsize=0.)
            plt.axis('tight')

        plt.ylabel('Decoder Weight Effect')
        plt.xlabel('Pairwise Distance (mm)')
        plt.title(decomp_labels[decomp])

        aprop_lbls = [ACOUSTIC_PROP_NAMES[aprop] for aprop in aprops]
        aclrs = [aprop_clrs[aprop] for aprop in aprops]
        leg = custom_legend(aclrs, aprop_lbls)
        plt.legend(handles=leg, loc='lower left')

    fname = os.path.join(get_this_dir(), 'pairwise_decoder_effect_vs_dist.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    plt.show()
Beispiel #5
0
def draw_spike_rate_vs_power(data_dir='/auto/tdrive/mschachter/data'):

    # read PairwiseCF file
    pcf_file = os.path.join(data_dir, 'aggregate', 'pairwise_cf.h5')
    pcf = AggregatePairwiseCF.load(pcf_file)

    # concatenate the lfp and spike psds
    nfreqs = len(pcf.freqs)
    lfp_and_spike_psds = np.zeros([len(pcf.df), nfreqs*2 + 1])
    nz = np.zeros(len(pcf.df), dtype='bool')
    for k,(lfp_index,spike_index) in enumerate(zip(pcf.df['lfp_index'], pcf.df['spike_index'])):
        lpsd = pcf.lfp_psds[lfp_index, :]
        spsd = pcf.spike_psds[spike_index, :]
        srate,sstd = pcf.spike_rates[spike_index, :]
        nz[k] = np.abs(lpsd).sum() > 0 and np.abs(spsd).sum() > 0
        lfp_and_spike_psds[k, :nfreqs] = lpsd
        lfp_and_spike_psds[k, nfreqs:-1] = spsd
        lfp_and_spike_psds[k, -1] = np.log(srate)

    # throw some bad data points out
    lfp_sum = lfp_and_spike_psds[:, :nfreqs].sum(axis=1)
    spike_sum =  lfp_and_spike_psds[:, nfreqs:-1].sum(axis=1)
    nz = ~np.isinf(lfp_and_spike_psds[:, -1]) & (lfp_sum > 0) & (spike_sum > 0) & ~np.isnan(spike_sum) & ~np.isnan(lfp_sum)
    print '# of good data points: %d out of %d' % (nz.sum(), lfp_and_spike_psds.shape[0])

    # zscore the concatenated matrix
    lfp_and_spike_psds = lfp_and_spike_psds[nz, :]
    lfp_and_spike_psds -= lfp_and_spike_psds.mean(axis=0)
    lfp_and_spike_psds /= lfp_and_spike_psds.std(axis=0, ddof=1)

    # compute CC between spike rate and power
    lfp_spike_rate_cc = np.zeros(len(pcf.freqs))
    spike_spike_rate_cc = np.zeros(len(pcf.freqs))

    for k,f in enumerate(pcf.freqs):
        lfp_spike_rate_cc[k] = np.corrcoef(lfp_and_spike_psds[:, k], lfp_and_spike_psds[:, -1])[0, 1]
        spike_spike_rate_cc[k] = np.corrcoef(lfp_and_spike_psds[:, k+len(pcf.freqs)], lfp_and_spike_psds[:, -1])[0, 1]

    fig = plt.figure(figsize=(12, 7))
    plt.axhline(0, c='k')
    plt.plot(pcf.freqs, lfp_spike_rate_cc, '-', linewidth=7.0, alpha=0.7, c=COLOR_BLUE_LFP)
    plt.plot(pcf.freqs, spike_spike_rate_cc, '-', linewidth=7.0, alpha=0.7, c=COLOR_YELLOW_SPIKE)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Correlation Coefficient')
    plt.title('CC Between Log Spike Rate and Spectral Power')
    plt.axis('tight')
    plt.ylim(-0.1, 0.6)
    leg = custom_legend([COLOR_BLUE_LFP, COLOR_YELLOW_SPIKE], ['LFP PSD', 'Spike PSD'])
    plt.legend(handles=leg, fontsize='x-small')

    fname = os.path.join(get_this_dir(), 'power_vs_rate.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    plt.show()
Beispiel #6
0
def draw_rate_weight_by_dist(agg):

    wdf = get_encoder_weight_data_for_psd(agg, include_sync=False, write_to_file=False)

    def exp_func(_x, _a, _b, _c):
        return _a * np.exp(-_b * _x) + _c

    # plot the average encoder weight as a function of distance from predicted electrode
    freqs = [15, 55, 135]
    band_labels = ['0-30Hz', '30-80Hz', '80-190Hz']
    clrs = {15:'k', 55:'r', 135:'b'}
    for f in freqs:
        i = ~np.isnan(wdf.dist_from_electrode.values) & (wdf.r2 > 0.05) & (wdf.dist_from_electrode > 0) & (wdf.f == f)

        x = wdf.dist_from_electrode[i].values
        y = (wdf.w[i].values)**2

        popt, pcov = curve_fit(exp_func, x, y)

        ypred = exp_func(x, *popt)
        ysqerr = (y - ypred)**2
        sstot = np.sum((y - y.mean())**2)
        sserr = np.sum(ysqerr)
        r2 = 1. - (sserr / sstot)

        print 'f=%dHz, a=%0.6f, space_const=%0.6f, bias=%0.6f, r2=%0.2f: ' % (f, popt[0], 1. / popt[1], popt[2], r2)

        npts = 100
        xreg = np.linspace(x.min()+1e-1, x.max()-1e-1, npts)
        yreg = exp_func(xreg, *popt)

        # approximate sqrt(err) with a cubic spline for plotting
        err_xcenter, err_ymean, err_yerr, err_ymean_cs = compute_mean_from_scatter(x, np.sqrt(ysqerr), bins=4,
                                                                                   num_smooth_points=npts)
        # yerr = err_ymean_cs(xreg)

        # plt.plot(x, y, 'ko', alpha=0.7)
        plt.plot(xreg, yreg, clrs[f], alpha=0.7, linewidth=5.0)
        # plt.errorbar(xreg, yreg, yerr=err_ymean, c=clrs[f], alpha=0.7, linewidth=5.0, ecolor='#b5b5b5')
        # plt.show()

        # plot_mean_from_scatter(x, y, bins=4, num_smooth_points=200, alpha=0.7, color=clrs[f], ecolor='#b5b5b5', bin_by_quantile=False)

    plt.xlabel('Distance From Predicted Electrode (mm)')
    plt.ylabel('Encoder Weight^2')
    plt.axis('tight')
    freq_clrs = [clrs[f] for f in freqs]
    leg = custom_legend(colors=freq_clrs, labels=band_labels)
    plt.legend(handles=leg, loc='lower right')
Beispiel #7
0
    def test_cross_psd(self):

        np.random.seed(1234567)
        sr = 1000.0
        dur = 1.0
        nt = int(dur * sr)
        t = np.arange(nt) / sr

        # create a simple signal
        freqs = list()
        freqs.extend(np.arange(8, 12))
        freqs.extend(np.arange(60, 71))
        freqs.extend(np.arange(130, 151))

        s1 = np.zeros([nt])
        for f in freqs:
            s1 += np.sin(2 * np.pi * f * t)
        s1 /= s1.max()

        # create a noise corrupted, bandpassed filtered version of s1
        noise = np.random.randn(nt) * 1e-1
        # s2 = convolve1d(s1, filt, mode='mirror') + noise
        s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.)
        s2 /= s2.max()
        s2 += noise

        # compute the signal's power spectrums
        welch_freq1, welch_psd1 = welch(s1, fs=sr)
        welch_freq2, welch_psd2 = welch(s2, fs=sr)

        welch_psd_max = max(welch_psd1.max(), welch_psd2.max())
        welch_psd1 /= welch_psd_max
        welch_psd2 /= welch_psd_max

        # compute the auto-correlation functions
        lags = np.arange(-200, 201)
        acf1 = correlation_function(s1, s1, lags, normalize=True)
        acf2 = correlation_function(s2, s2, lags, normalize=True)

        # compute the cross correlation functions
        cf12 = correlation_function(s1, s2, lags, normalize=True)
        coh12 = coherency(s1,
                          s2,
                          lags,
                          window_fraction=0.75,
                          noise_floor_db=100.)

        # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power
        # spectrum for some numerical reason
        shift_lags = fftshift(lags)
        if len(lags) % 2 == 1:
            # shift zero from end of shift_lags to beginning
            shift_lags = np.roll(shift_lags, 1)
        acf1_shift = correlation_function(s1, s1, shift_lags)
        acf2_shift = correlation_function(s2, s2, shift_lags)

        # compute the power spectra from the auto-spectra
        ps1 = fft(acf1_shift)
        ps1_freq = fftfreq(len(acf1), d=1.0 / sr)
        fi = ps1_freq > 0
        ps1 = ps1[fi]
        assert np.sum(
            np.abs(ps1.imag) > 1e-8
        ) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum(
            np.abs(ps1.imag) > 1e-8)
        ps1_auto = np.abs(ps1.real)
        ps1_auto_freq = ps1_freq[fi]

        ps2 = fft(acf2_shift)
        ps2_freq = fftfreq(len(acf2), d=1.0 / sr)
        fi = ps2_freq > 0
        ps2 = ps2[fi]
        assert np.sum(np.abs(ps2.imag) > 1e-8
                      ) == 0, "Nonzero imaginary part for fft(acf2)"
        ps2_auto = np.abs(ps2.real)
        ps2_auto_freq = ps2_freq[fi]

        assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto"
        assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto"

        # compute the cross spectral density from the correlation function
        cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True)
        psd12 = fft(cf12_shift)
        psd12_freq = fftfreq(len(cf12_shift), d=1.0 / sr)
        fi = psd12_freq > 0

        psd12 = np.abs(psd12[fi])
        psd12_freq = psd12_freq[fi]

        # compute the cross spectral density from the power spectra
        psd12_welch = welch_psd1 * welch_psd2
        psd12_welch /= psd12_welch.max()

        # compute the coherence from the cross spectral density
        cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \
            coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True)

        coh12_freqspace /= np.abs(coh12_freqspace).max()

        # weight the coherence by one minus the normalized standard deviation
        coherence_std = np.sqrt(coherence_var)
        # cweight = coherence_std / coherence_std.sum()
        # coherence_weighted = (1.0 - cweight)*coherence
        coherence_weighted = coherence - coherence_std
        coherence_weighted[coherence_weighted < 0] = 0

        # compute the coherence from the fft of the coherency
        coherence2 = fft(fftshift(coh12))
        coherence2_freq = fftfreq(len(coherence2), d=1.0 / sr)
        fi = coherence2_freq > 0
        coherence2 = np.abs(coherence2[fi])
        coherence2_freq = coherence2_freq[fi]
        """
        plt.figure()
        ax = plt.subplot(2, 1, 1)
        plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75)
        plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9)
        plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75)
        plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75)
        plt.axis('tight')
        plt.legend(['denom', '12', '1', '2'])

        ax = plt.subplot(2, 1, 2)
        plt.plot(psd12_freq, coherence, 'b-')
        plt.axis('tight')
        plt.show()
        """

        # normalize the cross-spectral density and power spectra
        psd12 /= psd12.max()
        ps_auto_max = max(ps1_auto.max(), ps2_auto.max())
        ps1_auto /= ps_auto_max
        ps2_auto /= ps_auto_max

        # make some plots
        plt.figure()

        nrows = 2
        ncols = 2

        # plot the signals
        ax = plt.subplot(nrows, ncols, 1)
        plt.plot(t, s1, 'k-', linewidth=2.0)
        plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0)
        plt.xlabel('Time (s)')
        plt.ylabel('Signal')
        plt.axis('tight')

        # plot the spectra
        ax = plt.subplot(nrows, ncols, 2)
        plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85)
        plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85)
        plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0)
        plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75)
        plt.axis('tight')

        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power')

        # plot the correlation functions
        ax = plt.subplot(nrows, ncols, 3)
        plt.axhline(0, c='k')
        plt.plot(lags, acf1, 'k-', linewidth=2.0)
        plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0)
        plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0)
        plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75)
        plt.plot(coh12_freqspace_t * 1e3,
                 coh12_freqspace,
                 'm-',
                 linewidth=1.0,
                 alpha=0.95)
        plt.xlabel('Lag (ms)')
        plt.ylabel('Correlation Function')
        plt.axis('tight')
        plt.ylim(-0.5, 1.0)
        handles = custom_legend(['k', 'r', 'g', 'b', 'c'],
                                ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f'])
        plt.legend(handles=handles)

        # plot the cross spectral density
        ax = plt.subplot(nrows, ncols, 4)
        handles = custom_legend(['g', 'k', 'b'],
                                ['CSD', 'Coherence', 'Weighted'])
        plt.axhline(0, c='k')
        plt.axhline(1, c='k')
        plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0)
        plt.errorbar(cfreq,
                     coherence,
                     yerr=np.sqrt(coherence_var),
                     fmt='k-',
                     ecolor='r',
                     linewidth=3.0,
                     elinewidth=5.0,
                     alpha=0.8)
        plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75)
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Cross-spectral Density/Coherence')
        plt.legend(handles=handles)
        """
        plt.figure()
        plt.axhline(0, c='k')
        plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0)
        plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75)
        plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95)
        plt.xlabel('Lag (ms)')
        plt.ylabel('Correlation Function')
        plt.axis('tight')
        plt.ylim(-0.5, 1.0)
        handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f'])
        plt.legend(handles=handles)
        """

        plt.show()
Beispiel #8
0
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, '%s.h5' % bird)
        stim_file = os.path.join(bird_dir, 'stims.h5')
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable['id'] == stim_id
    stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values)
    stim_times.sort(key=operator.itemgetter(0))
    stim_times = np.array(stim_times)

    stim_durs = stim_times[:, 1] - stim_times[:, 0]
    stim_dur = stim_durs.min()

    # aggregate the LFPs across trials
    lfps = list()
    sample_rate = None
    electrode_indices = None
    for start_time,end_time in stim_times:

        # get a slice of the LFP
        lfp_data = exp.get_lfp_slice(seg, start_time, end_time)
        electrode_indices,the_lfps,sample_rate = lfp_data[hemi]

        stim_dur_i = int(stim_dur*sample_rate)
        lfps.append(the_lfps[:, :stim_dur_i])
    lfps = np.array(lfps)

    # rescale the LFPs to they are in uV
    lfps *= 1e6

    # get the log spectrogram of the stimulus
    start_time_0 = stim_times[0][0]
    end_time_0 = stim_times[0][1]
    stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time_0, end_time_0)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002*sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start,syllable_end,syllable_max_amp = events[syllable_index]
    syllable_start -= 0.005
    syllable_end += 0.010
    amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min()
    last_syllable_end = events[-1, 1] + 0.025

    i1 = electrode_indices.index(e1)
    i2 = electrode_indices.index(e2)

    lfp1 = lfps[:, i1, :]
    lfp2 = lfps[:, i2, :]

    # zscore the LFP
    lfp1 -= lfp1.mean()
    lfp1 /= lfp1.std(ddof=1)
    lfp2 -= lfp2.mean()
    lfp2 /= lfp2.std(ddof=1)

    ntrials,nelectrodes,nt = lfps.shape
    ntrials_to_plot = 5

    t = np.arange(nt) / sample_rate

    # plot the stimulus and raw LFP for two electrodes
    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)
    gs = plt.GridSpec(3, 100)

    ax = plt.subplot(gs[0, :60])
    plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000,
                     colormap='SpectroColorMap', colorbar=False)
    # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # plot the first LFP (all trials)
    ax = plt.subplot(gs[1, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp1[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp1.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e1)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    ax = plt.subplot(gs[2, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp2[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp2.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e2)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # restrict the lfps to a single syllable
    print 'syllable_start=%f, syllable_end=%f' % (syllable_start, syllable_end)
    syllable_si = int(syllable_start*sample_rate)
    syllable_ei = int(syllable_end*sample_rate)
    lfp1 = lfp1[:, syllable_si:syllable_ei]
    lfp2 = lfp2[:, syllable_si:syllable_ei]

    # compute the trial averaged and mean subtracted lfps
    lfp1_mean,lfp1_ms = compute_avg_and_ms(lfp1)
    lfp2_mean,lfp2_ms = compute_avg_and_ms(lfp2)

    psd_stats = get_psd_stats(bird, block, segment, hemi)
    freqs, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total = compute_spectra_and_coherence_single_electrode(lfp1, lfp2, sample_rate, e1, e2, psd_stats=psd_stats)
    lags_ms = get_lags_ms(sample_rate)

    lfp_absmax = max(np.abs(lfp1_mean).max(), np.abs(lfp2_mean).max())
    lfp_ms_absmax = max(np.abs(lfp1_ms).max(), np.abs(lfp2_ms).max())

    ax = plt.subplot(gs[0, 65:80])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_mean, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_mean, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg LFP')
    leg = custom_legend(['k', '#c0c0c0'], ['E%d' % e1, 'E%d' % e2])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_absmax, lfp_absmax)

    ax = plt.subplot(gs[0, 85:])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_ms, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_ms, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub LFP')
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_ms_absmax, lfp_ms_absmax)

    psd_max = max(psd1.max(), psd2.max(), psd1_ms.max(), psd2_ms.max())

    ax = plt.subplot(gs[1, 65:80])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[1, 85:])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1_ms, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2_ms, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[2, 65:80])
    plt.axhline(0, c='k')
    plt.axvline(0, c='k')
    plt.plot(lags_ms, c12_total, 'k-', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12, '-', c='r', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12_nonlocked, '-', c='b', linewidth=3.0, alpha=0.75)
    plt.xlabel('Lags (ms)')
    plt.ylabel('Coherency')
    leg = custom_legend(['k', 'r', 'b'], ['Raw', 'Trial-avg', 'Mean-sub'])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-0.2, 0.3)

    fname = os.path.join(get_this_dir(), 'raw+coherency.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Beispiel #9
0
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, trial, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, '%s.h5' % bird)
        stim_file = os.path.join(bird_dir, 'stims.h5')
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable['id'] == stim_id
    stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values)
    stim_times.sort(key=operator.itemgetter(0))

    start_time,end_time = stim_times[trial]
    stim_dur = float(end_time - start_time)

    # get a slice of the LFP
    lfp_data = exp.get_lfp_slice(seg, start_time, end_time, zscore=True)
    electrode_indices,lfps,sample_rate = lfp_data[hemi]

    # get the log spectrogram of the stimulus
    stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time, end_time)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002*sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start,syllable_end,syllable_max_amp = events[syllable_index]
    amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min()
    last_syllable_end = events[-1, 1] + 0.025

    i1 = electrode_indices.index(e1)
    i2 = electrode_indices.index(e2)

    lfp1 = lfps[i1, :]
    lfp2 = lfps[i2, :]

    t = np.arange(len(lfp1)) / sample_rate
    legend = ['E%d' % e1, 'E%d' % e2]

    if hemi == 'L':
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT
    else:
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT

    # get the power spectrum stats for this site
    psd_stats = get_psd_stats(bird, block, segment, hemi)

    # compute the power spectra and cross coherence for all electrodes
    lags_ms = get_lags_ms(sample_rate)
    spectra,cross_mat = compute_spectra_and_coherence_multi_electrode_single_trial(lfps, sample_rate, electrode_indices, electrode_order, psd_stats=psd_stats)

    # plot the stimulus and raw LFP for two electrodes
    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)

    ax = plt.subplot(2, 1, 1)
    plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000,
                     colormap='SpectroColorMap', colorbar=False)
    # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    ax = plt.subplot(2, 1, 2)
    plt.plot(t, lfp1, 'b-', linewidth=3.0)
    plt.plot(t, lfp2, 'r-', linewidth=3.0, alpha=0.7)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('LFP (z-scored)')
    plt.legend(legend)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    fname = os.path.join(get_this_dir(), 'raw.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # restrict the lfps to a single syllable
    syllable_si = int(syllable_start*sample_rate)
    syllable_ei = int(syllable_end*sample_rate)
    lfp1 = lfp1[syllable_si:syllable_ei]
    lfp2 = lfp2[syllable_si:syllable_ei]

    # plot the two power spectra
    psd_ub = 6
    psd_lb = 0
    i1 = electrode_order.index(e1)
    i2 = electrode_order.index(e2)

    a1 = spectra[i1, :]
    a2 = spectra[i2, :]
    freqs = get_freqs(sample_rate)

    figsize = (10.0, 4.0)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.90, bottom=0.10, left=0.10, right=0.99, hspace=0.10)

    ax = plt.subplot(1, 2, 1)
    plt.plot(freqs, a1, 'b-', linewidth=3.0)
    plt.plot(freqs, a2, 'r-', linewidth=3.0)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power (z-scored)')
    plt.title('Power Spectrum')
    handles = custom_legend(['b', 'r'], legend)
    plt.legend(handles=handles, fontsize='small')
    plt.axis('tight')
    plt.ylim(psd_lb, psd_ub)

    # plot the coherency
    cf_lb = -0.1
    cf_ub = 0.3
    coh = cross_mat[i1, i2, :]
    ax = plt.subplot(1, 2, 2)
    plt.axhline(0, c='k')
    plt.axvline(0, c='k')
    plt.plot(lags_ms, coh, 'g-', linewidth=3.0)
    plt.xlabel('Frequency (Hz)')
    plt.title('Coherency')
    plt.axis('tight')
    plt.ylim(cf_lb, cf_ub)

    fname = os.path.join(get_this_dir(), 'auto+cross.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # compute all cross and auto-correlations
    nelectrodes = len(electrode_order)

    # make a plot
    figsize = (24.0, 13.5)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)

    gs = plt.GridSpec(nelectrodes, nelectrodes)
    for i in range(nelectrodes):
        for j in range(i+1):
            ax = plt.subplot(gs[i, j])
            plt.axhline(0, c='k')
            plt.axvline(0, c='k')

            _e1 = electrode_order[i]
            _e2 = electrode_order[j]

            clr = 'k'
            if i == j:
                if _e1 == e1:
                    clr = 'b'
                elif _e1 == e2:
                    clr = 'r'
            else:
                if _e1 == e1 and _e2 == e2:
                    clr = 'g'

            if i == j:
                plt.plot(freqs, spectra[i, :], '-', c=clr, linewidth=2.0)
            else:
                plt.plot(lags_ms, cross_mat[i, j, :], '-', c=clr, linewidth=2.0)
            plt.xticks([])
            plt.yticks([])
            plt.axis('tight')
            if i != j:
                plt.ylim(cf_lb, cf_ub)
            else:
                plt.axhline(0, c='k')
                plt.ylim(psd_lb, psd_ub)

            if j == 0:
                plt.ylabel('E%d' % electrode_order[i])
            if i == nelectrodes-1:
                plt.xlabel('E%d' % electrode_order[j])

    fname = os.path.join(get_this_dir(), 'cross_all.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Beispiel #10
0
    def test_cross_psd(self):

        np.random.seed(1234567)
        sr = 1000.0
        dur = 1.0
        nt = int(dur*sr)
        t = np.arange(nt) / sr

        # create a simple signal
        freqs = list()
        freqs.extend(np.arange(8, 12))
        freqs.extend(np.arange(60, 71))
        freqs.extend(np.arange(130, 151))

        s1 = np.zeros([nt])
        for f in freqs:
            s1 += np.sin(2*np.pi*f*t)
        s1 /= s1.max()

        # create a noise corrupted, bandpassed filtered version of s1
        noise = np.random.randn(nt)*1e-1
        # s2 = convolve1d(s1, filt, mode='mirror') + noise
        s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.)
        s2 /= s2.max()
        s2 += noise

        # compute the signal's power spectrums
        welch_freq1,welch_psd1 = welch(s1, fs=sr)
        welch_freq2,welch_psd2 = welch(s2, fs=sr)

        welch_psd_max = max(welch_psd1.max(), welch_psd2.max())
        welch_psd1 /= welch_psd_max
        welch_psd2 /= welch_psd_max

        # compute the auto-correlation functions
        lags = np.arange(-200, 201)
        acf1 = correlation_function(s1, s1, lags, normalize=True)
        acf2 = correlation_function(s2, s2, lags, normalize=True)

        # compute the cross correlation functions
        cf12 = correlation_function(s1, s2, lags, normalize=True)
        coh12 = coherency(s1, s2, lags, window_fraction=0.75, noise_floor_db=100.)

        # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power
        # spectrum for some numerical reason
        shift_lags = fftshift(lags)
        if len(lags) % 2 == 1:
            # shift zero from end of shift_lags to beginning
            shift_lags = np.roll(shift_lags, 1)
        acf1_shift = correlation_function(s1, s1, shift_lags)
        acf2_shift = correlation_function(s2, s2, shift_lags)

        # compute the power spectra from the auto-spectra
        ps1 = fft(acf1_shift)
        ps1_freq = fftfreq(len(acf1), d=1.0/sr)
        fi = ps1_freq > 0
        ps1 = ps1[fi]
        assert np.sum(np.abs(ps1.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum(np.abs(ps1.imag) > 1e-8)
        ps1_auto = np.abs(ps1.real)
        ps1_auto_freq = ps1_freq[fi]
        
        ps2 = fft(acf2_shift)
        ps2_freq = fftfreq(len(acf2), d=1.0/sr)
        fi = ps2_freq > 0
        ps2 = ps2[fi]        
        assert np.sum(np.abs(ps2.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf2)"
        ps2_auto = np.abs(ps2.real)
        ps2_auto_freq = ps2_freq[fi]

        assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto"
        assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto"

        # compute the cross spectral density from the correlation function
        cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True)
        psd12 = fft(cf12_shift)
        psd12_freq = fftfreq(len(cf12_shift), d=1.0/sr)
        fi = psd12_freq > 0

        psd12 = np.abs(psd12[fi])
        psd12_freq = psd12_freq[fi]

        # compute the cross spectral density from the power spectra
        psd12_welch = welch_psd1*welch_psd2
        psd12_welch /= psd12_welch.max()

        # compute the coherence from the cross spectral density
        cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \
            coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True)

        coh12_freqspace /= np.abs(coh12_freqspace).max()

        # weight the coherence by one minus the normalized standard deviation
        coherence_std = np.sqrt(coherence_var)
        # cweight = coherence_std / coherence_std.sum()
        # coherence_weighted = (1.0 - cweight)*coherence
        coherence_weighted = coherence - coherence_std
        coherence_weighted[coherence_weighted < 0] = 0

        # compute the coherence from the fft of the coherency
        coherence2 = fft(fftshift(coh12))
        coherence2_freq = fftfreq(len(coherence2), d=1.0/sr)
        fi = coherence2_freq > 0
        coherence2 = np.abs(coherence2[fi])
        coherence2_freq = coherence2_freq[fi]

        """
        plt.figure()
        ax = plt.subplot(2, 1, 1)
        plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75)
        plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9)
        plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75)
        plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75)
        plt.axis('tight')
        plt.legend(['denom', '12', '1', '2'])

        ax = plt.subplot(2, 1, 2)
        plt.plot(psd12_freq, coherence, 'b-')
        plt.axis('tight')
        plt.show()
        """

        # normalize the cross-spectral density and power spectra
        psd12 /= psd12.max()
        ps_auto_max = max(ps1_auto.max(), ps2_auto.max())
        ps1_auto /= ps_auto_max
        ps2_auto /= ps_auto_max

        # make some plots
        plt.figure()

        nrows = 2
        ncols = 2

        # plot the signals
        ax = plt.subplot(nrows, ncols, 1)
        plt.plot(t, s1, 'k-', linewidth=2.0)
        plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0)
        plt.xlabel('Time (s)')
        plt.ylabel('Signal')
        plt.axis('tight')

        # plot the spectra
        ax = plt.subplot(nrows, ncols, 2)
        plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85)
        plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85)
        plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0)
        plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75)
        plt.axis('tight')
        
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power')

        # plot the correlation functions
        ax = plt.subplot(nrows, ncols, 3)
        plt.axhline(0, c='k')
        plt.plot(lags, acf1, 'k-', linewidth=2.0)
        plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0)
        plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0)
        plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75)
        plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'm-', linewidth=1.0, alpha=0.95)
        plt.xlabel('Lag (ms)')
        plt.ylabel('Correlation Function')
        plt.axis('tight')
        plt.ylim(-0.5, 1.0)
        handles = custom_legend(['k', 'r', 'g', 'b', 'c'], ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f'])
        plt.legend(handles=handles)

        # plot the cross spectral density
        ax = plt.subplot(nrows, ncols, 4)
        handles = custom_legend(['g', 'k', 'b'], ['CSD', 'Coherence', 'Weighted'])
        plt.axhline(0, c='k')
        plt.axhline(1, c='k')
        plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0)
        plt.errorbar(cfreq, coherence, yerr=np.sqrt(coherence_var), fmt='k-', ecolor='r', linewidth=3.0, elinewidth=5.0, alpha=0.8)
        plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75)
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Cross-spectral Density/Coherence')
        plt.legend(handles=handles)

        """
        plt.figure()
        plt.axhline(0, c='k')
        plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0)
        plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75)
        plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95)
        plt.xlabel('Lag (ms)')
        plt.ylabel('Correlation Function')
        plt.axis('tight')
        plt.ylim(-0.5, 1.0)
        handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f'])
        plt.legend(handles=handles)
        """

        plt.show()
Beispiel #11
0
def draw_joint_relationships(data_dir='/auto/tdrive/mschachter/data'):

    big3_file = os.path.join(data_dir, 'aggregate', 'acoustic_big3.csv')

    if not os.path.exists(big3_file):
        bs_file = os.path.join(data_dir, 'aggregate', 'biosound.h5')
        bs_agg = AggregateBiosounds.load(bs_file)

        aprops = ['maxAmp', 'sal', 'meanspect']
        Xz, good_indices = bs_agg.remove_duplicates(aprops=aprops, thresh=np.inf)
        stim_types = bs_agg.df.stim_type[good_indices].values

        pdata = {'maxAmp': Xz[:, 0], 'sal': Xz[:, 1], 'meanspect': Xz[:, 2], 'stim_type': stim_types}
        df = pd.DataFrame(pdata)
        df.to_csv(big3_file, header=True, index=False)
    else:
        df = pd.read_csv(big3_file)

    print '# of syllables used: %d' % len(df)

    plt.figure()
    ax = plt.subplot(111, projection='3d')

    maxAmp_rescaled = df.maxAmp
    maxAmp_rescaled -= maxAmp_rescaled.min()
    maxAmp_rescaled = np.log(maxAmp_rescaled + 1e-6)
    maxAmp_rescaled -= maxAmp_rescaled.mean()
    maxAmp_rescaled /= maxAmp_rescaled.std(ddof=1)

    for ct in DECODER_CALL_TYPES:
        i = df.stim_type == ct
        if i.sum() == 0:
            continue

        x = maxAmp_rescaled[i]
        y = df.sal[i].values
        z = df.meanspect[i].values
        c = CALL_TYPE_COLORS[ct]
        ax.scatter(x, y, z, s=49, c=c)

    ax.set_xlabel('Amplitude')
    ax.set_ylabel('Saliency')
    ax.set_zlabel('Mean Frequency')

    leg = custom_legend([CALL_TYPE_COLORS[ct] for ct in DECODER_CALL_TYPES], DECODER_CALL_TYPES)
    plt.legend(handles=leg)

    plt.figure()
    nrows = 1
    ncols = 3

    call_type_order = ['song', 'Te', 'DC', 'Be', 'LT', 'Ne', 'Ag', 'Di', 'Th']

    # plot amplitude vs saliency
    ax = plt.subplot(nrows, ncols, 1)
    for ct in call_type_order:
        i = df.stim_type == ct
        if i.sum() == 0:
            continue
        x = maxAmp_rescaled[i]
        y = df.sal[i].values
        c = CALL_TYPE_COLORS[ct]
        ax.scatter(x, y, s=49, c=c, alpha=0.7)
    plt.xlabel('Maximum Amplitude')
    plt.ylabel('Saliency')
    plt.xlim(-3, 2)

    # plot amplitude vs meanspect
    ax = plt.subplot(nrows, ncols, 2)
    for ct in call_type_order:
        i = df.stim_type == ct
        if i.sum() == 0:
            continue
        x = maxAmp_rescaled[i]
        y = df.meanspect[i].values
        c = CALL_TYPE_COLORS[ct]
        ax.scatter(x, y, s=49, c=c, alpha=0.7)
    plt.xlabel('Maximum Amplitude')
    plt.ylabel('Mean Spectral Freq')
    plt.xlim(-3, 2)

    # plot saliency vs meanspect
    ax = plt.subplot(nrows, ncols, 3)
    for ct in call_type_order:
        i = df.stim_type == ct
        if i.sum() == 0:
            continue
        x = df.sal[i].values
        y = df.meanspect[i].values
        c = CALL_TYPE_COLORS[ct]
        ax.scatter(x, y, s=49, c=c, alpha=0.7)
    plt.xlabel('Saliency')
    plt.ylabel('Mean Spectral Freq')

    plt.show()
Beispiel #12
0
def draw_decoder_perf_barplots(data_dir='/auto/tdrive/mschachter/data', show_all=True):

    aprops_to_display = list(USED_ACOUSTIC_PROPS)
    aprops_to_display.remove('stdtime')

    if not show_all:
        decomps = ['spike_rate', 'full_psds']
        sub_names = ['Spike Rate', 'LFP PSD']
        sub_clrs = [COLOR_RED_SPIKE_RATE, COLOR_BLUE_LFP]
    else:
        decomps = ['spike_rate', 'full_psds', 'spike_rate+spike_sync']
        sub_names = ['Spike Rate', 'LFP PSD', 'Spike Rate + Sync']
        sub_clrs = [COLOR_RED_SPIKE_RATE, COLOR_BLUE_LFP, COLOR_CRIMSON_SPIKE_SYNC]

    df_me = pd.read_csv(os.path.join(data_dir, 'aggregate', 'decoder_perfs_for_glm.csv'))
    bprop_data = list()

    for aprop in aprops_to_display:
        bd = dict()
        for decomp in decomps:
            i = (df_me.decomp == decomp) & (df_me.aprop == aprop)
            perfs = df_me.r2[i].values
            bd[decomp] = perfs
        bprop_data.append({'bd':bd, 'lfp_mean':bd['full_psds'].mean(), 'aprop':aprop})

    bprop_data.sort(key=operator.itemgetter('lfp_mean'), reverse=True)

    lfp_r2 = [bdict['bd']['full_psds'].mean() for bdict in bprop_data]
    lfp_r2_std = [bdict['bd']['full_psds'].std(ddof=1) for bdict in bprop_data]

    spike_r2 = [bdict['bd']['spike_rate'].mean() for bdict in bprop_data]
    spike_r2_std = [bdict['bd']['spike_rate'].std(ddof=1) for bdict in bprop_data]

    if show_all:
        spike_sync_r2 = [bdict['bd']['spike_rate+spike_sync'].mean() for bdict in bprop_data]
        spike_sync_r2_std = [bdict['bd']['spike_rate+spike_sync'].std(ddof=1) for bdict in bprop_data]

    aprops_xticks = [ACOUSTIC_PROP_NAMES[bdict['aprop']] for bdict in bprop_data]

    figsize = (23, 7.)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.15, left=0.05, right=0.99, hspace=0.20, wspace=0.20)

    bar_width = 0.4
    if show_all:
        bar_width = 0.2

    bar_data = [(spike_r2, spike_r2_std), (lfp_r2, lfp_r2_std)]
    if len(decomps) == 3:
        bar_data.append( (spike_sync_r2, spike_sync_r2_std) )

    bar_x = np.arange(len(lfp_r2))
    for k,(br2,bstd) in enumerate(bar_data):
        bx = bar_x + bar_width*k
        plt.bar(bx, br2, yerr=bstd, width=bar_width, color=sub_clrs[k], alpha=0.9, ecolor='k')

    plt.ylabel('Decoder R2')
    plt.xticks(bar_x+0.45, aprops_xticks, rotation=90, fontsize=12)

    leg = custom_legend(sub_clrs, sub_names)
    plt.legend(handles=leg, loc='upper right')
    plt.axis('tight')
    plt.xlim(-0.5, bar_x.max() + 1)
    plt.ylim(0, 0.6)

    fname = os.path.join(get_this_dir(), 'decoder_perf_barplots.svg')
    if show_all:
        fname = os.path.join(get_this_dir(), 'decoder_perf_barplots_all.svg')

    plt.savefig(fname, facecolor='w', edgecolor='none')

    plt.show()