def plot_frequency_response(f_db, db, ax=None, **kwargs): """Plot the frequency response of a filter. Parameters ---------- f_db : 1d array Frequency vector corresponding to attenuation decibels, in Hz. db : 1d array Degree of attenuation for each frequency specified in f_db, in dB. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot the frequency response of an FIR bandpass filter: >>> from neurodsp.filt import design_fir_filter >>> from neurodsp.filt.utils import compute_frequency_response >>> filter_coefs = design_fir_filter(fs=500, pass_type='bandpass', f_range=(1, 40)) >>> f_db, db = compute_frequency_response(filter_coefs, 1, fs=500) >>> plot_frequency_response(f_db, db) """ ax = check_ax(ax, (5, 5)) ax.plot(f_db, db, 'k') ax.set_title('Frequency response') ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Attenuation (dB)')
def plot_timeseries(signals, shade=None, colors=None, xlim=None, ylim=None, offset=0, ax=None, **plt_kwargs): """Plot time series.""" ax = check_ax(ax) if isinstance(signals, np.ndarray): signals = [signals] if isinstance(colors, str) or colors is None: colors = repeat(colors) for ind, (signal, color) in enumerate(zip(signals, colors)): ax.plot(signal + ind * offset, color=color, **plt_kwargs) if xlim: ax.set_xlim(xlim) else: # Despite seeming like this redundantly resets the limits to what they already are, # for some reason this seems to make sure that shading goes the full length ax.set_xlim(*ax.get_xlim()) if ylim: ax.set_ylim(ylim) if shade: ax.axvspan(*ax.get_xlim(), alpha=0.2, color=shade) ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
def plot_impulse_response(fs, impulse_response, ax=None): """Plot the impulse response of a filter. Parameters ---------- fs : float Sampling rate, in Hz. impulse_response : 1d array The impulse response of a filter. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (5, 5)) # Create a samples vector, center to zero, and convert to time samples = np.arange(len(impulse_response)) samples = samples - (len(samples) - 1) / 2 time = samples / fs ax.plot(time, impulse_response, 'k') ax.set_title('Kernel') ax.set_xlabel('Time (seconds)') ax.set_ylabel('Response')
def plot_swm_pattern(pattern, ax=None, **kwargs): """Plot the resulting pattern from a sliding window matching analysis. Parameters ---------- pattern : 1d array The resulting average pattern from applying sliding window matching. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot the average pattern from a sliding window matching analysis: >>> from neurodsp.sim import sim_combined >>> from neurodsp.rhythm import sliding_window_matching >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_powerlaw': {'f_range': (2, None)}, ... 'sim_bursty_oscillation': {'freq': 20, ... 'enter_burst': .25, ... 'leave_burst': .25}}) >>> avg_window, _, _ = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.5) >>> plot_swm_pattern(avg_window) """ ax = check_ax(ax, (4, 4)) ax.plot(pattern, 'k') ax.set_title('Average Pattern') ax.set_xlabel('Time (samples)') ax.set_ylabel('Voltage (a.u.)')
def plot_scv_rs_lines(freqs, scv_rs, ax=None, **kwargs): """Plot spectral coefficient of variation, from the resampling method, as lines. Parameters ---------- freqs : 1d array Frequency vector. scv_rs : 2d array Spectral coefficient of variation, from resampling procedure. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot the spectral coefficient of variation using a resampling method: >>> from neurodsp.sim import sim_combined >>> from neurodsp.spectral import compute_scv_rs >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) >>> freqs, t_inds, scv_rs = compute_scv_rs(sig, fs=500, nperseg=500, method='bootstrap', ... rs_params=(5, 200)) >>> plot_scv_rs_lines(freqs, scv_rs) """ ax = check_ax(ax, (8, 8)) ax.loglog(freqs, scv_rs, 'k', alpha=0.1) ax.loglog(freqs, np.mean(scv_rs, axis=1), lw=2) ax.loglog(freqs, len(freqs) * [1.]) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('SCV')
def plot_scv(freqs, scv, ax=None, **kwargs): """Plot spectral coefficient of variation. Parameters ---------- freqs : 1d array Frequency vector. scv : 1d array Spectral coefficient of variation. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot the spectral coefficient of variation: >>> from neurodsp.sim import sim_combined >>> from neurodsp.spectral import compute_scv >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) >>> freqs, scv = compute_scv(sig, fs=500) >>> plot_scv(freqs, scv) """ ax = check_ax(ax, (5, 5)) ax.loglog(freqs, scv) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('SCV')
def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwargs): """Plot power spectra. Parameters ---------- freqs : 1d or 2d array or list of 1d array Frequency vector. powers : 1d or 2d array or list of 1d array Power values. labels : str or list of str, optional Labels for each time series. colors : str or list of str Colors to use to plot lines. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot a power spectrum: >>> from neurodsp.sim import sim_combined >>> from neurodsp.spectral import compute_spectrum >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_synaptic_current': {}, ... 'sim_bursty_oscillation' : {'freq': 10}}, ... component_variances=(0.5, 1.)) >>> freqs, powers = compute_spectrum(sig, fs=500) >>> plot_power_spectra(freqs, powers) """ ax = check_ax(ax, (6, 6)) freqs = repeat(freqs) if isinstance( freqs, np.ndarray) and freqs.ndim == 1 else freqs powers = [ powers ] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers if labels is not None: labels = [labels] if not isinstance(labels, list) else labels else: labels = repeat(labels) colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) for freq, power, color, label in zip(freqs, powers, colors, labels): ax.loglog(freq, power, color=color, label=label) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power (V^2/Hz)')
def plot_data(x_values=None, y_values=None, ax=None, **kwargs): """Generic plot for plotting data.""" ax = check_ax(ax, [5, 5]) if y_values is None: ax.plot(x_values, **kwargs) else: ax.plot(x_values, y_values, **kwargs)
def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): """Plot a time series. Parameters ---------- times : 1d or 2d array, or list of 1d array Time definition(s) for the time series to be plotted. sigs : 1d or 2d array, or list of 1d array Time series to plot. labels : list of str, optional Labels for each time series. colors : str or list of str Colors to use to plot lines. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Create a time series plot: >>> from neurodsp.sim import sim_combined >>> from neurodsp.utils import create_times >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_powerlaw': {'exponent': -1.5, 'f_range': (2, None)}, ... 'sim_oscillation' : {'freq': 10}}) >>> times = create_times(n_seconds=10, fs=500) >>> plot_time_series(times, sig) """ ax = check_ax(ax, (15, 3)) times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs if labels is not None: labels = [labels] if not isinstance(labels, list) else labels else: labels = repeat(labels) # If not provided, default colors for up to two signals to be black & red if not colors and len(sigs) <= 2: colors = ['k', 'r'] colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) for time, sig, color, label in zip(times, sigs, colors, labels): ax.plot(time, sig, color=color, label=label) ax.set_xlabel('Time (s)') ax.set_ylabel('Voltage (uV)')
def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, spectrum=None, ax=None, **kwargs): """Plot spectral histogram. Parameters ---------- freqs : 1d array Frequencies over which the histogram is calculated. power_bins : 1d array Power bins within which histogram is aggregated. spectral_hist : 2d array Spectral histogram to be plotted. spectrum_freqs : 1d array, optional Frequency axis of the power spectrum to be plotted. spectrum : 1d array, optional Spectrum to be plotted over the histograms. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot a spectral histogram: >>> from neurodsp.sim import sim_combined >>> from neurodsp.spectral import compute_spectral_hist >>> sig = sim_combined(n_seconds=100, fs=500, ... components={'sim_synaptic_current': {}, ... 'sim_bursty_oscillation' : {'freq': 10}}, ... component_variances=(0.5, 1)) >>> freqs, bins, spect_hist = compute_spectral_hist(sig, fs=500, nbins=40, f_range=(1, 75), ... cut_pct=(0.1, 99.9)) >>> plot_spectral_hist(freqs, bins, spect_hist) """ # Get axis, by default scaling figure height based on number of bins figsize = (8, 12 * len(power_bins) / len(freqs)) ax = check_ax(ax, figsize) # Plot histogram intensity as image and automatically adjust aspect ratio im = ax.imshow(spectral_hist, extent=[freqs[0], freqs[-1], power_bins[0], power_bins[-1]], aspect='auto') plt.colorbar(im, label='Probability') ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Log10 Power') # If a power spectrum is provided, plot over the histogram data if spectrum is not None: plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1]) ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8)
def plot_sig_kernel(sig, samps, kernel, ax=None): """Plot a signal with an overlying kernel.""" ax = check_ax(ax, [12, 2]) ax.plot(sig, color='black', alpha=0.25) ax.plot(samps, sig[samps], marker='.', markersize=2.5, linewidth=0, color='blue') ax.plot(samps, kernel * 25 - 0.5, color='red', alpha=0.75) ax.set(xlim=[0, len(sig)], ylim=[-3.5, 3.5]) ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
def plot_convolution(samples, convolved, ax=None): """Plot the output of a convolution.""" ax = check_ax(ax, [12, 2]) ax.plot(samples, convolved, alpha=0.5, color='green') ind = np.where(~np.isnan(convolved))[0][-1] ax.plot(samples[ind], convolved[ind], '.', markersize=12, color='green', alpha=0.75) ax.set(xlim=[0, len(samples)], ylim=[-3.5, 3.5]) ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
def plot_swm_pattern(pattern, ax=None): """Plot the resulting pattern from a sliding window matching analysis. Parameters ---------- pattern : 1d array The resulting average pattern from applying sliding window matching. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (4, 4)) plt.plot(pattern, 'k') plt.title('Average Pattern') plt.xlabel('Time (samples)') plt.ylabel('Voltage (a.u.)')
def plot_lagged_coherence(freqs, lcs, ax=None): """Plot lagged coherence values across frequencies. Parameters ---------- freqs : 1d array Vector of frequencies at which lagged coherence was computed. lcs : 1d array Lagged coherence values across the computed frequencies. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (6, 3)) plt.plot(freqs, lcs, 'k.-') plt.xlabel('Frequency (Hz)') plt.ylabel('Lagged Coherence')
def plot_scv(freqs, scv, ax=None): """Plot spectral coefficient of variation. Parameters ---------- freqs : 1d array Frequency vector. scv : 1d array Spectral coefficient of variation. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (5, 5)) ax.loglog(freqs, scv) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('SCV')
def plot_spectra(freqs, powers, log_freqs=True, log_powers=True, xlim=None, ylim=None, colors=None, shade_ranges=None, shade_colors=None, ax=None, **plt_kwargs): """Plot power spectra.""" ax = check_ax(ax) if isinstance(powers, np.ndarray): powers = [powers] if isinstance(colors, str) or colors is None: colors = repeat(colors) if log_freqs: with warnings.catch_warnings(): warnings.simplefilter("ignore") freqs = np.log10(freqs) for power, color in zip(powers, colors): if log_powers: power = np.log10(power) ax.plot(freqs, power, color=color, **plt_kwargs) if shade_ranges: add_shades(ax, shade_ranges, shade_colors, logged=log_freqs) if xlim: ax.set_xlim(xlim) if ylim: ax.set_ylim(ylim) ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
def plot_frequency_response(f_db, db, ax=None): """Plot the frequency response of a filter. Parameters ---------- f_db : 1d array Frequency vector corresponding to attenuation decibels, in Hz. db : 1d array Degree of attenuation for each frequency specified in f_db, in dB. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (5, 5)) ax.plot(f_db, db, 'k') ax.set_title('Frequency response') ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Attenuation (dB)')
def plot_bursts(times, sig, bursting, ax=None, **plt_kwargs): """Plot a time series, with labeled bursts. Parameters ---------- times : 1d array Time definition for the time series to be plotted. sig : 1d array Time series to plot. bursting : 1d array A boolean array which indicates identified bursts. ax : matplotlib.Axes, optional Figure axes upon which to plot. **plt_kwargs Keyword arguments to pass into `plot_time_series`. """ ax = check_ax(ax, (15, 3)) bursts = ma.array(sig, mask=np.invert(bursting)) plot_time_series(times, [sig, bursts], ax=ax, **plt_kwargs)
def plot_scv_rs_lines(freqs, scv_rs, ax=None): """Plot spectral coefficient of variation, from the resampling method, as lines. Parameters ---------- freqs : 1d array Frequency vector. scv_rs : 2d array Spectral coefficient of variation, from resampling procedure. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (8, 8)) ax.loglog(freqs, scv_rs, 'k', alpha=0.1) ax.loglog(freqs, np.mean(scv_rs, axis=1), lw=2) ax.loglog(freqs, len(freqs) * [1.]) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('SCV')
def plot_scv_rs_matrix(freqs, t_inds, scv_rs, ax=None, **kwargs): """Plot spectral coefficient of variation, from the resampling method, as a matrix. Parameters ---------- freqs : 1d array Frequency vector. t_inds : 1d array Time indices. scv_rs : 1d array Spectral coefficient of variation, from resampling procedure. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot a SCV matrix from a simulated signal with a high probability of bursting at 10Hz: >>> from neurodsp.sim import sim_combined >>> from neurodsp.spectral import compute_scv_rs >>> sig = sim_combined(n_seconds=100, fs=500, ... components={'sim_synaptic_current': {}, ... 'sim_bursty_oscillation': {'freq': 10, 'enter_burst':0.75}}) >>> freqs, t_inds, scv_rs = compute_scv_rs(sig, fs=500, method='rolling', rs_params=(10, 2)) >>> # Plot the computed scv, plotting frequencies up to 20 Hz (index of 21) >>> plot_scv_rs_matrix(freqs[:21], t_inds, scv_rs[:21]) """ ax = check_ax(ax, (10, 5)) im = ax.imshow(np.log10(scv_rs), aspect='auto', extent=(t_inds[0], t_inds[-1], freqs[-1], freqs[0])) plt.colorbar(im, label='SCV') ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)')
def plot_time_series(times, sigs, labels=None, colors=None, ax=None): """Plot a time series. Parameters ---------- times : 1d array or list of 1d array Time definition(s) for the time series to be plotted. sigs : 1d array or list of 1d array Time series to plot. labels : list of str, optional Labels for each time series. cols : str or list of str Colors to use to plot lines. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (15, 3)) times = repeat(times) if isinstance(times, np.ndarray) else times sigs = [sigs] if isinstance(sigs, np.ndarray) else sigs if labels is not None: labels = [labels] if not isinstance(labels, list) else labels else: labels = repeat(labels) if colors is not None: colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) else: colors = cycle(['k', 'r', 'b', 'g', 'm', 'c']) for time, sig, color, label in zip(times, sigs, colors, labels): ax.plot(time, sig, color, label=label) ax.set_xlabel('Time (s)') ax.set_ylabel('Voltage (uV)')
def plot_impulse_response(fs, impulse_response, ax=None, **kwargs): """Plot the impulse response of a filter. Parameters ---------- fs : float Sampling rate, in Hz. impulse_response : 1d array The impulse response of a filter. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot the impulse response of an FIR bandpass filter: >>> from neurodsp.filt import design_fir_filter >>> from neurodsp.filt.utils import compute_frequency_response >>> fs = 500 >>> filter_coefs = design_fir_filter(fs, pass_type='bandpass', f_range=(1, 40)) >>> plot_impulse_response(fs, filter_coefs) """ ax = check_ax(ax, (5, 5)) # Create a samples vector, center to zero, and convert to time samples = np.arange(len(impulse_response)) samples = samples - (len(samples) - 1) / 2 time = samples / fs ax.plot(time, impulse_response, 'k') ax.set_title('Kernel') ax.set_xlabel('Time (seconds)') ax.set_ylabel('Response')
def plot_lagged_coherence(freqs, lcs, ax=None, **kwargs): """Plot lagged coherence values across frequencies. Parameters ---------- freqs : 1d array Vector of frequencies at which lagged coherence was computed. lcs : 1d array Lagged coherence values across the computed frequencies. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot lagged coherence: >>> from neurodsp.sim import sim_combined >>> from neurodsp.rhythm import compute_lagged_coherence >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_synaptic_current': {}, ... 'sim_bursty_oscillation': {'freq': 20, ... 'enter_burst': .50, ... 'leave_burst': .25}}) >>> lag_cohs, freqs = compute_lagged_coherence(sig, fs=500, freqs=(5, 35), ... return_spectrum=True) >>> plot_lagged_coherence(freqs, lag_cohs) """ ax = check_ax(ax, (6, 3)) ax.plot(freqs, lcs, 'k.-') ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Lagged Coherence')
def plot_bursts(times, sig, bursting, ax=None, **kwargs): """Plot a time series, with labeled bursts. Parameters ---------- times : 1d array Time definition for the time series to be plotted. sig : 1d array Time series to plot. bursting : 1d array A boolean array which indicates identified bursts. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`, and/or for customizing the plot. Examples -------- Create a plot of burst activity: >>> from neurodsp.sim import sim_combined >>> from neurodsp.utils import create_times >>> from neurodsp.burst import detect_bursts_dual_threshold >>> sig = sim_combined(n_seconds=10, fs=500, ... components={'sim_synaptic_current': {}, ... 'sim_bursty_oscillation' : {'freq': 10}}, ... component_variances=(0.1, 0.9)) >>> is_burst = detect_bursts_dual_threshold(sig, fs=500, dual_thresh=(1, 2), f_range=(8, 12)) >>> times = create_times(n_seconds=10, fs=500) >>> plot_bursts(times, sig, is_burst, labels=['Raw Data', 'Detected Bursts']) """ ax = check_ax(ax, (15, 3)) bursts = ma.array(sig, mask=np.invert(bursting)) plot_time_series(times, [sig, bursts], ax=ax, **kwargs)
def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None): """Plot power spectra. Parameters ---------- freqs : 1d array or list of 1d array Frequency vector. powers : 1d array or list of 1d array Power values. labels : str or list of str, optional Labels for each time series. colors : str or list of str Colors to use to plot lines. ax : matplotlib.Axes, optional Figure axes upon which to plot. """ ax = check_ax(ax, (6, 6)) freqs = repeat(freqs) if isinstance(freqs, np.ndarray) else freqs powers = [powers] if isinstance(powers, np.ndarray) else powers if labels is not None: labels = [labels] if not isinstance(labels, list) else labels else: labels = repeat(labels) if colors is not None: colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) for freq, power, label in zip(freqs, powers, labels): ax.loglog(freq, power, label=label) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power (V^2/Hz)')
def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): """Plot a time-frequency representation of data. Parameters ---------- times : 1d array The time dimension for the time-frequency representation. freqs : 1d array The frequency dimension for the time-frequency representation. powers : 2d array Power values to plot. If array is complex, the real component is taken for plotting. x_ticks, y_ticks : int or array_like Defines the tick labels to add to the plot. If int, is the number of evenly sampled labels to add to the plot. If array_like, is a set of labels to add to the plot. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot a Morlet transformation: >>> import numpy as np >>> from neurodsp.sim import sim_bursty_oscillation >>> from neurodsp.timefrequency.wavelets import compute_wavelet_transform >>> fs=1000 >>> sig = sim_bursty_oscillation(n_seconds=10, fs=fs, freq=10) >>> times = np.arange(0, len(sig)/fs, 1/fs) >>> freqs = np.arange(1, 50, 1) >>> mwt = compute_wavelet_transform(sig, fs, freqs) >>> plot_timefrequency(times, freqs, mwt) """ ax = check_ax(ax, None) if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect='auto', **kwargs) ax.invert_yaxis() ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') if isinstance(x_ticks, int): x_tick_pos = np.linspace(0, times.size, x_ticks) x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) if isinstance(y_ticks, int): y_ticks_pos = np.linspace(0, freqs.size, y_ticks) y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) else: y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks)
def plot_spikes(df_features, sig, fs, spikes=None, index=None, xlim=None, ax=None): """Plot a group of spikes or the cyclepoints for an individual spike. Parameters ---------- df_features : pandas.DataFrame Dataframe containing shape and burst features for each spike. sig : 1d or 2d array Voltage timeseries. May be 2d if spikes are split. fs : float Sampling rate, in Hz. spikes : 1d array, optional, default: None Spikes that have been split into a 2d array. Ignored if ``index`` is passed. index : int, optional, default: None The index in ``df_features`` to plot. If None, plot all spikes. xlim : tuple Upper and lower time limits. Ignored if spikes or index is passed. ax : matplotlib.Axes, optional, default: None Figure axes upon which to plot. """ ax = check_ax(ax, (10, 4)) center_e, _ = get_extrema_df(df_features) # Plot a single spike if index is not None: times = np.arange(0, len(sig) / fs, 1 / fs) # Get where spike starts/ends start = df_features.iloc[index]['sample_start'].astype(int) end = df_features.iloc[index]['sample_end'].astype(int) sig_lim = sig[start:end + 1] times_lim = times[start:end + 1] # Plot the spike waveform plot_time_series(times_lim, sig_lim, ax=ax) # Plot cyclespoints labels, keys = _infer_labels(center_e) colors = ['C0', 'C1', 'C2', 'C3'] for idx, key in enumerate(keys): sample = df_features.iloc[index][key].astype('int') plot_time_series(np.array([times[sample]]), np.array([sig[sample]]), colors=colors[idx], labels=labels[idx], ls='', marker='o', ax=ax) # Plot as stack of spikes elif index is None and spikes is not None: times = np.arange(0, len(spikes[0]) / fs, 1 / fs) plot_time_series(times, spikes, ax=ax) # Plot as continuous timeseries elif index is None and spikes is None: ax = check_ax(ax, (15, 3)) times = np.arange(0, len(sig) / fs, 1 / fs) plot_time_series(times, sig, ax=ax, xlim=xlim) if xlim is None: sig_lim = sig df_lim = df_features times_lim = times starts = df_lim['sample_start'] else: cyc_idxs = (df_features['sample_start'].values >= xlim[0] * fs) & \ (df_features['sample_end'].values <= xlim[1] * fs) df_lim = df_features.iloc[cyc_idxs].copy() sig_lim, times_lim = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) starts = df_lim['sample_start'] - int(fs * xlim[0]) ends = starts + df_lim['period'].values is_spike = np.zeros(len(sig_lim), dtype='bool') for start, end in zip(starts, ends): is_spike[start:end] = True plot_bursts(times_lim, sig_lim, is_spike, ax=ax)