def plot_cyclepoints_df(df_samples, sig, fs, plot_sig=True, plot_extrema=True, plot_zerox=True, xlim=None, ax=None, **kwargs): """Plot extrema and/or zero-crossings from a DataFrame. Parameters ---------- df_samples : pandas.DataFrame Dataframe output of :func:`~.compute_cyclepoints`. sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. plot_sig : bool, optional, default: True Whether to also plot the raw signal. plot_extrema : bool, optional, default: True Whether to plots the peaks and troughs. plot_zerox : bool, optional, default: True Whether to plots the zero-crossings. xlim : tuple of (float, float), optional Start and stop times. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`. Notes ----- Default keyword arguments include: - ``figsize``: tuple of (float, float), default: (15, 3) - ``xlabel``: str, default: 'Time (s)' - ``ylabel``: str, default: 'Voltage (uV) Examples -------- Plot cyclepoints using a dataframe from :func:`~.compute_cyclepoints`: >>> from bycycle.features import compute_cyclepoints >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> df_samples = compute_cyclepoints(sig, fs, f_range=(8, 12)) >>> plot_cyclepoints_df(df_samples, sig, fs) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Determine extrema/zero-crossings from dataframe center_e, side_e = get_extrema_df(df_samples) peaks, troughs, rises, decays = [None] * 4 if plot_extrema: peaks = df_samples['sample_' + center_e].values troughs = np.append(df_samples['sample_last_' + side_e].values, df_samples['sample_next_' + side_e].values) troughs = np.unique(troughs) if plot_zerox: rises = df_samples['sample_zerox_rise'].values decays = df_samples['sample_zerox_decay'].values plot_cyclepoints_array(sig, fs, peaks=peaks, troughs=troughs, rises=rises, decays=decays, plot_sig=plot_sig, xlim=xlim, ax=ax, **kwargs)
def plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs, xlim=None, figsize=(15, 3), plot_only_result=False, interp=True): """Plot the cycle-by-cycle burst detection parameters and burst detection summary. Parameters ---------- df_features : pandas.DataFrame Dataframe output of :func:`~.compute_features`. The df must contain sample indices (i.e. when ``return_samples = True``). sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. threshold_kwargs : dict Burst parameter keys and threshold value pairs, as defined in the 'threshold_kwargs' argument of :func:`.compute_features`. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. figsize : tuple of (float, float), optional, default: (15, 3) Size of each plot. plot_only_result : bool, optional, default: False Plot only the signal and bursts, excluding burst parameter plots. interp : bool, optional, default: True If True, interpolates between given values. Otherwise, plots in a step-wise fashion. Notes ----- - If plot_only_result = True: return a plot of the burst detection in which periods with bursts are denoted in red. - If plot_only_result = False: return a list of the fig handle followed by the 5 axes. - In the top plot, the raw signal is plotted in black, and the red line indicates periods defined as oscillatory bursts. The highlighted regions indicate when each burst requirement was violated, color-coded consistently with the plots below. - blue: amp_fraction_threshold - red: amp_consistency_threshold - yellow: period_consistency_threshold - green: monotonicity_threshold Examples -------- Plot the burst detection summary of a bursting signal: >>> from bycycle.features import compute_features >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5, ... 'period_consistency_threshold': .5, 'monotonicity_threshold': .8} >>> df_features = compute_features(sig, fs, f_range=(8, 12), threshold_kwargs=threshold_kwargs) >>> plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Normalize signal sig = zscore(sig) # Determine time array and limits times = np.arange(0, len(sig) / fs, 1 / fs) xlim = (times[0], times[-1]) if xlim is None else xlim # Determine if peak of troughs are the sides of an oscillation _, side_e = get_extrema_df(df_features) # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot) thresholds = threshold_kwargs.copy() if 'min_n_cycles' in thresholds.keys(): del thresholds['min_n_cycles'] n_kwargs = len(thresholds.keys()) # Create figure and subplots if plot_only_result: fig, axes = plt.subplots(figsize=figsize, nrows=1) axes = [axes] else: fig, axes = plt.subplots(figsize=(figsize[0], figsize[1] * (n_kwargs + 1)), nrows=n_kwargs + 1, sharex=True) # Determine which samples are defined as bursting is_osc = np.zeros(len(sig), dtype=bool) df_osc = df_features.loc[df_features['is_burst']] for _, cyc in df_osc.iterrows(): samp_start_burst = int(cyc['sample_last_' + side_e]) samp_end_burst = int(cyc['sample_next_' + side_e] + 1) is_osc[samp_start_burst:samp_end_burst] = True # Plot bursts with extrema points xlabel = 'Time (s)' if len(axes) == 1 else '' plot_bursts(times, sig, is_osc, ax=axes[0], xlim=xlim, lw=2, labels=['Signal', 'Bursts'], xlabel='', ylabel='') plot_cyclepoints_df(df_features, sig, fs, ax=axes[0], xlim=xlim, plot_zerox=False, plot_sig=False, xlabel=xlabel, ylabel='Voltage\n(normalized)', colors=['m', 'c']) # Plot each burst param colors = cycle( ['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange']) for idx, osc_key in enumerate(thresholds.keys()): column = osc_key.replace('_threshold', '') color = next(colors) # Highlight where a burst param falls below threshold for _, cyc in df_features.iterrows(): if cyc[column] < thresholds[osc_key]: axes[0].axvspan(times[int(cyc['sample_last_' + side_e])], times[int(cyc['sample_next_' + side_e])], alpha=0.5, color=color, lw=0) # Plot each burst param on separate axes if not plot_only_result: ylabel = column.replace('_', ' ').capitalize() xlabel = 'Time (s)' if idx == n_kwargs - 1 else '' plot_burst_detect_param(df_features, sig, fs, column, thresholds[osc_key], figsize=figsize, ax=axes[idx + 1], xlim=xlim, xlabel=xlabel, ylabel=ylabel, color=color, interp=interp)
def plot_bg(dfs_features, sigs, fs, titles=None, btn=True, xlim=None): """Plot 2D bycycle results. Parameters ---------- dfs_features : list of pandas.DataFrame Dataframes containing shape and burst features for each cycle. sigs : 2d array Time series. fs : float Sampling rate, in Hz. titles : list, optional, default: None The titles for each subplot. btn : bool, optional, default: True Adds a recompute bursts button when True. Omits when False. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. Returns ------- graph : str The bycycle plot as a string containing html. """ # Initialize figures in groups of 10 # Plotly doesn't render single figures well with 100+ plots n_per_fig = 10 n_figs = int(np.ceil(len(sigs) / n_per_fig)) figs = np.zeros(n_figs).tolist() titles = [] if titles is None else titles n_rows = [] for idx in range(n_figs): # Subplot titles start = idx * n_per_fig end = start + len(sigs[start:start + n_per_fig]) if len(titles) == idx: titles.append("Indices: {start}-{end}".format(start=start, end=end)) # Create subplots n_rows.append(end - start) fig = make_subplots(rows=n_rows[idx], cols=1, vertical_spacing=0.005, shared_xaxes=True) figs[idx] = fig for idx, df_features in enumerate(dfs_features): # Normalize signal sig = zscore(sigs[idx]) # Determine time array and limits times = np.arange(0, len(sig) / fs, 1 / fs) xlim = (times[0], times[-1]) if xlim is None else xlim # Limit signal and dataframe sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1]) # Get extrema center_e, side_e = get_extrema_df(df_features) # Plot bursts fig_idx = int(np.ceil((idx + 1) / n_per_fig)) - 1 row_idx = int(idx - (n_per_fig * fig_idx)) _plot_bursts(df_features, sig, times, center_e, side_e, figs[fig_idx], plot_cps=False, row=row_idx + 1, col=1) graphs = [] for idx, fig in enumerate(figs): # The size of plotly subplots don't scale properly, this is a workaround height = (60 * n_rows[idx]) + ((n_per_fig - n_rows[idx]) * 18) # Update the figures figs[idx].update_layout(autosize=False, width=1000, height=height, showlegend=False, margin_autoexpand=False, title_text=titles[idx]) figs[idx].update_yaxes(showticklabels=False, showgrid=False) figs[idx].update_xaxes(showgrid=False) figs[idx].update_xaxes(rangeslider={'visible': True}, row=n_rows[idx], col=1) # Convert to html if idx == 0: graphs.append(fig.to_html(include_plotlyjs=False, full_html=False)) else: div = re.search( "<div>.*</div>", fig.to_html(include_plotlyjs=False, full_html=False))[0] graphs.append(div + "\n") # Custom js callback js_callback = ["<script type=\"text/javascript\">\n"] # Get plot div ids div_ids = [re.search("<div id=.+?\"", graph)[0][9:-1] for graph in graphs] # Recolor (non)bursts on click for div_id in div_ids: js_callback.append(""" recolorBursts('{plotID}'); """.format(plotID=div_id)) js_callback.append("</script>\n") # Flatten lists into single string graphs.append("".join(js_callback)) if btn: rewrite_call = "rewriteBursts({div_ids})".format(div_ids=div_ids) # Add a button btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" " btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>" graphs.append(btn) graphs[0] = re.sub("</body>\n</html>", "", graphs[0]) graphs = "".join(graphs) return graphs
def plot_burst_detect_param(df_features, sig, fs, burst_param, thresh, xlim=None, interp=True, ax=None, **kwargs): """Plot a burst detection parameter and threshold. Parameters ---------- df_features : pandas.DataFrame Dataframe output of :func:`~.compute_features`. sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. burst_param : str Column name of the parameter of interest in ``df``. thresh : float The burst parameter threshold. Parameter values greater than ``thresh`` are considered bursts. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. interp : bool, optional, default: True Interpolates points if true. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`. Notes ----- Default keyword arguments include: - ``figsize``: tuple of (float, float), default: (15, 3) - ``xlabel``: str, default: 'Time (s)' - ``ylabel``: str, default: 'Voltage (uV) - ``color``: str, default: 'r'. - Note: ``color`` here is the fill color, rather than line color. Examples -------- Plot the monotonicity of a bursting signal: >>> from bycycle.features import compute_features >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5, ... 'period_consistency_threshold': .5, 'monotonicity_threshold': .8} >>> df_features = compute_features(sig, fs, f_range=(8, 12), ... threshold_kwargs=threshold_kwargs) >>> plot_burst_detect_param(df_features, sig, fs, 'monotonicity', .8) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Set default kwargs figsize = kwargs.pop('figsize', (15, 3)) xlabel = kwargs.pop('xlabel', 'Time (s)') ylabel = kwargs.pop('ylabel', burst_param) color = kwargs.pop('color', 'r') # Determine time array and limits times = np.arange(0, len(sig) / fs, 1 / fs) xlim = (times[0], times[-1]) if xlim is None else xlim if ax is None: fig, ax = plt.subplots(figsize=figsize) # Determine extrema strings center_e, side_e = get_extrema_df(df_features) # Limit dataframe, sig and times df = limit_df(df_features, fs, start=xlim[0], stop=xlim[1]) sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) # Remove start / end cycles that tlims falls between df = df[(df['sample_last_' + side_e] >= 0) & \ (df['sample_next_' + side_e] < xlim[1]*fs)] # Plot burst param if interp: plot_time_series([times[df['sample_' + center_e]], xlim], [df[burst_param], [thresh] * 2], ax=ax, colors=['k', 'k'], ls=['-', '--'], marker=["o", None], xlabel=xlabel, ylabel="{0:s}\nthreshold={1:.2f}".format( ylabel, thresh), **kwargs) else: # Create steps, from side to side of each cycle, and set the y-value # to the burst parameter value for that cycle side_times = np.array([]) side_param = np.array([]) for _, cyc in df.iterrows(): # Get the times for the last and next side of a cycle side_times = np.append(side_times, [ times[int(cyc['sample_last_' + side_e])], times[int( cyc['sample_next_' + side_e])] ]) # Set the y-value, from side to side, to the burst param for each cycle side_param = np.append(side_param, [cyc[burst_param]] * 2) plot_time_series([side_times, xlim], [side_param, [thresh] * 2], ax=ax, colors=['k', 'k'], ls=['-', '--'], marker=["o", None], xlim=xlim, xlabel=xlabel, ylabel="{0:s}\nthreshold={1:.2f}".format( ylabel, thresh), **kwargs) # Highlight where param falls below threshold for _, cyc in df.iterrows(): if cyc[burst_param] <= thresh: ax.axvspan(times[int(cyc['sample_last_' + side_e])], times[int(cyc['sample_next_' + side_e])], alpha=0.5, color=color, lw=0)
def plot_bm(df_features, sig, fs, threshold_kwargs, df_idx, xlim=None, plot_only_result=True): """Plot a single bycycle fit using plotly. Parameters ---------- df_features : pandas.DataFrame A dataframe containing shape and burst features for each cycle. sig : 1d array Time series. fs : float Sampling rate, in Hz. threshold_kwargs : dict, optional, default: None Feature thresholds for cycles to be considered bursts. df_idx : int The index of the dataframe in the javascript array. This is only used for fetching data for the relabel js callback. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. plot_only_result : bool, optional, default: True Plot only the signal and bursts, excluding burst parameter plots. Notes ----- The output_dir and df_idx arguments are used to fetch a javascript array containing the dataframe results during the js relabel callback. This is done since javascript can't access the local filesystem to load csv files. Returns ------- graph : str The bycycle plot as a string containing html. """ # Normalize signal sig = zscore(sig) # Determine time array and limits times = np.arange(0, len(sig) / fs, 1 / fs) xlim = (times[0], times[-1]) if xlim is None else xlim # Limit signal and dataframe sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1]) # Determine if peak of troughs are the sides of an oscillation center_e, side_e = get_extrema_df(df_features) # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot) if 'min_n_cycles' in threshold_kwargs.keys(): del threshold_kwargs['min_n_cycles'] # Create figure and subplots if plot_only_result: fig = go.Figure(make_subplots(rows=1, cols=1)) elif not plot_only_result: fig = go.Figure(make_subplots(rows=5, cols=1, vertical_spacing=0.01)) # Plot bursts fig = _plot_bursts(df_features, sig, times, center_e, side_e, fig, row=1, col=1) # Plot params if not plot_only_result: burst_params = [ 'amp_fraction', 'amp_consistency', 'period_consistency', 'monotonicity', 'burst_fraction' ] burst_params = [ param for param in burst_params if param in df_features.columns ] fig = _plot_params(df_features, sig, fs, times, center_e, side_e, burst_params, threshold_kwargs, fig, row=2, col=1) # Update axes and layout if not plot_only_result: # Add time label to last subplot fig.update_xaxes(title_text='Time', showticklabels=True, row=len(burst_params) + 2, col=1) # Update signal axes fig.update_xaxes(showticklabels=False, row=1, col=1) fig.update_yaxes(title_text="Voltage<br>(normalized)", row=1, col=1) # Zoom link across all subplots fig.update_xaxes(matches="x") # Update layout across all subplots fig.update_layout(width=1000, height=1000) else: fig.update_layout(width=1000, height=325, xaxis_title="Time", yaxis_title="Voltage<br>(normalized)") fig.update_layout( autosize=True, showlegend=False, title_text="Burst Detection Plots", ) fig.update_xaxes(rangeslider={'visible': True}, row=5, col=1) graph = fig.to_html(include_plotlyjs=False, full_html=False) # Get burst traces trace_id = len(fig.data) - 2 burst_traces = list(range(len(df_features))) # Get the div id containg the plot div_id = re.search("<div id=.+?\"", graph)[0] div_id = div_id[9:-1] # Create js callback js_callback = ["<script type=\"text/javascript\">\n"] js_callback.append(""" var burstPlot = document.getElementById('{plot_id}'); var burstTraces = {burst_traces}; var traceId = {trace_id}; burstPlot.on('plotly_click', function(data){{ relabel1DBursts(data, burstPlot, {idx}, burstTraces, traceId); }}); """.format(trace_id=trace_id, burst_traces=str(burst_traces), plot_id=div_id, idx=df_idx)) js_callback.append("</script>\n") js_callback = "".join(js_callback) # Flatten lists into single string graph = re.sub("</body>\n</html>", "\n", graph) # Add recompute burst btn below the plots rewrite_call = "rewriteBursts({div_id})".format(div_id=str([div_id])) btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" " btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>" graph = graph + js_callback + btn + "\n</body>\n</html>" return graph