def draw_preds(agg, data_dir='/auto/tdrive/mschachter/data'): df_best = pd.read_csv('/auto/tdrive/mschachter/data/aggregate/rnn_best.csv') # print 'min_err=', df_best.err.min() bird = 'GreBlu9508M' block = 'Site2' segment = 'Call2' hemi = 'L' best_md5 = 'db5d601c8af2621f341d8e4cbd4108c1' fname = '%s_%s_%s_%s' % (bird, block, segment, hemi) preproc_file = os.path.join(data_dir, bird, 'preprocess', 'RNNPreprocess_%s.h5' % fname) rnn_file = os.path.join(data_dir, bird, 'rnn', 'RNNLFPEncoder_%s_%s.h5' %(fname, best_md5)) pred_file = os.path.join(data_dir, bird, 'rnn', 'RNNLFPEncoderPred_%s_%s.h5' %(fname, best_md5)) linear_file = os.path.join(data_dir, bird, 'rnn', 'LFPEnvelope_%s.h5' % fname) if not os.path.exists(pred_file): rpt = RNNPreprocessTransform.load(preproc_file) rpt.write_pred_file(rnn_file, pred_file, lfp_enc_file=linear_file) visualize_pred_file(pred_file, 3, 0, 2.4, electrode_order=ROSTRAL_CAUDAL_ELECTRODES_LEFT[::-1], dbnoise=4.0) leg = custom_legend(['k', 'r', 'b'], ['Real', 'Linear', 'RNN']) plt.legend(handles=leg) fname = os.path.join(get_this_dir(), 'rnn_preds.svg') plt.savefig(fname, facecolor='w', edgecolor='none') plt.show()
def plot_raw_dists(data_dir='/auto/tdrive/mschachter/data'): edata = pd.read_csv(os.path.join(data_dir, 'aggregate', 'electrode_data+dist.csv')) i = edata.bird != 'BlaBro09xxF' edata = edata[i] birds = edata.bird.unique() clrs = ['r', 'g', 'b'] fig = plt.figure() fig.subplots_adjust(top=0.99, bottom=0.01, right=0.99, left=0.01, hspace=0, wspace=0) for k,b in enumerate(birds): i = (edata.bird == b) & ~np.isnan(edata.dist_midline) & ~np.isnan(edata.dist_l2a) x = edata[i].dist_midline.values y = edata[i].dist_l2a.values if b == 'GreBlu9508M': y *= 4 reg = edata[i].region.values for xx,yy,r in zip(x, y, reg): plt.plot(xx, yy, 'o', markersize=8, c=clrs[k], alpha=0.4) plt.text(xx, yy, r, fontsize=8) plt.axis('tight') plt.xlabel('Dist from Midline (mm)') plt.ylabel('Dist from LH (mm)') leg = custom_legend(clrs, birds) plt.legend(handles=leg) plt.show()
def _plot_freqs(pdata, ax): plt.sca(ax) nsamps_lfp = len(pdata['lfp']) lkrat_lfp_mean = pdata['lfp'].mean(axis=0) lkrat_lfp_std = pdata['lfp'].std(axis=0, ddof=1) / np.sqrt(nsamps_lfp) nsamps_spike = len(pdata['spike']) lkrat_spike_mean = pdata['spike'].mean(axis=0) lkrat_spike_std = pdata['spike'].std(axis=0, ddof=1) / np.sqrt(nsamps_spike) if pdata['aprop'] != 'category': plt.axhline(1.0, c='k', linestyle='dashed', alpha=0.7, linewidth=2.0) plt.errorbar(pdata['freqs'], lkrat_lfp_mean, yerr=lkrat_lfp_std, c=COLOR_BLUE_LFP, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0) plt.errorbar(pdata['freqs']+2., lkrat_spike_mean, yerr=lkrat_spike_std, c=COLOR_YELLOW_SPIKE, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0) plt.xlabel('Frequency (Hz)') plt.ylabel('Normalized LR') leg = custom_legend([COLOR_BLUE_LFP, COLOR_YELLOW_SPIKE], ['LFP', 'Spike']) plt.legend(handles=leg, fontsize='x-small') plt.title(pdata['aprop']) plt.axis('tight') if pdata['aprop'] != 'category': plt.ylim(0, 5) else: plt.axis('tight')
def draw_pairwise_weights_vs_dist(agg): wdf = export_pairwise_decoder_weights(agg) r2_thresh = 0.20 # aprops = ALL_ACOUSTIC_PROPS aprops = ['meanspect', 'stdspect', 'sal', 'maxAmp'] aprop_clrs = {'meanspect':'#FF8000', 'stdspect':'#FFBF00', 'sal':'#088A08', 'maxAmp':'k'} print wdf.decomp.unique() # get scatter data for weights vs distance scatter_data = dict() for decomp in ['full_psds+full_cfs', 'spike_rate+spike_sync']: i = wdf.decomp == decomp assert i.sum() > 0 df = wdf[i] for n,aprop in enumerate(aprops): ii = (df.r2 > r2_thresh) & (df.aprop == aprop) & ~np.isnan(df.dist) & ~np.isinf(df.dist) d = df.dist[ii].values w = df.w[ii].values scatter_data[(decomp, aprop)] = (d, w) decomp_labels = {'full_psds+full_cfs':'LFP Pairwise Correlations', 'spike_rate+spike_sync':'Spike Synchrony'} figsize = (14, 5) fig = plt.figure(figsize=figsize) fig.subplots_adjust(left=0.10, right=0.98) for k,decomp in enumerate(['full_psds+full_cfs', 'spike_rate+spike_sync']): ax = plt.subplot(1, 2, k+1) for aprop in aprops: d,w = scatter_data[(decomp, aprop)] wz = w**2 wz /= wz.std(ddof=1) xcenter, ymean, yerr, ymean_cs = compute_mean_from_scatter(d, wz, bins=5, num_smooth_points=300) # plt.plot(xcenter, ymean, '-', c=aprop_clrs[aprop], linewidth=7.0, alpha=0.7) plt.errorbar(xcenter, ymean, linestyle='-', yerr=yerr, c=aprop_clrs[aprop], linewidth=9.0, alpha=0.7, elinewidth=8., ecolor='#d8d8d8', capsize=0.) plt.axis('tight') plt.ylabel('Decoder Weight Effect') plt.xlabel('Pairwise Distance (mm)') plt.title(decomp_labels[decomp]) aprop_lbls = [ACOUSTIC_PROP_NAMES[aprop] for aprop in aprops] aclrs = [aprop_clrs[aprop] for aprop in aprops] leg = custom_legend(aclrs, aprop_lbls) plt.legend(handles=leg, loc='lower left') fname = os.path.join(get_this_dir(), 'pairwise_decoder_effect_vs_dist.svg') plt.savefig(fname, facecolor='w', edgecolor='none') plt.show()
def draw_spike_rate_vs_power(data_dir='/auto/tdrive/mschachter/data'): # read PairwiseCF file pcf_file = os.path.join(data_dir, 'aggregate', 'pairwise_cf.h5') pcf = AggregatePairwiseCF.load(pcf_file) # concatenate the lfp and spike psds nfreqs = len(pcf.freqs) lfp_and_spike_psds = np.zeros([len(pcf.df), nfreqs*2 + 1]) nz = np.zeros(len(pcf.df), dtype='bool') for k,(lfp_index,spike_index) in enumerate(zip(pcf.df['lfp_index'], pcf.df['spike_index'])): lpsd = pcf.lfp_psds[lfp_index, :] spsd = pcf.spike_psds[spike_index, :] srate,sstd = pcf.spike_rates[spike_index, :] nz[k] = np.abs(lpsd).sum() > 0 and np.abs(spsd).sum() > 0 lfp_and_spike_psds[k, :nfreqs] = lpsd lfp_and_spike_psds[k, nfreqs:-1] = spsd lfp_and_spike_psds[k, -1] = np.log(srate) # throw some bad data points out lfp_sum = lfp_and_spike_psds[:, :nfreqs].sum(axis=1) spike_sum = lfp_and_spike_psds[:, nfreqs:-1].sum(axis=1) nz = ~np.isinf(lfp_and_spike_psds[:, -1]) & (lfp_sum > 0) & (spike_sum > 0) & ~np.isnan(spike_sum) & ~np.isnan(lfp_sum) print '# of good data points: %d out of %d' % (nz.sum(), lfp_and_spike_psds.shape[0]) # zscore the concatenated matrix lfp_and_spike_psds = lfp_and_spike_psds[nz, :] lfp_and_spike_psds -= lfp_and_spike_psds.mean(axis=0) lfp_and_spike_psds /= lfp_and_spike_psds.std(axis=0, ddof=1) # compute CC between spike rate and power lfp_spike_rate_cc = np.zeros(len(pcf.freqs)) spike_spike_rate_cc = np.zeros(len(pcf.freqs)) for k,f in enumerate(pcf.freqs): lfp_spike_rate_cc[k] = np.corrcoef(lfp_and_spike_psds[:, k], lfp_and_spike_psds[:, -1])[0, 1] spike_spike_rate_cc[k] = np.corrcoef(lfp_and_spike_psds[:, k+len(pcf.freqs)], lfp_and_spike_psds[:, -1])[0, 1] fig = plt.figure(figsize=(12, 7)) plt.axhline(0, c='k') plt.plot(pcf.freqs, lfp_spike_rate_cc, '-', linewidth=7.0, alpha=0.7, c=COLOR_BLUE_LFP) plt.plot(pcf.freqs, spike_spike_rate_cc, '-', linewidth=7.0, alpha=0.7, c=COLOR_YELLOW_SPIKE) plt.xlabel('Frequency (Hz)') plt.ylabel('Correlation Coefficient') plt.title('CC Between Log Spike Rate and Spectral Power') plt.axis('tight') plt.ylim(-0.1, 0.6) leg = custom_legend([COLOR_BLUE_LFP, COLOR_YELLOW_SPIKE], ['LFP PSD', 'Spike PSD']) plt.legend(handles=leg, fontsize='x-small') fname = os.path.join(get_this_dir(), 'power_vs_rate.svg') plt.savefig(fname, facecolor='w', edgecolor='none') plt.show()
def draw_rate_weight_by_dist(agg): wdf = get_encoder_weight_data_for_psd(agg, include_sync=False, write_to_file=False) def exp_func(_x, _a, _b, _c): return _a * np.exp(-_b * _x) + _c # plot the average encoder weight as a function of distance from predicted electrode freqs = [15, 55, 135] band_labels = ['0-30Hz', '30-80Hz', '80-190Hz'] clrs = {15:'k', 55:'r', 135:'b'} for f in freqs: i = ~np.isnan(wdf.dist_from_electrode.values) & (wdf.r2 > 0.05) & (wdf.dist_from_electrode > 0) & (wdf.f == f) x = wdf.dist_from_electrode[i].values y = (wdf.w[i].values)**2 popt, pcov = curve_fit(exp_func, x, y) ypred = exp_func(x, *popt) ysqerr = (y - ypred)**2 sstot = np.sum((y - y.mean())**2) sserr = np.sum(ysqerr) r2 = 1. - (sserr / sstot) print 'f=%dHz, a=%0.6f, space_const=%0.6f, bias=%0.6f, r2=%0.2f: ' % (f, popt[0], 1. / popt[1], popt[2], r2) npts = 100 xreg = np.linspace(x.min()+1e-1, x.max()-1e-1, npts) yreg = exp_func(xreg, *popt) # approximate sqrt(err) with a cubic spline for plotting err_xcenter, err_ymean, err_yerr, err_ymean_cs = compute_mean_from_scatter(x, np.sqrt(ysqerr), bins=4, num_smooth_points=npts) # yerr = err_ymean_cs(xreg) # plt.plot(x, y, 'ko', alpha=0.7) plt.plot(xreg, yreg, clrs[f], alpha=0.7, linewidth=5.0) # plt.errorbar(xreg, yreg, yerr=err_ymean, c=clrs[f], alpha=0.7, linewidth=5.0, ecolor='#b5b5b5') # plt.show() # plot_mean_from_scatter(x, y, bins=4, num_smooth_points=200, alpha=0.7, color=clrs[f], ecolor='#b5b5b5', bin_by_quantile=False) plt.xlabel('Distance From Predicted Electrode (mm)') plt.ylabel('Encoder Weight^2') plt.axis('tight') freq_clrs = [clrs[f] for f in freqs] leg = custom_legend(colors=freq_clrs, labels=band_labels) plt.legend(handles=leg, loc='lower right')
def test_cross_psd(self): np.random.seed(1234567) sr = 1000.0 dur = 1.0 nt = int(dur * sr) t = np.arange(nt) / sr # create a simple signal freqs = list() freqs.extend(np.arange(8, 12)) freqs.extend(np.arange(60, 71)) freqs.extend(np.arange(130, 151)) s1 = np.zeros([nt]) for f in freqs: s1 += np.sin(2 * np.pi * f * t) s1 /= s1.max() # create a noise corrupted, bandpassed filtered version of s1 noise = np.random.randn(nt) * 1e-1 # s2 = convolve1d(s1, filt, mode='mirror') + noise s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.) s2 /= s2.max() s2 += noise # compute the signal's power spectrums welch_freq1, welch_psd1 = welch(s1, fs=sr) welch_freq2, welch_psd2 = welch(s2, fs=sr) welch_psd_max = max(welch_psd1.max(), welch_psd2.max()) welch_psd1 /= welch_psd_max welch_psd2 /= welch_psd_max # compute the auto-correlation functions lags = np.arange(-200, 201) acf1 = correlation_function(s1, s1, lags, normalize=True) acf2 = correlation_function(s2, s2, lags, normalize=True) # compute the cross correlation functions cf12 = correlation_function(s1, s2, lags, normalize=True) coh12 = coherency(s1, s2, lags, window_fraction=0.75, noise_floor_db=100.) # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power # spectrum for some numerical reason shift_lags = fftshift(lags) if len(lags) % 2 == 1: # shift zero from end of shift_lags to beginning shift_lags = np.roll(shift_lags, 1) acf1_shift = correlation_function(s1, s1, shift_lags) acf2_shift = correlation_function(s2, s2, shift_lags) # compute the power spectra from the auto-spectra ps1 = fft(acf1_shift) ps1_freq = fftfreq(len(acf1), d=1.0 / sr) fi = ps1_freq > 0 ps1 = ps1[fi] assert np.sum( np.abs(ps1.imag) > 1e-8 ) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum( np.abs(ps1.imag) > 1e-8) ps1_auto = np.abs(ps1.real) ps1_auto_freq = ps1_freq[fi] ps2 = fft(acf2_shift) ps2_freq = fftfreq(len(acf2), d=1.0 / sr) fi = ps2_freq > 0 ps2 = ps2[fi] assert np.sum(np.abs(ps2.imag) > 1e-8 ) == 0, "Nonzero imaginary part for fft(acf2)" ps2_auto = np.abs(ps2.real) ps2_auto_freq = ps2_freq[fi] assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto" assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto" # compute the cross spectral density from the correlation function cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True) psd12 = fft(cf12_shift) psd12_freq = fftfreq(len(cf12_shift), d=1.0 / sr) fi = psd12_freq > 0 psd12 = np.abs(psd12[fi]) psd12_freq = psd12_freq[fi] # compute the cross spectral density from the power spectra psd12_welch = welch_psd1 * welch_psd2 psd12_welch /= psd12_welch.max() # compute the coherence from the cross spectral density cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \ coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True) coh12_freqspace /= np.abs(coh12_freqspace).max() # weight the coherence by one minus the normalized standard deviation coherence_std = np.sqrt(coherence_var) # cweight = coherence_std / coherence_std.sum() # coherence_weighted = (1.0 - cweight)*coherence coherence_weighted = coherence - coherence_std coherence_weighted[coherence_weighted < 0] = 0 # compute the coherence from the fft of the coherency coherence2 = fft(fftshift(coh12)) coherence2_freq = fftfreq(len(coherence2), d=1.0 / sr) fi = coherence2_freq > 0 coherence2 = np.abs(coherence2[fi]) coherence2_freq = coherence2_freq[fi] """ plt.figure() ax = plt.subplot(2, 1, 1) plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75) plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9) plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75) plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.legend(['denom', '12', '1', '2']) ax = plt.subplot(2, 1, 2) plt.plot(psd12_freq, coherence, 'b-') plt.axis('tight') plt.show() """ # normalize the cross-spectral density and power spectra psd12 /= psd12.max() ps_auto_max = max(ps1_auto.max(), ps2_auto.max()) ps1_auto /= ps_auto_max ps2_auto /= ps_auto_max # make some plots plt.figure() nrows = 2 ncols = 2 # plot the signals ax = plt.subplot(nrows, ncols, 1) plt.plot(t, s1, 'k-', linewidth=2.0) plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0) plt.xlabel('Time (s)') plt.ylabel('Signal') plt.axis('tight') # plot the spectra ax = plt.subplot(nrows, ncols, 2) plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85) plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85) plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.xlabel('Frequency (Hz)') plt.ylabel('Power') # plot the correlation functions ax = plt.subplot(nrows, ncols, 3) plt.axhline(0, c='k') plt.plot(lags, acf1, 'k-', linewidth=2.0) plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75) plt.plot(coh12_freqspace_t * 1e3, coh12_freqspace, 'm-', linewidth=1.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'r', 'g', 'b', 'c'], ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) # plot the cross spectral density ax = plt.subplot(nrows, ncols, 4) handles = custom_legend(['g', 'k', 'b'], ['CSD', 'Coherence', 'Weighted']) plt.axhline(0, c='k') plt.axhline(1, c='k') plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0) plt.errorbar(cfreq, coherence, yerr=np.sqrt(coherence_var), fmt='k-', ecolor='r', linewidth=3.0, elinewidth=5.0, alpha=0.8) plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75) plt.xlabel('Frequency (Hz)') plt.ylabel('Cross-spectral Density/Coherence') plt.legend(handles=handles) """ plt.figure() plt.axhline(0, c='k') plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) """ plt.show()
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None): # load up the experiment if exp is None: bird_dir = os.path.join(data_dir, bird) exp_file = os.path.join(bird_dir, '%s.h5' % bird) stim_file = os.path.join(bird_dir, 'stims.h5') exp = Experiment.load(exp_file, stim_file) seg = exp.get_segment(block, segment) # get the start and end times of the stimulus etable = exp.get_epoch_table(seg) i = etable['id'] == stim_id stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values) stim_times.sort(key=operator.itemgetter(0)) stim_times = np.array(stim_times) stim_durs = stim_times[:, 1] - stim_times[:, 0] stim_dur = stim_durs.min() # aggregate the LFPs across trials lfps = list() sample_rate = None electrode_indices = None for start_time,end_time in stim_times: # get a slice of the LFP lfp_data = exp.get_lfp_slice(seg, start_time, end_time) electrode_indices,the_lfps,sample_rate = lfp_data[hemi] stim_dur_i = int(stim_dur*sample_rate) lfps.append(the_lfps[:, :stim_dur_i]) lfps = np.array(lfps) # rescale the LFPs to they are in uV lfps *= 1e6 # get the log spectrogram of the stimulus start_time_0 = stim_times[0][0] end_time_0 = stim_times[0][1] stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time_0, end_time_0) stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t)) stim_spec_dt = np.diff(stim_spec_t)[0] nz = stim_spec > 0 stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100 stim_spec[stim_spec < 0] = 0 # get the amplitude envelope amp_env = stim_spec.std(axis=0, ddof=1) amp_env -= amp_env.min() amp_env /= amp_env.max() # segment the amplitude envelope into syllables merge_thresh = int(0.002*sample_rate) events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh) # translate the event indices into actual times events *= stim_spec_dt syllable_start,syllable_end,syllable_max_amp = events[syllable_index] syllable_start -= 0.005 syllable_end += 0.010 amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min() last_syllable_end = events[-1, 1] + 0.025 i1 = electrode_indices.index(e1) i2 = electrode_indices.index(e2) lfp1 = lfps[:, i1, :] lfp2 = lfps[:, i2, :] # zscore the LFP lfp1 -= lfp1.mean() lfp1 /= lfp1.std(ddof=1) lfp2 -= lfp2.mean() lfp2 /= lfp2.std(ddof=1) ntrials,nelectrodes,nt = lfps.shape ntrials_to_plot = 5 t = np.arange(nt) / sample_rate # plot the stimulus and raw LFP for two electrodes figsize = (24.0, 10) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10) gs = plt.GridSpec(3, 100) ax = plt.subplot(gs[0, :60]) plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000, colormap='SpectroColorMap', colorbar=False) # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50) plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0) plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0) plt.axis('tight') plt.xlim(0, last_syllable_end) # plot the first LFP (all trials) ax = plt.subplot(gs[1, :60]) for k in range(ntrials_to_plot): plt.plot(t, lfp1[k, :], '-', linewidth=2.0, alpha=0.75) plt.plot(t, lfp1.mean(axis=0), 'k-', linewidth=3.0) plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0) plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0) plt.xlabel('Time (ms)') plt.ylabel('E%d (z-scored)' % e1) plt.axis('tight') plt.xlim(0, last_syllable_end) ax = plt.subplot(gs[2, :60]) for k in range(ntrials_to_plot): plt.plot(t, lfp2[k, :], '-', linewidth=2.0, alpha=0.75) plt.plot(t, lfp2.mean(axis=0), 'k-', linewidth=3.0) plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0) plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0) plt.xlabel('Time (ms)') plt.ylabel('E%d (z-scored)' % e2) plt.axis('tight') plt.xlim(0, last_syllable_end) # restrict the lfps to a single syllable print 'syllable_start=%f, syllable_end=%f' % (syllable_start, syllable_end) syllable_si = int(syllable_start*sample_rate) syllable_ei = int(syllable_end*sample_rate) lfp1 = lfp1[:, syllable_si:syllable_ei] lfp2 = lfp2[:, syllable_si:syllable_ei] # compute the trial averaged and mean subtracted lfps lfp1_mean,lfp1_ms = compute_avg_and_ms(lfp1) lfp2_mean,lfp2_ms = compute_avg_and_ms(lfp2) psd_stats = get_psd_stats(bird, block, segment, hemi) freqs, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total = compute_spectra_and_coherence_single_electrode(lfp1, lfp2, sample_rate, e1, e2, psd_stats=psd_stats) lags_ms = get_lags_ms(sample_rate) lfp_absmax = max(np.abs(lfp1_mean).max(), np.abs(lfp2_mean).max()) lfp_ms_absmax = max(np.abs(lfp1_ms).max(), np.abs(lfp2_ms).max()) ax = plt.subplot(gs[0, 65:80]) plt.axhline(0, c='k') plt.plot(t[syllable_si:syllable_ei], lfp1_mean, 'k-', linewidth=3.0, alpha=1.) plt.plot(t[syllable_si:syllable_ei], lfp2_mean, '-', c='#c0c0c0', linewidth=3.0) # plt.xlabel('Time (s)') plt.ylabel('Trial-avg LFP') leg = custom_legend(['k', '#c0c0c0'], ['E%d' % e1, 'E%d' % e2]) plt.legend(handles=leg, fontsize='x-small') plt.axis('tight') plt.ylim(-lfp_absmax, lfp_absmax) ax = plt.subplot(gs[0, 85:]) plt.axhline(0, c='k') plt.plot(t[syllable_si:syllable_ei], lfp1_ms, 'k-', linewidth=3.0, alpha=1.) plt.plot(t[syllable_si:syllable_ei], lfp2_ms, '-', c='#c0c0c0', linewidth=3.0) # plt.xlabel('Time (s)') plt.ylabel('Mean-sub LFP') plt.legend(handles=leg, fontsize='x-small') plt.axis('tight') plt.ylim(-lfp_ms_absmax, lfp_ms_absmax) psd_max = max(psd1.max(), psd2.max(), psd1_ms.max(), psd2_ms.max()) ax = plt.subplot(gs[1, 65:80]) plt.axhline(0, c='k') plt.plot(freqs, psd1, 'k-', linewidth=3.0) plt.plot(freqs, psd2, '-', c='#c0c0c0', linewidth=3.0) plt.xlabel('Time (s)') plt.ylabel('Trial-avg Power') plt.legend(handles=leg, fontsize='x-small', loc=2) plt.axis('tight') # plt.ylim(0, psd_max) ax = plt.subplot(gs[1, 85:]) plt.axhline(0, c='k') plt.plot(freqs, psd1_ms, 'k-', linewidth=3.0) plt.plot(freqs, psd2_ms, '-', c='#c0c0c0', linewidth=3.0) plt.xlabel('Time (s)') plt.ylabel('Mean-sub Power') plt.legend(handles=leg, fontsize='x-small', loc=2) plt.axis('tight') # plt.ylim(0, psd_max) ax = plt.subplot(gs[2, 65:80]) plt.axhline(0, c='k') plt.axvline(0, c='k') plt.plot(lags_ms, c12_total, 'k-', linewidth=3.0, alpha=0.75) plt.plot(lags_ms, c12, '-', c='r', linewidth=3.0, alpha=0.75) plt.plot(lags_ms, c12_nonlocked, '-', c='b', linewidth=3.0, alpha=0.75) plt.xlabel('Lags (ms)') plt.ylabel('Coherency') leg = custom_legend(['k', 'r', 'b'], ['Raw', 'Trial-avg', 'Mean-sub']) plt.legend(handles=leg, fontsize='x-small') plt.axis('tight') plt.ylim(-0.2, 0.3) fname = os.path.join(get_this_dir(), 'raw+coherency.svg') plt.savefig(fname, facecolor='w', edgecolor='none')
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, trial, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None): # load up the experiment if exp is None: bird_dir = os.path.join(data_dir, bird) exp_file = os.path.join(bird_dir, '%s.h5' % bird) stim_file = os.path.join(bird_dir, 'stims.h5') exp = Experiment.load(exp_file, stim_file) seg = exp.get_segment(block, segment) # get the start and end times of the stimulus etable = exp.get_epoch_table(seg) i = etable['id'] == stim_id stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values) stim_times.sort(key=operator.itemgetter(0)) start_time,end_time = stim_times[trial] stim_dur = float(end_time - start_time) # get a slice of the LFP lfp_data = exp.get_lfp_slice(seg, start_time, end_time, zscore=True) electrode_indices,lfps,sample_rate = lfp_data[hemi] # get the log spectrogram of the stimulus stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time, end_time) stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t)) stim_spec_dt = np.diff(stim_spec_t)[0] nz = stim_spec > 0 stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100 stim_spec[stim_spec < 0] = 0 # get the amplitude envelope amp_env = stim_spec.std(axis=0, ddof=1) amp_env -= amp_env.min() amp_env /= amp_env.max() # segment the amplitude envelope into syllables merge_thresh = int(0.002*sample_rate) events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh) # translate the event indices into actual times events *= stim_spec_dt syllable_start,syllable_end,syllable_max_amp = events[syllable_index] amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min() last_syllable_end = events[-1, 1] + 0.025 i1 = electrode_indices.index(e1) i2 = electrode_indices.index(e2) lfp1 = lfps[i1, :] lfp2 = lfps[i2, :] t = np.arange(len(lfp1)) / sample_rate legend = ['E%d' % e1, 'E%d' % e2] if hemi == 'L': electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT else: electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT # get the power spectrum stats for this site psd_stats = get_psd_stats(bird, block, segment, hemi) # compute the power spectra and cross coherence for all electrodes lags_ms = get_lags_ms(sample_rate) spectra,cross_mat = compute_spectra_and_coherence_multi_electrode_single_trial(lfps, sample_rate, electrode_indices, electrode_order, psd_stats=psd_stats) # plot the stimulus and raw LFP for two electrodes figsize = (24.0, 10) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10) ax = plt.subplot(2, 1, 1) plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000, colormap='SpectroColorMap', colorbar=False) # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50) plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0) plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0) plt.axis('tight') plt.xlim(0, last_syllable_end) ax = plt.subplot(2, 1, 2) plt.plot(t, lfp1, 'b-', linewidth=3.0) plt.plot(t, lfp2, 'r-', linewidth=3.0, alpha=0.7) plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0) plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0) plt.xlabel('Time (ms)') plt.ylabel('LFP (z-scored)') plt.legend(legend) plt.axis('tight') plt.xlim(0, last_syllable_end) fname = os.path.join(get_this_dir(), 'raw.svg') plt.savefig(fname, facecolor='w', edgecolor='none') # restrict the lfps to a single syllable syllable_si = int(syllable_start*sample_rate) syllable_ei = int(syllable_end*sample_rate) lfp1 = lfp1[syllable_si:syllable_ei] lfp2 = lfp2[syllable_si:syllable_ei] # plot the two power spectra psd_ub = 6 psd_lb = 0 i1 = electrode_order.index(e1) i2 = electrode_order.index(e2) a1 = spectra[i1, :] a2 = spectra[i2, :] freqs = get_freqs(sample_rate) figsize = (10.0, 4.0) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.90, bottom=0.10, left=0.10, right=0.99, hspace=0.10) ax = plt.subplot(1, 2, 1) plt.plot(freqs, a1, 'b-', linewidth=3.0) plt.plot(freqs, a2, 'r-', linewidth=3.0) plt.xlabel('Frequency (Hz)') plt.ylabel('Power (z-scored)') plt.title('Power Spectrum') handles = custom_legend(['b', 'r'], legend) plt.legend(handles=handles, fontsize='small') plt.axis('tight') plt.ylim(psd_lb, psd_ub) # plot the coherency cf_lb = -0.1 cf_ub = 0.3 coh = cross_mat[i1, i2, :] ax = plt.subplot(1, 2, 2) plt.axhline(0, c='k') plt.axvline(0, c='k') plt.plot(lags_ms, coh, 'g-', linewidth=3.0) plt.xlabel('Frequency (Hz)') plt.title('Coherency') plt.axis('tight') plt.ylim(cf_lb, cf_ub) fname = os.path.join(get_this_dir(), 'auto+cross.svg') plt.savefig(fname, facecolor='w', edgecolor='none') # compute all cross and auto-correlations nelectrodes = len(electrode_order) # make a plot figsize = (24.0, 13.5) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10) gs = plt.GridSpec(nelectrodes, nelectrodes) for i in range(nelectrodes): for j in range(i+1): ax = plt.subplot(gs[i, j]) plt.axhline(0, c='k') plt.axvline(0, c='k') _e1 = electrode_order[i] _e2 = electrode_order[j] clr = 'k' if i == j: if _e1 == e1: clr = 'b' elif _e1 == e2: clr = 'r' else: if _e1 == e1 and _e2 == e2: clr = 'g' if i == j: plt.plot(freqs, spectra[i, :], '-', c=clr, linewidth=2.0) else: plt.plot(lags_ms, cross_mat[i, j, :], '-', c=clr, linewidth=2.0) plt.xticks([]) plt.yticks([]) plt.axis('tight') if i != j: plt.ylim(cf_lb, cf_ub) else: plt.axhline(0, c='k') plt.ylim(psd_lb, psd_ub) if j == 0: plt.ylabel('E%d' % electrode_order[i]) if i == nelectrodes-1: plt.xlabel('E%d' % electrode_order[j]) fname = os.path.join(get_this_dir(), 'cross_all.svg') plt.savefig(fname, facecolor='w', edgecolor='none')
def test_cross_psd(self): np.random.seed(1234567) sr = 1000.0 dur = 1.0 nt = int(dur*sr) t = np.arange(nt) / sr # create a simple signal freqs = list() freqs.extend(np.arange(8, 12)) freqs.extend(np.arange(60, 71)) freqs.extend(np.arange(130, 151)) s1 = np.zeros([nt]) for f in freqs: s1 += np.sin(2*np.pi*f*t) s1 /= s1.max() # create a noise corrupted, bandpassed filtered version of s1 noise = np.random.randn(nt)*1e-1 # s2 = convolve1d(s1, filt, mode='mirror') + noise s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.) s2 /= s2.max() s2 += noise # compute the signal's power spectrums welch_freq1,welch_psd1 = welch(s1, fs=sr) welch_freq2,welch_psd2 = welch(s2, fs=sr) welch_psd_max = max(welch_psd1.max(), welch_psd2.max()) welch_psd1 /= welch_psd_max welch_psd2 /= welch_psd_max # compute the auto-correlation functions lags = np.arange(-200, 201) acf1 = correlation_function(s1, s1, lags, normalize=True) acf2 = correlation_function(s2, s2, lags, normalize=True) # compute the cross correlation functions cf12 = correlation_function(s1, s2, lags, normalize=True) coh12 = coherency(s1, s2, lags, window_fraction=0.75, noise_floor_db=100.) # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power # spectrum for some numerical reason shift_lags = fftshift(lags) if len(lags) % 2 == 1: # shift zero from end of shift_lags to beginning shift_lags = np.roll(shift_lags, 1) acf1_shift = correlation_function(s1, s1, shift_lags) acf2_shift = correlation_function(s2, s2, shift_lags) # compute the power spectra from the auto-spectra ps1 = fft(acf1_shift) ps1_freq = fftfreq(len(acf1), d=1.0/sr) fi = ps1_freq > 0 ps1 = ps1[fi] assert np.sum(np.abs(ps1.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum(np.abs(ps1.imag) > 1e-8) ps1_auto = np.abs(ps1.real) ps1_auto_freq = ps1_freq[fi] ps2 = fft(acf2_shift) ps2_freq = fftfreq(len(acf2), d=1.0/sr) fi = ps2_freq > 0 ps2 = ps2[fi] assert np.sum(np.abs(ps2.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf2)" ps2_auto = np.abs(ps2.real) ps2_auto_freq = ps2_freq[fi] assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto" assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto" # compute the cross spectral density from the correlation function cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True) psd12 = fft(cf12_shift) psd12_freq = fftfreq(len(cf12_shift), d=1.0/sr) fi = psd12_freq > 0 psd12 = np.abs(psd12[fi]) psd12_freq = psd12_freq[fi] # compute the cross spectral density from the power spectra psd12_welch = welch_psd1*welch_psd2 psd12_welch /= psd12_welch.max() # compute the coherence from the cross spectral density cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \ coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True) coh12_freqspace /= np.abs(coh12_freqspace).max() # weight the coherence by one minus the normalized standard deviation coherence_std = np.sqrt(coherence_var) # cweight = coherence_std / coherence_std.sum() # coherence_weighted = (1.0 - cweight)*coherence coherence_weighted = coherence - coherence_std coherence_weighted[coherence_weighted < 0] = 0 # compute the coherence from the fft of the coherency coherence2 = fft(fftshift(coh12)) coherence2_freq = fftfreq(len(coherence2), d=1.0/sr) fi = coherence2_freq > 0 coherence2 = np.abs(coherence2[fi]) coherence2_freq = coherence2_freq[fi] """ plt.figure() ax = plt.subplot(2, 1, 1) plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75) plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9) plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75) plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.legend(['denom', '12', '1', '2']) ax = plt.subplot(2, 1, 2) plt.plot(psd12_freq, coherence, 'b-') plt.axis('tight') plt.show() """ # normalize the cross-spectral density and power spectra psd12 /= psd12.max() ps_auto_max = max(ps1_auto.max(), ps2_auto.max()) ps1_auto /= ps_auto_max ps2_auto /= ps_auto_max # make some plots plt.figure() nrows = 2 ncols = 2 # plot the signals ax = plt.subplot(nrows, ncols, 1) plt.plot(t, s1, 'k-', linewidth=2.0) plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0) plt.xlabel('Time (s)') plt.ylabel('Signal') plt.axis('tight') # plot the spectra ax = plt.subplot(nrows, ncols, 2) plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85) plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85) plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.xlabel('Frequency (Hz)') plt.ylabel('Power') # plot the correlation functions ax = plt.subplot(nrows, ncols, 3) plt.axhline(0, c='k') plt.plot(lags, acf1, 'k-', linewidth=2.0) plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'm-', linewidth=1.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'r', 'g', 'b', 'c'], ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) # plot the cross spectral density ax = plt.subplot(nrows, ncols, 4) handles = custom_legend(['g', 'k', 'b'], ['CSD', 'Coherence', 'Weighted']) plt.axhline(0, c='k') plt.axhline(1, c='k') plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0) plt.errorbar(cfreq, coherence, yerr=np.sqrt(coherence_var), fmt='k-', ecolor='r', linewidth=3.0, elinewidth=5.0, alpha=0.8) plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75) plt.xlabel('Frequency (Hz)') plt.ylabel('Cross-spectral Density/Coherence') plt.legend(handles=handles) """ plt.figure() plt.axhline(0, c='k') plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) """ plt.show()
def draw_joint_relationships(data_dir='/auto/tdrive/mschachter/data'): big3_file = os.path.join(data_dir, 'aggregate', 'acoustic_big3.csv') if not os.path.exists(big3_file): bs_file = os.path.join(data_dir, 'aggregate', 'biosound.h5') bs_agg = AggregateBiosounds.load(bs_file) aprops = ['maxAmp', 'sal', 'meanspect'] Xz, good_indices = bs_agg.remove_duplicates(aprops=aprops, thresh=np.inf) stim_types = bs_agg.df.stim_type[good_indices].values pdata = {'maxAmp': Xz[:, 0], 'sal': Xz[:, 1], 'meanspect': Xz[:, 2], 'stim_type': stim_types} df = pd.DataFrame(pdata) df.to_csv(big3_file, header=True, index=False) else: df = pd.read_csv(big3_file) print '# of syllables used: %d' % len(df) plt.figure() ax = plt.subplot(111, projection='3d') maxAmp_rescaled = df.maxAmp maxAmp_rescaled -= maxAmp_rescaled.min() maxAmp_rescaled = np.log(maxAmp_rescaled + 1e-6) maxAmp_rescaled -= maxAmp_rescaled.mean() maxAmp_rescaled /= maxAmp_rescaled.std(ddof=1) for ct in DECODER_CALL_TYPES: i = df.stim_type == ct if i.sum() == 0: continue x = maxAmp_rescaled[i] y = df.sal[i].values z = df.meanspect[i].values c = CALL_TYPE_COLORS[ct] ax.scatter(x, y, z, s=49, c=c) ax.set_xlabel('Amplitude') ax.set_ylabel('Saliency') ax.set_zlabel('Mean Frequency') leg = custom_legend([CALL_TYPE_COLORS[ct] for ct in DECODER_CALL_TYPES], DECODER_CALL_TYPES) plt.legend(handles=leg) plt.figure() nrows = 1 ncols = 3 call_type_order = ['song', 'Te', 'DC', 'Be', 'LT', 'Ne', 'Ag', 'Di', 'Th'] # plot amplitude vs saliency ax = plt.subplot(nrows, ncols, 1) for ct in call_type_order: i = df.stim_type == ct if i.sum() == 0: continue x = maxAmp_rescaled[i] y = df.sal[i].values c = CALL_TYPE_COLORS[ct] ax.scatter(x, y, s=49, c=c, alpha=0.7) plt.xlabel('Maximum Amplitude') plt.ylabel('Saliency') plt.xlim(-3, 2) # plot amplitude vs meanspect ax = plt.subplot(nrows, ncols, 2) for ct in call_type_order: i = df.stim_type == ct if i.sum() == 0: continue x = maxAmp_rescaled[i] y = df.meanspect[i].values c = CALL_TYPE_COLORS[ct] ax.scatter(x, y, s=49, c=c, alpha=0.7) plt.xlabel('Maximum Amplitude') plt.ylabel('Mean Spectral Freq') plt.xlim(-3, 2) # plot saliency vs meanspect ax = plt.subplot(nrows, ncols, 3) for ct in call_type_order: i = df.stim_type == ct if i.sum() == 0: continue x = df.sal[i].values y = df.meanspect[i].values c = CALL_TYPE_COLORS[ct] ax.scatter(x, y, s=49, c=c, alpha=0.7) plt.xlabel('Saliency') plt.ylabel('Mean Spectral Freq') plt.show()
def draw_decoder_perf_barplots(data_dir='/auto/tdrive/mschachter/data', show_all=True): aprops_to_display = list(USED_ACOUSTIC_PROPS) aprops_to_display.remove('stdtime') if not show_all: decomps = ['spike_rate', 'full_psds'] sub_names = ['Spike Rate', 'LFP PSD'] sub_clrs = [COLOR_RED_SPIKE_RATE, COLOR_BLUE_LFP] else: decomps = ['spike_rate', 'full_psds', 'spike_rate+spike_sync'] sub_names = ['Spike Rate', 'LFP PSD', 'Spike Rate + Sync'] sub_clrs = [COLOR_RED_SPIKE_RATE, COLOR_BLUE_LFP, COLOR_CRIMSON_SPIKE_SYNC] df_me = pd.read_csv(os.path.join(data_dir, 'aggregate', 'decoder_perfs_for_glm.csv')) bprop_data = list() for aprop in aprops_to_display: bd = dict() for decomp in decomps: i = (df_me.decomp == decomp) & (df_me.aprop == aprop) perfs = df_me.r2[i].values bd[decomp] = perfs bprop_data.append({'bd':bd, 'lfp_mean':bd['full_psds'].mean(), 'aprop':aprop}) bprop_data.sort(key=operator.itemgetter('lfp_mean'), reverse=True) lfp_r2 = [bdict['bd']['full_psds'].mean() for bdict in bprop_data] lfp_r2_std = [bdict['bd']['full_psds'].std(ddof=1) for bdict in bprop_data] spike_r2 = [bdict['bd']['spike_rate'].mean() for bdict in bprop_data] spike_r2_std = [bdict['bd']['spike_rate'].std(ddof=1) for bdict in bprop_data] if show_all: spike_sync_r2 = [bdict['bd']['spike_rate+spike_sync'].mean() for bdict in bprop_data] spike_sync_r2_std = [bdict['bd']['spike_rate+spike_sync'].std(ddof=1) for bdict in bprop_data] aprops_xticks = [ACOUSTIC_PROP_NAMES[bdict['aprop']] for bdict in bprop_data] figsize = (23, 7.) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.95, bottom=0.15, left=0.05, right=0.99, hspace=0.20, wspace=0.20) bar_width = 0.4 if show_all: bar_width = 0.2 bar_data = [(spike_r2, spike_r2_std), (lfp_r2, lfp_r2_std)] if len(decomps) == 3: bar_data.append( (spike_sync_r2, spike_sync_r2_std) ) bar_x = np.arange(len(lfp_r2)) for k,(br2,bstd) in enumerate(bar_data): bx = bar_x + bar_width*k plt.bar(bx, br2, yerr=bstd, width=bar_width, color=sub_clrs[k], alpha=0.9, ecolor='k') plt.ylabel('Decoder R2') plt.xticks(bar_x+0.45, aprops_xticks, rotation=90, fontsize=12) leg = custom_legend(sub_clrs, sub_names) plt.legend(handles=leg, loc='upper right') plt.axis('tight') plt.xlim(-0.5, bar_x.max() + 1) plt.ylim(0, 0.6) fname = os.path.join(get_this_dir(), 'decoder_perf_barplots.svg') if show_all: fname = os.path.join(get_this_dir(), 'decoder_perf_barplots_all.svg') plt.savefig(fname, facecolor='w', edgecolor='none') plt.show()