def draw_coherency_matrix( bird, block, segment, hemi, stim_id, trial, syllable_index, data_dir="/auto/tdrive/mschachter/data", exp=None, save=True, ): # load up the experiment if exp is None: bird_dir = os.path.join(data_dir, bird) exp_file = os.path.join(bird_dir, "%s.h5" % bird) stim_file = os.path.join(bird_dir, "stims.h5") exp = Experiment.load(exp_file, stim_file) seg = exp.get_segment(block, segment) # get the start and end times of the stimulus etable = exp.get_epoch_table(seg) i = etable["id"] == stim_id stim_times = zip(etable[i]["start_time"].values, etable[i]["end_time"].values) stim_times.sort(key=operator.itemgetter(0)) start_time, end_time = stim_times[trial] stim_dur = float(end_time - start_time) # get a slice of the LFP lfp_data = exp.get_lfp_slice(seg, start_time, end_time) electrode_indices, lfps, sample_rate = lfp_data[hemi] # rescale the LFPs to they are in uV lfps *= 1e6 # get the log spectrogram of the stimulus stim_spec_t, stim_spec_freq, stim_spec = exp.get_spectrogram_slice(seg, start_time, end_time) stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t)) stim_spec_dt = np.diff(stim_spec_t)[0] nz = stim_spec > 0 stim_spec[nz] = 20 * np.log10(stim_spec[nz]) + 100 stim_spec[stim_spec < 0] = 0 # get the amplitude envelope amp_env = stim_spec.std(axis=0, ddof=1) amp_env -= amp_env.min() amp_env /= amp_env.max() # segment the amplitude envelope into syllables merge_thresh = int(0.002 * sample_rate) events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh) # translate the event indices into actual times events *= stim_spec_dt syllable_start, syllable_end, syllable_max_amp = events[syllable_index] syllable_si = int(syllable_start * sample_rate) syllable_ei = int(syllable_end * sample_rate) # compute all cross and auto-correlations if hemi == "L": electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT else: electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT lags = np.arange(-20, 21) lags_ms = (lags / sample_rate) * 1e3 window_fraction = 0.35 noise_db = 25.0 nelectrodes = len(electrode_order) cross_mat = np.zeros([nelectrodes, nelectrodes, len(lags)]) for i in range(nelectrodes): for j in range(nelectrodes): lfp1 = lfps[i, syllable_si:syllable_ei] lfp2 = lfps[j, syllable_si:syllable_ei] if i != j: x = coherency(lfp1, lfp2, lags, window_fraction=window_fraction, noise_floor_db=noise_db) else: x = correlation_function(lfp1, lfp2, lags) _e1 = electrode_indices[i] _e2 = electrode_indices[j] i1 = electrode_order.index(_e1) i2 = electrode_order.index(_e2) # print 'i=%d, j=%d, e1=%d, e2=%d, i1=%d, i2=%d' % (i, j, _e1, _e2, i1, i2) cross_mat[i1, i2, :] = x # make a plot figsize = (24.0, 13.5) fig = plt.figure(figsize=figsize) plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10) gs = plt.GridSpec(nelectrodes, nelectrodes) for i in range(nelectrodes): for j in range(i + 1): ax = plt.subplot(gs[i, j]) plt.axhline(0, c="k") plt.axvline(0, c="k") _e1 = electrode_order[i] _e2 = electrode_order[j] plt.plot(lags_ms, cross_mat[i, j, :], "k-", linewidth=2.0) plt.xticks([]) plt.yticks([]) plt.axis("tight") plt.ylim(-0.5, 1.0) if j == 0: plt.ylabel("E%d" % electrode_order[i]) if i == nelectrodes - 1: plt.xlabel("E%d" % electrode_order[j]) if save: fname = os.path.join(get_this_dir(), "coherency_matrix.svg") plt.savefig(fname, facecolor="w", edgecolor="none")
def test_cross_psd(self): np.random.seed(1234567) sr = 1000.0 dur = 1.0 nt = int(dur * sr) t = np.arange(nt) / sr # create a simple signal freqs = list() freqs.extend(np.arange(8, 12)) freqs.extend(np.arange(60, 71)) freqs.extend(np.arange(130, 151)) s1 = np.zeros([nt]) for f in freqs: s1 += np.sin(2 * np.pi * f * t) s1 /= s1.max() # create a noise corrupted, bandpassed filtered version of s1 noise = np.random.randn(nt) * 1e-1 # s2 = convolve1d(s1, filt, mode='mirror') + noise s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.) s2 /= s2.max() s2 += noise # compute the signal's power spectrums welch_freq1, welch_psd1 = welch(s1, fs=sr) welch_freq2, welch_psd2 = welch(s2, fs=sr) welch_psd_max = max(welch_psd1.max(), welch_psd2.max()) welch_psd1 /= welch_psd_max welch_psd2 /= welch_psd_max # compute the auto-correlation functions lags = np.arange(-200, 201) acf1 = correlation_function(s1, s1, lags, normalize=True) acf2 = correlation_function(s2, s2, lags, normalize=True) # compute the cross correlation functions cf12 = correlation_function(s1, s2, lags, normalize=True) coh12 = coherency(s1, s2, lags, window_fraction=0.75, noise_floor_db=100.) # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power # spectrum for some numerical reason shift_lags = fftshift(lags) if len(lags) % 2 == 1: # shift zero from end of shift_lags to beginning shift_lags = np.roll(shift_lags, 1) acf1_shift = correlation_function(s1, s1, shift_lags) acf2_shift = correlation_function(s2, s2, shift_lags) # compute the power spectra from the auto-spectra ps1 = fft(acf1_shift) ps1_freq = fftfreq(len(acf1), d=1.0 / sr) fi = ps1_freq > 0 ps1 = ps1[fi] assert np.sum( np.abs(ps1.imag) > 1e-8 ) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum( np.abs(ps1.imag) > 1e-8) ps1_auto = np.abs(ps1.real) ps1_auto_freq = ps1_freq[fi] ps2 = fft(acf2_shift) ps2_freq = fftfreq(len(acf2), d=1.0 / sr) fi = ps2_freq > 0 ps2 = ps2[fi] assert np.sum(np.abs(ps2.imag) > 1e-8 ) == 0, "Nonzero imaginary part for fft(acf2)" ps2_auto = np.abs(ps2.real) ps2_auto_freq = ps2_freq[fi] assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto" assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto" # compute the cross spectral density from the correlation function cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True) psd12 = fft(cf12_shift) psd12_freq = fftfreq(len(cf12_shift), d=1.0 / sr) fi = psd12_freq > 0 psd12 = np.abs(psd12[fi]) psd12_freq = psd12_freq[fi] # compute the cross spectral density from the power spectra psd12_welch = welch_psd1 * welch_psd2 psd12_welch /= psd12_welch.max() # compute the coherence from the cross spectral density cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \ coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True) coh12_freqspace /= np.abs(coh12_freqspace).max() # weight the coherence by one minus the normalized standard deviation coherence_std = np.sqrt(coherence_var) # cweight = coherence_std / coherence_std.sum() # coherence_weighted = (1.0 - cweight)*coherence coherence_weighted = coherence - coherence_std coherence_weighted[coherence_weighted < 0] = 0 # compute the coherence from the fft of the coherency coherence2 = fft(fftshift(coh12)) coherence2_freq = fftfreq(len(coherence2), d=1.0 / sr) fi = coherence2_freq > 0 coherence2 = np.abs(coherence2[fi]) coherence2_freq = coherence2_freq[fi] """ plt.figure() ax = plt.subplot(2, 1, 1) plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75) plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9) plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75) plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.legend(['denom', '12', '1', '2']) ax = plt.subplot(2, 1, 2) plt.plot(psd12_freq, coherence, 'b-') plt.axis('tight') plt.show() """ # normalize the cross-spectral density and power spectra psd12 /= psd12.max() ps_auto_max = max(ps1_auto.max(), ps2_auto.max()) ps1_auto /= ps_auto_max ps2_auto /= ps_auto_max # make some plots plt.figure() nrows = 2 ncols = 2 # plot the signals ax = plt.subplot(nrows, ncols, 1) plt.plot(t, s1, 'k-', linewidth=2.0) plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0) plt.xlabel('Time (s)') plt.ylabel('Signal') plt.axis('tight') # plot the spectra ax = plt.subplot(nrows, ncols, 2) plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85) plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85) plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.xlabel('Frequency (Hz)') plt.ylabel('Power') # plot the correlation functions ax = plt.subplot(nrows, ncols, 3) plt.axhline(0, c='k') plt.plot(lags, acf1, 'k-', linewidth=2.0) plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75) plt.plot(coh12_freqspace_t * 1e3, coh12_freqspace, 'm-', linewidth=1.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'r', 'g', 'b', 'c'], ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) # plot the cross spectral density ax = plt.subplot(nrows, ncols, 4) handles = custom_legend(['g', 'k', 'b'], ['CSD', 'Coherence', 'Weighted']) plt.axhline(0, c='k') plt.axhline(1, c='k') plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0) plt.errorbar(cfreq, coherence, yerr=np.sqrt(coherence_var), fmt='k-', ecolor='r', linewidth=3.0, elinewidth=5.0, alpha=0.8) plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75) plt.xlabel('Frequency (Hz)') plt.ylabel('Cross-spectral Density/Coherence') plt.legend(handles=handles) """ plt.figure() plt.axhline(0, c='k') plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) """ plt.show()
def test_cross_psd(self): np.random.seed(1234567) sr = 1000.0 dur = 1.0 nt = int(dur*sr) t = np.arange(nt) / sr # create a simple signal freqs = list() freqs.extend(np.arange(8, 12)) freqs.extend(np.arange(60, 71)) freqs.extend(np.arange(130, 151)) s1 = np.zeros([nt]) for f in freqs: s1 += np.sin(2*np.pi*f*t) s1 /= s1.max() # create a noise corrupted, bandpassed filtered version of s1 noise = np.random.randn(nt)*1e-1 # s2 = convolve1d(s1, filt, mode='mirror') + noise s2 = bandpass_filter(s1, sample_rate=sr, low_freq=40., high_freq=90.) s2 /= s2.max() s2 += noise # compute the signal's power spectrums welch_freq1,welch_psd1 = welch(s1, fs=sr) welch_freq2,welch_psd2 = welch(s2, fs=sr) welch_psd_max = max(welch_psd1.max(), welch_psd2.max()) welch_psd1 /= welch_psd_max welch_psd2 /= welch_psd_max # compute the auto-correlation functions lags = np.arange(-200, 201) acf1 = correlation_function(s1, s1, lags, normalize=True) acf2 = correlation_function(s2, s2, lags, normalize=True) # compute the cross correlation functions cf12 = correlation_function(s1, s2, lags, normalize=True) coh12 = coherency(s1, s2, lags, window_fraction=0.75, noise_floor_db=100.) # do an FFT shift to the lags and the window, otherwise the FFT of the ACFs is not equal to the power # spectrum for some numerical reason shift_lags = fftshift(lags) if len(lags) % 2 == 1: # shift zero from end of shift_lags to beginning shift_lags = np.roll(shift_lags, 1) acf1_shift = correlation_function(s1, s1, shift_lags) acf2_shift = correlation_function(s2, s2, shift_lags) # compute the power spectra from the auto-spectra ps1 = fft(acf1_shift) ps1_freq = fftfreq(len(acf1), d=1.0/sr) fi = ps1_freq > 0 ps1 = ps1[fi] assert np.sum(np.abs(ps1.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf1) (%d)" % np.sum(np.abs(ps1.imag) > 1e-8) ps1_auto = np.abs(ps1.real) ps1_auto_freq = ps1_freq[fi] ps2 = fft(acf2_shift) ps2_freq = fftfreq(len(acf2), d=1.0/sr) fi = ps2_freq > 0 ps2 = ps2[fi] assert np.sum(np.abs(ps2.imag) > 1e-8) == 0, "Nonzero imaginary part for fft(acf2)" ps2_auto = np.abs(ps2.real) ps2_auto_freq = ps2_freq[fi] assert np.sum(ps1_auto < 0) == 0, "negatives in ps1_auto" assert np.sum(ps2_auto < 0) == 0, "negatives in ps2_auto" # compute the cross spectral density from the correlation function cf12_shift = correlation_function(s1, s2, shift_lags, normalize=True) psd12 = fft(cf12_shift) psd12_freq = fftfreq(len(cf12_shift), d=1.0/sr) fi = psd12_freq > 0 psd12 = np.abs(psd12[fi]) psd12_freq = psd12_freq[fi] # compute the cross spectral density from the power spectra psd12_welch = welch_psd1*welch_psd2 psd12_welch /= psd12_welch.max() # compute the coherence from the cross spectral density cfreq,coherence,coherence_var,phase_coherence,phase_coherence_var,coh12_freqspace,coh12_freqspace_t = \ coherence_jn(s1, s2, sample_rate=sr, window_length=0.100, increment=0.050, return_coherency=True) coh12_freqspace /= np.abs(coh12_freqspace).max() # weight the coherence by one minus the normalized standard deviation coherence_std = np.sqrt(coherence_var) # cweight = coherence_std / coherence_std.sum() # coherence_weighted = (1.0 - cweight)*coherence coherence_weighted = coherence - coherence_std coherence_weighted[coherence_weighted < 0] = 0 # compute the coherence from the fft of the coherency coherence2 = fft(fftshift(coh12)) coherence2_freq = fftfreq(len(coherence2), d=1.0/sr) fi = coherence2_freq > 0 coherence2 = np.abs(coherence2[fi]) coherence2_freq = coherence2_freq[fi] """ plt.figure() ax = plt.subplot(2, 1, 1) plt.plot(ps1_auto_freq, ps1_auto*ps2_auto, 'c-', linewidth=2.0, alpha=0.75) plt.plot(psd12_freq, psd12, 'g-', linewidth=2.0, alpha=0.9) plt.plot(ps1_auto_freq, ps1_auto, 'k-', linewidth=2.0, alpha=0.75) plt.plot(ps2_auto_freq, ps2_auto, 'r-', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.legend(['denom', '12', '1', '2']) ax = plt.subplot(2, 1, 2) plt.plot(psd12_freq, coherence, 'b-') plt.axis('tight') plt.show() """ # normalize the cross-spectral density and power spectra psd12 /= psd12.max() ps_auto_max = max(ps1_auto.max(), ps2_auto.max()) ps1_auto /= ps_auto_max ps2_auto /= ps_auto_max # make some plots plt.figure() nrows = 2 ncols = 2 # plot the signals ax = plt.subplot(nrows, ncols, 1) plt.plot(t, s1, 'k-', linewidth=2.0) plt.plot(t, s2, 'r-', alpha=0.75, linewidth=2.0) plt.xlabel('Time (s)') plt.ylabel('Signal') plt.axis('tight') # plot the spectra ax = plt.subplot(nrows, ncols, 2) plt.plot(welch_freq1, welch_psd1, 'k-', linewidth=2.0, alpha=0.85) plt.plot(ps1_auto_freq, ps1_auto, 'k--', linewidth=2.0, alpha=0.85) plt.plot(welch_freq2, welch_psd2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(ps2_auto_freq, ps2_auto, 'r--', linewidth=2.0, alpha=0.75) plt.axis('tight') plt.xlabel('Frequency (Hz)') plt.ylabel('Power') # plot the correlation functions ax = plt.subplot(nrows, ncols, 3) plt.axhline(0, c='k') plt.plot(lags, acf1, 'k-', linewidth=2.0) plt.plot(lags, acf2, 'r-', alpha=0.75, linewidth=2.0) plt.plot(lags, cf12, 'g-', alpha=0.75, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=2.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'm-', linewidth=1.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'r', 'g', 'b', 'c'], ['acf1', 'acf2', 'cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) # plot the cross spectral density ax = plt.subplot(nrows, ncols, 4) handles = custom_legend(['g', 'k', 'b'], ['CSD', 'Coherence', 'Weighted']) plt.axhline(0, c='k') plt.axhline(1, c='k') plt.plot(psd12_freq, psd12, 'g-', linewidth=3.0) plt.errorbar(cfreq, coherence, yerr=np.sqrt(coherence_var), fmt='k-', ecolor='r', linewidth=3.0, elinewidth=5.0, alpha=0.8) plt.plot(cfreq, coherence_weighted, 'b-', linewidth=3.0, alpha=0.75) plt.xlabel('Frequency (Hz)') plt.ylabel('Cross-spectral Density/Coherence') plt.legend(handles=handles) """ plt.figure() plt.axhline(0, c='k') plt.plot(lags, cf12, 'k-', alpha=1, linewidth=2.0) plt.plot(lags, coh12, 'b-', linewidth=3.0, alpha=0.75) plt.plot(coh12_freqspace_t*1e3, coh12_freqspace, 'r-', linewidth=2.0, alpha=0.95) plt.xlabel('Lag (ms)') plt.ylabel('Correlation Function') plt.axis('tight') plt.ylim(-0.5, 1.0) handles = custom_legend(['k', 'b', 'r'], ['cf12', 'coh12', 'coh12_f']) plt.legend(handles=handles) """ plt.show()
def draw_figures(): d = get_full_data('GreBlu9508M', 'Site4', 'Call1', 'L', 287) syllable_index = 1 syllable_start = d['syllable_props'][syllable_index]['start_time'] - 0.020 syllable_end = d['syllable_props'][syllable_index]['end_time'] + 0.030 sr = d['lfp_sample_rate'] lfp_mean = d['lfp'].mean(axis=0) lfp_t = np.arange(lfp_mean.shape[1]) / sr nelectrodes,nt = lfp_mean.shape lfp_i = (lfp_t >= syllable_start) & (lfp_t <= syllable_end) electrode_order = d['electrode_order'] # compute the cross coherency between each pair of electrodes nelectrodes = 16 lags = np.arange(-20, 21) lags_ms = (lags / sr)*1e3 nlags = len(lags) window_fraction = 0.60 noise_floor_db = 25, cross_coherency = np.zeros([nelectrodes, nelectrodes, nlags]) for i in range(nelectrodes): for j in range(nelectrodes): if i == j: continue lfp1 = lfp_mean[i, lfp_i] lfp2 = lfp_mean[j, lfp_i] cross_coherency[i, j, :] = coherency(lfp1, lfp2, lags, window_fraction=window_fraction, noise_floor_db=noise_floor_db) figsize = (24, 13) fig = plt.figure(figsize=figsize) fig.subplots_adjust(top=0.95, bottom=0.02, right=0.97, left=0.03, hspace=0.20, wspace=0.20) gs = plt.GridSpec(nelectrodes, nelectrodes) for k in range(nelectrodes): for j in range(k): ax = plt.subplot(gs[k, j]) plt.axhline(0, c='k') plt.plot(lags_ms, cross_coherency[k, j], 'k-', linewidth=2.0, alpha=0.8) plt.axis('tight') plt.ylim(-.25, 0.5) plt.yticks([]) plt.xticks([]) if k == nelectrodes-1: plt.xlabel('E%d' % electrode_order[j]) xtks = [-40, 0, 40] plt.xticks(xtks, ['%d' % x for x in xtks]) if j == 0: plt.ylabel('E%d' % electrode_order[k]) ytks = [-0.2, 0.4] plt.yticks(ytks, ['%0.1f' % x for x in ytks]) ax = plt.subplot(gs[:7, (nelectrodes-8):]) voffset = 5 for n in range(nelectrodes): plt.plot(lfp_t, lfp_mean[nelectrodes-n-1, :] + voffset*n, 'k-', linewidth=3.0, alpha=0.75) plt.axis('tight') ytick_locs = np.arange(nelectrodes) * voffset plt.yticks(ytick_locs, list(reversed(d['electrode_order']))) plt.ylabel('Electrode') plt.axvline(syllable_start, c='k', linestyle='--', linewidth=3.0, alpha=0.7) plt.axvline(syllable_end, c='k', linestyle='--', linewidth=3.0, alpha=0.7) plt.xlabel('Time (s)') fname = os.path.join(get_this_dir(), 'figure.svg') plt.savefig(fname, facecolor='w', edgecolor='none') plt.show()
def compute_spectra_and_coherence_single_electrode(lfp1, lfp2, sample_rate, e1, e2, window_length=0.060, increment=None, log=True, window_fraction=0.60, noise_floor_db=25, lags=np.arange(-20, 21, 1), psd_stats=None): """ :param lfp1: An array of shape (ntrials, nt) :param lfp2: An array of shape (ntrials, nt) :return: """ # compute the mean (locked) spectra lfp1_mean = lfp1.mean(axis=0) lfp2_mean = lfp2.mean(axis=0) if increment is None: increment = 2.0 / sample_rate pfreq,psd1,ps_var,phase = power_spectrum_jn(lfp1_mean, sample_rate, window_length, increment) pfreq,psd2,ps_var,phase = power_spectrum_jn(lfp2_mean, sample_rate, window_length, increment) if log: log_transform(psd1) log_transform(psd2) c12 = coherency(lfp1_mean, lfp2_mean, lags, window_fraction=window_fraction, noise_floor_db=noise_floor_db) # compute the nonlocked spectra coherence c12_pertrial = list() ntrials,nt = lfp1.shape psd1_ms_all = list() psd2_ms_all = list() for k in range(ntrials): i = np.ones([ntrials], dtype='bool') i[k] = False lfp1_jn_mean = lfp1[i, :].mean(axis=0) lfp2_jn_mean = lfp2[i, :].mean(axis=0) lfp1_ms = lfp1[k, :] - lfp1_jn_mean lfp2_ms = lfp2[k, :] - lfp2_jn_mean pfreq,psd1_ms,ps_var_ms,phase_ms = power_spectrum_jn(lfp1_ms, sample_rate, window_length, increment) pfreq,psd2_ms,ps_var_ms,phase_ms = power_spectrum_jn(lfp2_ms, sample_rate, window_length, increment) if log: log_transform(psd1_ms) log_transform(psd2_ms) psd1_ms_all.append(psd1_ms) psd2_ms_all.append(psd2_ms) c12_ms = coherency(lfp1_ms, lfp2_ms, lags, window_fraction=window_fraction, noise_floor_db=noise_floor_db) c12_pertrial.append(c12_ms) psd1_ms_all = np.array(psd1_ms_all) psd2_ms_all = np.array(psd2_ms_all) psd1_ms = psd1_ms_all.mean(axis=0) psd2_ms = psd2_ms_all.mean(axis=0) if psd_stats is not None: psd_mean1,psd_std1 = psd_stats[e1] psd_mean2,psd_std2 = psd_stats[e2] psd1 -= psd_mean1 psd1 /= psd_std1 psd2 -= psd_mean2 psd2 /= psd_std2 psd1_ms -= psd_mean1 psd1_ms /= psd_std1 psd2_ms -= psd_mean2 psd2_ms /= psd_std2 c12_pertrial = np.array(c12_pertrial) c12_nonlocked = c12_pertrial.mean(axis=0) # compute the coherence per trial then take the average c12_totals = list() for k in range(ntrials): c12 = coherency(lfp1[k, :], lfp2[k, :], lags, window_fraction=window_fraction, noise_floor_db=noise_floor_db) c12_totals.append(c12) c12_totals = np.array(c12_totals) c12_total = c12_totals.mean(axis=0) return pfreq, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total
def compute_spectra_and_coherence_multi_electrode_single_trial(lfps, sample_rate, electrode_indices, electrode_order, window_length=0.060, increment=None, log=True, window_fraction=0.60, noise_floor_db=25, lags=np.arange(-20, 21, 1), psd_stats=None): """ :param lfps: an array of shape (ntrials, nelectrodes, nt) :return: """ if increment is None: increment = 2.0 / sample_rate nelectrodes,nt = lfps.shape freqs = get_freqs(sample_rate, window_length, increment) lags_ms = get_lags_ms(sample_rate, lags) spectra = np.zeros([nelectrodes, len(freqs)]) cross_mat = np.zeros([nelectrodes, nelectrodes, len(lags_ms)]) for k in range(nelectrodes): _e1 = electrode_indices[k] i1 = electrode_order.index(_e1) lfp1 = lfps[k, :] freqs,psd1,ps_var,phase = power_spectrum_jn(lfp1, sample_rate, window_length, increment) if log: log_transform(psd1) if psd_stats is not None: psd_mean,psd_std = psd_stats[_e1] """ plt.figure() plt.subplot(2, 2, 1) plt.plot(freqs, psd1, 'k-') plt.title('PSD (%d)' % _e1) plt.axis('tight') plt.subplot(2, 2, 3) plt.plot(freqs, psd_mean, 'g-') plt.title('Mean') plt.axis('tight') plt.subplot(2, 2, 4) plt.plot(freqs, psd_std, 'c-') plt.title('STD') plt.axis('tight') plt.subplot(2, 2, 2) psd1_z = deepcopy(psd1) psd1_z -= psd_mean psd1_z /= psd_std plt.plot(freqs, psd1_z, 'r-') plt.title('Zscored') plt.axis('tight') """ psd1 -= psd_mean psd1 /= psd_std spectra[i1, :] = psd1 for j in range(k): _e2 = electrode_indices[j] i2 = electrode_order.index(_e2) lfp2 = lfps[j, :] cf = coherency(lfp1, lfp2, lags, window_fraction=window_fraction, noise_floor_db=noise_floor_db) """ freqs,c12,c_var_amp,c_phase,c_phase_var,coherency,coherency_t = coherence_jn(lfp1, lfp2, sample_rate, window_length, increment, return_coherency=True) """ cross_mat[i1, i2] = cf cross_mat[i2, i1] = cf[::-1] return spectra, cross_mat