Exemplo n.º 1
0
def draw_coherency_matrix(
    bird,
    block,
    segment,
    hemi,
    stim_id,
    trial,
    syllable_index,
    data_dir="/auto/tdrive/mschachter/data",
    exp=None,
    save=True,
):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, "%s.h5" % bird)
        stim_file = os.path.join(bird_dir, "stims.h5")
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable["id"] == stim_id
    stim_times = zip(etable[i]["start_time"].values, etable[i]["end_time"].values)
    stim_times.sort(key=operator.itemgetter(0))

    start_time, end_time = stim_times[trial]
    stim_dur = float(end_time - start_time)

    # get a slice of the LFP
    lfp_data = exp.get_lfp_slice(seg, start_time, end_time)
    electrode_indices, lfps, sample_rate = lfp_data[hemi]

    # rescale the LFPs to they are in uV
    lfps *= 1e6

    # get the log spectrogram of the stimulus
    stim_spec_t, stim_spec_freq, stim_spec = exp.get_spectrogram_slice(seg, start_time, end_time)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20 * np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002 * sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start, syllable_end, syllable_max_amp = events[syllable_index]
    syllable_si = int(syllable_start * sample_rate)
    syllable_ei = int(syllable_end * sample_rate)

    # compute all cross and auto-correlations
    if hemi == "L":
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT
    else:
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT

    lags = np.arange(-20, 21)
    lags_ms = (lags / sample_rate) * 1e3

    window_fraction = 0.35
    noise_db = 25.0
    nelectrodes = len(electrode_order)
    cross_mat = np.zeros([nelectrodes, nelectrodes, len(lags)])
    for i in range(nelectrodes):
        for j in range(nelectrodes):

            lfp1 = lfps[i, syllable_si:syllable_ei]
            lfp2 = lfps[j, syllable_si:syllable_ei]

            if i != j:
                x = coherency(lfp1, lfp2, lags, window_fraction=window_fraction, noise_floor_db=noise_db)
            else:
                x = correlation_function(lfp1, lfp2, lags)

            _e1 = electrode_indices[i]
            _e2 = electrode_indices[j]

            i1 = electrode_order.index(_e1)
            i2 = electrode_order.index(_e2)

            # print 'i=%d, j=%d, e1=%d, e2=%d, i1=%d, i2=%d' % (i, j, _e1, _e2, i1, i2)

            cross_mat[i1, i2, :] = x

    # make a plot
    figsize = (24.0, 13.5)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)

    gs = plt.GridSpec(nelectrodes, nelectrodes)
    for i in range(nelectrodes):
        for j in range(i + 1):
            ax = plt.subplot(gs[i, j])
            plt.axhline(0, c="k")
            plt.axvline(0, c="k")

            _e1 = electrode_order[i]
            _e2 = electrode_order[j]

            plt.plot(lags_ms, cross_mat[i, j, :], "k-", linewidth=2.0)
            plt.xticks([])
            plt.yticks([])
            plt.axis("tight")
            plt.ylim(-0.5, 1.0)
            if j == 0:
                plt.ylabel("E%d" % electrode_order[i])
            if i == nelectrodes - 1:
                plt.xlabel("E%d" % electrode_order[j])

    if save:
        fname = os.path.join(get_this_dir(), "coherency_matrix.svg")
        plt.savefig(fname, facecolor="w", edgecolor="none")
Exemplo n.º 2
0
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, trial, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, '%s.h5' % bird)
        stim_file = os.path.join(bird_dir, 'stims.h5')
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable['id'] == stim_id
    stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values)
    stim_times.sort(key=operator.itemgetter(0))

    start_time,end_time = stim_times[trial]
    stim_dur = float(end_time - start_time)

    # get a slice of the LFP
    lfp_data = exp.get_lfp_slice(seg, start_time, end_time, zscore=True)
    electrode_indices,lfps,sample_rate = lfp_data[hemi]

    # get the log spectrogram of the stimulus
    stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time, end_time)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002*sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start,syllable_end,syllable_max_amp = events[syllable_index]
    amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min()
    last_syllable_end = events[-1, 1] + 0.025

    i1 = electrode_indices.index(e1)
    i2 = electrode_indices.index(e2)

    lfp1 = lfps[i1, :]
    lfp2 = lfps[i2, :]

    t = np.arange(len(lfp1)) / sample_rate
    legend = ['E%d' % e1, 'E%d' % e2]

    if hemi == 'L':
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT
    else:
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT

    # get the power spectrum stats for this site
    psd_stats = get_psd_stats(bird, block, segment, hemi)

    # compute the power spectra and cross coherence for all electrodes
    lags_ms = get_lags_ms(sample_rate)
    spectra,cross_mat = compute_spectra_and_coherence_multi_electrode_single_trial(lfps, sample_rate, electrode_indices, electrode_order, psd_stats=psd_stats)

    # plot the stimulus and raw LFP for two electrodes
    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)

    ax = plt.subplot(2, 1, 1)
    plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000,
                     colormap='SpectroColorMap', colorbar=False)
    # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    ax = plt.subplot(2, 1, 2)
    plt.plot(t, lfp1, 'b-', linewidth=3.0)
    plt.plot(t, lfp2, 'r-', linewidth=3.0, alpha=0.7)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('LFP (z-scored)')
    plt.legend(legend)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    fname = os.path.join(get_this_dir(), 'raw.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # restrict the lfps to a single syllable
    syllable_si = int(syllable_start*sample_rate)
    syllable_ei = int(syllable_end*sample_rate)
    lfp1 = lfp1[syllable_si:syllable_ei]
    lfp2 = lfp2[syllable_si:syllable_ei]

    # plot the two power spectra
    psd_ub = 6
    psd_lb = 0
    i1 = electrode_order.index(e1)
    i2 = electrode_order.index(e2)

    a1 = spectra[i1, :]
    a2 = spectra[i2, :]
    freqs = get_freqs(sample_rate)

    figsize = (10.0, 4.0)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.90, bottom=0.10, left=0.10, right=0.99, hspace=0.10)

    ax = plt.subplot(1, 2, 1)
    plt.plot(freqs, a1, 'b-', linewidth=3.0)
    plt.plot(freqs, a2, 'r-', linewidth=3.0)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power (z-scored)')
    plt.title('Power Spectrum')
    handles = custom_legend(['b', 'r'], legend)
    plt.legend(handles=handles, fontsize='small')
    plt.axis('tight')
    plt.ylim(psd_lb, psd_ub)

    # plot the coherency
    cf_lb = -0.1
    cf_ub = 0.3
    coh = cross_mat[i1, i2, :]
    ax = plt.subplot(1, 2, 2)
    plt.axhline(0, c='k')
    plt.axvline(0, c='k')
    plt.plot(lags_ms, coh, 'g-', linewidth=3.0)
    plt.xlabel('Frequency (Hz)')
    plt.title('Coherency')
    plt.axis('tight')
    plt.ylim(cf_lb, cf_ub)

    fname = os.path.join(get_this_dir(), 'auto+cross.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')

    # compute all cross and auto-correlations
    nelectrodes = len(electrode_order)

    # make a plot
    figsize = (24.0, 13.5)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)

    gs = plt.GridSpec(nelectrodes, nelectrodes)
    for i in range(nelectrodes):
        for j in range(i+1):
            ax = plt.subplot(gs[i, j])
            plt.axhline(0, c='k')
            plt.axvline(0, c='k')

            _e1 = electrode_order[i]
            _e2 = electrode_order[j]

            clr = 'k'
            if i == j:
                if _e1 == e1:
                    clr = 'b'
                elif _e1 == e2:
                    clr = 'r'
            else:
                if _e1 == e1 and _e2 == e2:
                    clr = 'g'

            if i == j:
                plt.plot(freqs, spectra[i, :], '-', c=clr, linewidth=2.0)
            else:
                plt.plot(lags_ms, cross_mat[i, j, :], '-', c=clr, linewidth=2.0)
            plt.xticks([])
            plt.yticks([])
            plt.axis('tight')
            if i != j:
                plt.ylim(cf_lb, cf_ub)
            else:
                plt.axhline(0, c='k')
                plt.ylim(psd_lb, psd_ub)

            if j == 0:
                plt.ylabel('E%d' % electrode_order[i])
            if i == nelectrodes-1:
                plt.xlabel('E%d' % electrode_order[j])

    fname = os.path.join(get_this_dir(), 'cross_all.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Exemplo n.º 3
0
def draw_figures(bird, block, segment, hemi, e1, e2, stim_id, syllable_index, data_dir='/auto/tdrive/mschachter/data', exp=None):

    # load up the experiment
    if exp is None:
        bird_dir = os.path.join(data_dir, bird)
        exp_file = os.path.join(bird_dir, '%s.h5' % bird)
        stim_file = os.path.join(bird_dir, 'stims.h5')
        exp = Experiment.load(exp_file, stim_file)
    seg = exp.get_segment(block, segment)

    # get the start and end times of the stimulus
    etable = exp.get_epoch_table(seg)

    i = etable['id'] == stim_id
    stim_times = zip(etable[i]['start_time'].values, etable[i]['end_time'].values)
    stim_times.sort(key=operator.itemgetter(0))
    stim_times = np.array(stim_times)

    stim_durs = stim_times[:, 1] - stim_times[:, 0]
    stim_dur = stim_durs.min()

    # aggregate the LFPs across trials
    lfps = list()
    sample_rate = None
    electrode_indices = None
    for start_time,end_time in stim_times:

        # get a slice of the LFP
        lfp_data = exp.get_lfp_slice(seg, start_time, end_time)
        electrode_indices,the_lfps,sample_rate = lfp_data[hemi]

        stim_dur_i = int(stim_dur*sample_rate)
        lfps.append(the_lfps[:, :stim_dur_i])
    lfps = np.array(lfps)

    # rescale the LFPs to they are in uV
    lfps *= 1e6

    # get the log spectrogram of the stimulus
    start_time_0 = stim_times[0][0]
    end_time_0 = stim_times[0][1]
    stim_spec_t,stim_spec_freq,stim_spec = exp.get_spectrogram_slice(seg, start_time_0, end_time_0)
    stim_spec_t = np.linspace(0, stim_dur, len(stim_spec_t))
    stim_spec_dt = np.diff(stim_spec_t)[0]
    nz = stim_spec > 0
    stim_spec[nz] = 20*np.log10(stim_spec[nz]) + 100
    stim_spec[stim_spec < 0] = 0

    # get the amplitude envelope
    amp_env = stim_spec.std(axis=0, ddof=1)
    amp_env -= amp_env.min()
    amp_env /= amp_env.max()

    # segment the amplitude envelope into syllables
    merge_thresh = int(0.002*sample_rate)
    events = break_envelope_into_events(amp_env, threshold=0.05, merge_thresh=merge_thresh)

    # translate the event indices into actual times
    events *= stim_spec_dt

    syllable_start,syllable_end,syllable_max_amp = events[syllable_index]
    syllable_start -= 0.005
    syllable_end += 0.010
    amp_env_rs = amp_env*(stim_spec_freq.max() - stim_spec_freq.min()) + stim_spec_freq.min()
    last_syllable_end = events[-1, 1] + 0.025

    i1 = electrode_indices.index(e1)
    i2 = electrode_indices.index(e2)

    lfp1 = lfps[:, i1, :]
    lfp2 = lfps[:, i2, :]

    # zscore the LFP
    lfp1 -= lfp1.mean()
    lfp1 /= lfp1.std(ddof=1)
    lfp2 -= lfp2.mean()
    lfp2 /= lfp2.std(ddof=1)

    ntrials,nelectrodes,nt = lfps.shape
    ntrials_to_plot = 5

    t = np.arange(nt) / sample_rate

    # plot the stimulus and raw LFP for two electrodes
    figsize = (24.0, 10)
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(top=0.95, bottom=0.05, left=0.03, right=0.99, hspace=0.10)
    gs = plt.GridSpec(3, 100)

    ax = plt.subplot(gs[0, :60])
    plot_spectrogram(stim_spec_t, stim_spec_freq, stim_spec, ax=ax, ticks=True, fmin=300, fmax=8000,
                     colormap='SpectroColorMap', colorbar=False)
    # plt.plot(stim_spec_t, amp_env_rs, 'k-', linewidth=2.0, alpha=0.50)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # plot the first LFP (all trials)
    ax = plt.subplot(gs[1, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp1[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp1.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e1)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    ax = plt.subplot(gs[2, :60])
    for k in range(ntrials_to_plot):
        plt.plot(t, lfp2[k, :], '-', linewidth=2.0, alpha=0.75)
    plt.plot(t, lfp2.mean(axis=0), 'k-', linewidth=3.0)
    plt.axvline(syllable_start, c='k', linestyle='dashed', linewidth=3.0)
    plt.axvline(syllable_end, c='k', linestyle='dashed', linewidth=3.0)
    plt.xlabel('Time (ms)')
    plt.ylabel('E%d (z-scored)' % e2)
    plt.axis('tight')
    plt.xlim(0, last_syllable_end)

    # restrict the lfps to a single syllable
    print 'syllable_start=%f, syllable_end=%f' % (syllable_start, syllable_end)
    syllable_si = int(syllable_start*sample_rate)
    syllable_ei = int(syllable_end*sample_rate)
    lfp1 = lfp1[:, syllable_si:syllable_ei]
    lfp2 = lfp2[:, syllable_si:syllable_ei]

    # compute the trial averaged and mean subtracted lfps
    lfp1_mean,lfp1_ms = compute_avg_and_ms(lfp1)
    lfp2_mean,lfp2_ms = compute_avg_and_ms(lfp2)

    psd_stats = get_psd_stats(bird, block, segment, hemi)
    freqs, psd1, psd2, psd1_ms, psd2_ms, c12, c12_nonlocked, c12_total = compute_spectra_and_coherence_single_electrode(lfp1, lfp2, sample_rate, e1, e2, psd_stats=psd_stats)
    lags_ms = get_lags_ms(sample_rate)

    lfp_absmax = max(np.abs(lfp1_mean).max(), np.abs(lfp2_mean).max())
    lfp_ms_absmax = max(np.abs(lfp1_ms).max(), np.abs(lfp2_ms).max())

    ax = plt.subplot(gs[0, 65:80])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_mean, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_mean, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg LFP')
    leg = custom_legend(['k', '#c0c0c0'], ['E%d' % e1, 'E%d' % e2])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_absmax, lfp_absmax)

    ax = plt.subplot(gs[0, 85:])
    plt.axhline(0, c='k')
    plt.plot(t[syllable_si:syllable_ei], lfp1_ms, 'k-', linewidth=3.0, alpha=1.)
    plt.plot(t[syllable_si:syllable_ei], lfp2_ms, '-', c='#c0c0c0', linewidth=3.0)
    # plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub LFP')
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-lfp_ms_absmax, lfp_ms_absmax)

    psd_max = max(psd1.max(), psd2.max(), psd1_ms.max(), psd2_ms.max())

    ax = plt.subplot(gs[1, 65:80])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Trial-avg Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[1, 85:])
    plt.axhline(0, c='k')
    plt.plot(freqs, psd1_ms, 'k-', linewidth=3.0)
    plt.plot(freqs, psd2_ms, '-', c='#c0c0c0', linewidth=3.0)
    plt.xlabel('Time (s)')
    plt.ylabel('Mean-sub Power')
    plt.legend(handles=leg, fontsize='x-small', loc=2)
    plt.axis('tight')
    # plt.ylim(0, psd_max)

    ax = plt.subplot(gs[2, 65:80])
    plt.axhline(0, c='k')
    plt.axvline(0, c='k')
    plt.plot(lags_ms, c12_total, 'k-', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12, '-', c='r', linewidth=3.0, alpha=0.75)
    plt.plot(lags_ms, c12_nonlocked, '-', c='b', linewidth=3.0, alpha=0.75)
    plt.xlabel('Lags (ms)')
    plt.ylabel('Coherency')
    leg = custom_legend(['k', 'r', 'b'], ['Raw', 'Trial-avg', 'Mean-sub'])
    plt.legend(handles=leg, fontsize='x-small')
    plt.axis('tight')
    plt.ylim(-0.2, 0.3)

    fname = os.path.join(get_this_dir(), 'raw+coherency.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Exemplo n.º 4
0
def draw_figures(data_dir='/auto/tdrive/mschachter/data', bird='GreBlu9508M', output_dir='/auto/tdrive/mschachter/data/sounds'):

    spec_colormap()

    exp_dir = os.path.join(data_dir, bird)
    exp_file = os.path.join(exp_dir, '%s.h5' % bird)
    stim_file = os.path.join(exp_dir, 'stims.h5')
    
    exp = Experiment.load(exp_file, stim_file)
    
    bird = exp.bird_name
    all_stim_ids = list()
    # iterate through the segments and get the stim ids from each epoch table
    for seg in exp.get_all_segments():
        etable = exp.get_epoch_table(seg)
        stim_ids = etable['id'].unique()
        all_stim_ids.extend(stim_ids)

    stim_ids = np.unique(all_stim_ids)

    stim_info = list()
    for stim_id in stim_ids:
        si = exp.stim_table['id'] == stim_id
        assert si.sum() == 1, "More than one stimulus defined for id=%d" % stim_id
        stype = exp.stim_table['type'][si].values[0]
        if stype == 'call':
            stype = exp.stim_table['callid'][si].values[0]

        # get sound pressure waveform
        sound = exp.sound_manager.reconstruct(stim_id)
        waveform = np.array(sound.squeeze())
        sample_rate = float(sound.samplerate)
        stim_dur = len(waveform) / sample_rate

        stim_info.append( (stim_id, stype, sample_rate, waveform, stim_dur))

    durations = np.array([x[-1] for x in stim_info])
    max_dur = durations.max()
    min_dur = durations.min()

    max_fig_size = 15.
    min_fig_size = 5.

    for stim_id,stype,sample_rate,waveform,stim_dur in stim_info:

        fname = os.path.join(output_dir, '%s_stim_%d.wav' % (stype, stim_id))
        print 'Writing %s...' % fname
        wavfile.write(fname, sample_rate, waveform)

        dfrac = (stim_dur - min_dur) / (max_dur - min_dur)
        fig_width = dfrac*(max_fig_size - min_fig_size) + min_fig_size

        spec_t,spec_freq,spec,rms = spectrogram(waveform, sample_rate, sample_rate, 136., min_freq=300, max_freq=8000,
                                                log=True, noise_level_db=80, rectify=True, cmplx=False)

        figsize = (fig_width, 5)
        fig = plt.figure(figsize=figsize)
        plot_spectrogram(spec_t, spec_freq, spec, colormap='SpectroColorMap', colorbar=False)
        plt.title('Stim %d: %s' % (stim_id, stype))

        fname = os.path.join(output_dir, '%s_stim_%d.png' % (stype, stim_id))
        plt.savefig(fname, facecolor='w')
        plt.close('all')