Esempio n. 1
0
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, '%s.h5' % bird)
        stim_file = os.path.join(bird_dir, 'stims.h5')
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable['id'] == stim_id
    stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values)
    stim_times.sort(key=operator.itemgetter(0))
    stim_times = np.array(stim_times)

    stim_durs = stim_times[:, 1] - stim_times[:, 0]
    stim_dur = stim_durs.min()

    # aggregate the LFPs across trials
    lfps = list()
    sample_rate = None
    electrode_indices = None
    for start_time,end_time in stim_times:

        # get a slice of the LFP
        lfp_data = exp.get_lfp_slice(seg, start_time, end_time)
        electrode_indices,the_lfps,sample_rate = lfp_data[hemi]

        stim_dur_i = int(stim_dur*sample_rate)
        lfps.append(the_lfps[:, :stim_dur_i])
    lfps = np.array(lfps)

    # rescale the LFPs to they are in uV
    lfps *= 1e6

    # get the log spectrogram of the stimulus
    start_time_0 = stim_times[0][0]
    end_time_0 = stim_times[0][1]
    stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time_0, end_time_0)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002*sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start,syllable_end,syllable_max_amp = events[syllable_index]
    syllable_start -= 0.005
    syllable_end += 0.010
    amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min()
    last_syllable_end = events[-1, 1] + 0.025

    i1 = electrode_indices.index(e1)
    i2 = electrode_indices.index(e2)

    lfp1 = lfps[:, i1, :]
    lfp2 = lfps[:, i2, :]

    # zscore the LFP
    lfp1 -= lfp1.mean()
    lfp1 /= lfp1.std(ddof=1)
    lfp2 -= lfp2.mean()
    lfp2 /= lfp2.std(ddof=1)

    ntrials,nelectrodes,nt = lfps.shape
    ntrials_to_plot = 5

    t = np.arange(nt) / sample_rate

    # plot the stimulus and raw LFP for two electrodes
    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)
    gs = plt.GridSpec(3, 100)

    ax = plt.subplot(gs[0, :60])
    plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000,
                     colormap='SpectroColorMap', colorbar=False)
    # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # plot the first LFP (all trials)
    ax = plt.subplot(gs[1, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp1[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp1.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e1)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    ax = plt.subplot(gs[2, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp2[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp2.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e2)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # restrict the lfps to a single syllable
    print 'syllable_start=%f, syllable_end=%f' % (syllable_start, syllable_end)
    syllable_si = int(syllable_start*sample_rate)
    syllable_ei = int(syllable_end*sample_rate)
    lfp1 = lfp1[:, syllable_si:syllable_ei]
    lfp2 = lfp2[:, syllable_si:syllable_ei]

    # compute the trial averaged and mean subtracted lfps
    lfp1_mean,lfp1_ms = compute_avg_and_ms(lfp1)
    lfp2_mean,lfp2_ms = compute_avg_and_ms(lfp2)

    psd_stats = get_psd_stats(bird, block, segment, hemi)
    freqs, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total = compute_spectra_and_coherence_single_electrode(lfp1, lfp2, sample_rate, e1, e2, psd_stats=psd_stats)
    lags_ms = get_lags_ms(sample_rate)

    lfp_absmax = max(np.abs(lfp1_mean).max(), np.abs(lfp2_mean).max())
    lfp_ms_absmax = max(np.abs(lfp1_ms).max(), np.abs(lfp2_ms).max())

    ax = plt.subplot(gs[0, 65:80])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_mean, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_mean, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg LFP')
    leg = custom_legend(['k', '#c0c0c0'], ['E%d' % e1, 'E%d' % e2])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_absmax, lfp_absmax)

    ax = plt.subplot(gs[0, 85:])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_ms, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_ms, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub LFP')
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_ms_absmax, lfp_ms_absmax)

    psd_max = max(psd1.max(), psd2.max(), psd1_ms.max(), psd2_ms.max())

    ax = plt.subplot(gs[1, 65:80])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[1, 85:])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1_ms, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2_ms, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[2, 65:80])
    plt.axhline(0, c='k')
    plt.axvline(0, c='k')
    plt.plot(lags_ms, c12_total, 'k-', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12, '-', c='r', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12_nonlocked, '-', c='b', linewidth=3.0, alpha=0.75)
    plt.xlabel('Lags (ms)')
    plt.ylabel('Coherency')
    leg = custom_legend(['k', 'r', 'b'], ['Raw', 'Trial-avg', 'Mean-sub'])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-0.2, 0.3)

    fname = os.path.join(get_this_dir(), 'raw+coherency.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Esempio n. 2
0
def draw_figures(stim_event, stim_ids, syllable_indices, e1=5, e2=2):

    assert isinstance(stim_event, StimEventTransform)

    sample_rate = stim_event.lfp_sample_rate
    lags_ms = get_lags_ms(sample_rate)

    cross_functions_total = dict()
    cross_functions_locked = dict()
    cross_functions_nonlocked = dict()
    specs = dict()
    syllable_times = dict()
    stim_end_time = dict()

    hemi = ','.join(stim_event.rcg_names)

    # compute all cross and auto-correlations
    if hemi == 'L':
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT
    else:
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT

    freqs = None
    nelectrodes = None
    index2electrode = stim_event.index2electrode

    seg_uname = stim_event.seg_uname
    bird,block,seg = seg_uname.split('_')
    psd_stats = get_psd_stats(bird, stim_event.block_name, stim_event.segment_name, hemi)

    for stim_id,syllable_index in zip(stim_ids, syllable_indices):
        lfp = stim_event.lfp_reps_by_stim['raw'][stim_id]
        ntrials,nelectrodes,nt = lfp.shape

        # get the start and end time of the syllable
        i = (stim_event.segment_df['stim_id'] == stim_id) & (stim_event.segment_df['order'] == syllable_index)
        assert i.sum() > 0, "No syllable for stim_id=%d, order=%d" % (stim_id, syllable_index)
        assert i.sum() == 1, "More than one syllable for stim_id=%d, order=%d" % (stim_id, syllable_index)
        start_time = stim_event.segment_df[i]['start_time'].values[0]
        end_time = stim_event.segment_df[i]['end_time'].values[0]
        syllable_times[stim_id] = (start_time, end_time)

        # get the end time of the last syllable
        i = (stim_event.segment_df['stim_id'] == stim_id)
        stim_end_time[stim_id] = stim_event.segment_df[i]['end_time'].max()

        si = int((stim_event.pre_stim_time + start_time)*sample_rate)
        ei = int((stim_event.pre_stim_time + end_time)*sample_rate)

        # restrict the lfp to just the syllable time
        lfp = lfp[:, :, si:ei]
        specs[stim_id] = stim_event.spec_by_stim[stim_id]

        i1 = index2electrode.index(e1)
        lfp1 = lfp[:, i1, :]

        i2 = index2electrode.index(e2)
        lfp2 = lfp[:, i2, :]

        # compute the covariance functions
        pfreq, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total = compute_spectra_and_coherence_single_electrode(lfp1, lfp2,
                                                                                                          sample_rate,
                                                                                                          e1, e2,
                                                                                                          psd_stats=psd_stats)
        cross_functions_total[stim_id] = (psd1, psd2, c12_total)
        cross_functions_locked[stim_id] = (psd1, psd2, c12)
        cross_functions_nonlocked[stim_id] = (psd1, psd2, c12_nonlocked)

    # plot the cross covariance functions to overlap for each stim
    plot_cross_pair(stim_ids, cross_functions_total, electrode_order, freqs, lags_ms)
    plt.suptitle('Total Covariance')
    fname = os.path.join(get_this_dir(), 'cross_total.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # plot_cross_mat(stim_ids, cross_functions_locked, electrode_order, freqs)
    plot_cross_pair(stim_ids, cross_functions_locked, electrode_order, freqs, lags_ms)
    plt.suptitle('Stim-locked Covariance')
    fname = os.path.join(get_this_dir(), 'cross_locked.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # plot_cross_mat(stim_ids, cross_functions_nonlocked, electrode_order, freqs)
    plot_cross_pair(stim_ids, cross_functions_nonlocked, electrode_order, freqs, lags_ms)
    plt.suptitle('Non-locked Covariance')
    fname = os.path.join(get_this_dir(), 'cross_nonlocked.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # plot the spectrograms
    fig_height = 2
    fig_max_width = 6
    spec_lens = np.array([stim_end_time[stim_id] for stim_id in stim_ids])
    spec_ratios = spec_lens / spec_lens.max()
    spec_sample_rate = sample_rate

    for k,stim_id in enumerate(stim_ids):
        spec = specs[stim_id]
        syllable_start,syllable_end = syllable_times[stim_id]
        print 'stim_id=%d, syllable_start=%0.3f, syllable_end=%0.3f' % (stim_id, syllable_start, syllable_end)
        spec_t = np.arange(spec.shape[1]) / spec_sample_rate
        stim_end = stim_end_time[stim_id]

        figsize = (spec_ratios[k]*fig_max_width, fig_height)
        fig = plt.figure(figsize=figsize)
        ax = plt.gca()
        plot_spectrogram(spec_t, stim_event.spec_freq, spec, ax=ax, ticks=True, fmin=300., fmax=8000.,
                         colormap='SpectroColorMap', colorbar=False)
        plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
        plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
        plt.xlim(0, stim_end+0.005)
        fname = os.path.join(get_this_dir(), 'stim_spec_%d.svg' % stim_id)