def plot_grid(self): fig, ax = subplots(figsize=paperfig(0.55, 0.55)) ax.imshow( self.data_matrix, cmap=self.cmap, origin="lower", aspect="auto", norm=self.norm, ) ax.set_xlabel("Ripple definition (μV)") ax.set_ylabel("Sharp wave definition (μV)") num_ripple = len(self.ripple_thresholds) num_SW = len(self.SW_thresholds) ax.set_xticks(arange(num_ripple)[::2]) ax.set_yticks(arange(num_SW)[::2]) ax.set_xticklabels([f"{x:.0f}" for x in self.ripple_thresholds][::2]) ax.set_yticklabels([f"{x:.0f}" for x in self.SW_thresholds][::2]) for SW_ix in range(num_SW): for ripple_ix in range(num_ripple): value = self.data_matrix[SW_ix][ripple_ix] text = self.fstring.format(x=value) ax.text( ripple_ix, SW_ix, text, ha="center", va="center", color=self.text_color(value), size="smaller", ) ax.grid(False) fig.tight_layout() self.output_grid.write(fig)
def work(self): fig, axes = subplots( nrows=2, ncols=2, # figsize=1.1 * array([9.8, 9]), figsize=paperfig(1, 0.82), gridspec_kw=dict(width_ratios=[1.63, 1], height_ratios=[1, 1.63]), ) ax_PR: Axes = axes[1, 0] ax_delay_P: Axes = axes[1, 1] ax_delay_R: Axes = axes[0, 0] axes[0, 1].remove() self.setup_axes(ax_PR, ax_delay_P, ax_delay_R) self.plot_PR_curves(ax_PR) self.shade_under_PR_curves(ax_PR) self.plot_delays(ax_delay_P, ax_delay_R) self.mark_cutoffs(ax_PR, ax_delay_P, ax_delay_R) add_colored_legend( fig, self.titles, self.colors, loc="upper right", bbox_to_anchor=(0.965, 0.965), ) fig.tight_layout() self.output().write(fig)
def plot_grid(self): downscale = 1.65 self.rowheight /= downscale self.col_pad /= downscale num_gridrows = len(self.num_delays) nrows = num_gridrows + 1 ncols = len(self.channel_combo_names) figwidth = 1 / downscale + 1.8 * ncols / downscale figheight = 1 / downscale + self.rowheight * nrows rel_rowheight = SearchGrid.rowheight / self.rowheight channelmap_rel_height = 2.5 * rel_rowheight fig, axes = subplots( nrows=nrows, ncols=ncols, figsize=(figwidth, figheight), gridspec_kw=dict(height_ratios=[1] * num_gridrows + [channelmap_rel_height]), ) self.plot_grid_cells(axes) self.plot_channelmaps(axes) self.label_cell_x(axes[0, 0]) self.label_cell_y(axes[0, -1]) add_colored_legend( fig, ("GEVec", "Online BPF"), (self.GEVec_color, self.sota_color), ncol=2, loc="upper center", ) self.label_grid(fig, axes) rect = (0.07, rel_rowheight * 0.013, 1, 1 - rel_rowheight * 0.04) fig.tight_layout(w_pad=self.col_pad, rect=rect) self.output_grid.write(fig)
def make_figure(self, time_range): nrows = 1 + len(self.extra_signals) axheights = ones(nrows) figheight = 0.30 + 0.14 * len(self.extra_signals) if not self.reference_channel_only: num_channels = self.multichannel_test.num_channels axheights[0] = 1 + num_channels / 4 figheight *= 1.2 fig, axes = subplots( nrows=nrows, sharex=True, figsize=paperfig(width=self.figwidth, height=0.9 * figheight), gridspec_kw=dict(height_ratios=axheights), ) input_ax = axes[0] extra_axes = axes[1:] self.plot_input_signal(time_range, input_ax) self.plot_other_signals(time_range, extra_axes) self.post_plot(time_range, input_ax, extra_axes) add_segments(input_ax, self.reference_segs_test) add_time_scalebar( extra_axes[0], select_scalebar_time(time_range), "ms", pos_along=0.73, pos_across=1.25, in_layout=False, ) fig.tight_layout(rect=(0.02, 0, 1, 1)) return fig
def work(self): tups = zip(self.trainers, self.convolvers, self.colors, self.output()) for trainer, convolver, color, filetarget in tups: fig, ax = subplots(figsize=(5 + convolver.num_delays / 3, 5)) GEVec = trainer.output().read() signal = trainer.multichannel_train num_channels = signal.num_channels weights = shape_GEVec(GEVec, num_channels) wmax = max(abs(weights)) image = ax.imshow( weights, origin="lower", cmap="PiYG", vmin=-wmax, vmax=wmax ) cbar = fig.colorbar(image, ax=ax, shrink=0.6) cbar.set_label("Weight") ax.set_xticks(convolver.delays) ax.set_yticks(convolver.channels) ax.set_xlim(ax.get_xlim()[::-1]) if signal.fs != 1000: warn('Delay axis scale "(ms)" in GEVec plot is not correct.') ax.set_xlabel("Delay (ms)") ax.set_ylabel("Channel") title = f"GEVec\n({convolver.filename})" ax.set_title(title, color=color) ax.grid(False) fig.tight_layout() filetarget.write(fig)
def work(self): fig, axes = subplots(nrows=2, sharex=True, figsize=paperfig()) ax_top, ax_btm = axes ax_btm.set_xlabel("Number of delays") self.plot_on_axes(ax_top, ax_btm) self.add_legend(ax_top) fig.tight_layout() self.output().write(fig)
def work(self): with PdfPages(self.output().path_string) as pdf: for time_range in self.time_ranges: fig, ax = subplots() plot_signal(self.signal, time_range, ax=ax, time_grid=False) add_time_scalebar(ax, 500, "ms", pos_across=-0.04) add_voltage_scalebar(ax, pos_across=-0.04) t = format_duration(time_range[0], auto_ms=False) ax.set_title(f"{t} since start of recording.") fig.tight_layout() pdf.savefig(fig) close(fig)
def plot_colorbar(self): fig, ax = subplots(figsize=(1, 2)) cbar = ColorbarBase( ax=ax, label=self.colorbar_label, norm=self.norm, cmap=self.cmap, extend="both", format=fraction, ) fig.tight_layout() self.output_colorbar.write(fig)
def plot_colorbar(self): fig, ax = subplots(figsize=paperfig(0.42, 0.16)) cbar = ColorbarBase( ax=ax, orientation="horizontal", label=self.colorbar_label, norm=self.norm, cmap=self.cmap, extend="both", format=StrMethodFormatter(self.fstring), ) fig.tight_layout() self.output_colorbar.write(fig)
def work(self): fig, ax = subplots(figsize=paperfig(0.59, 0.497)) self.setup_ax(ax) self.plot_PR(ax) self.shade_under_PR_curves(ax) self.plot_iso_F_curves(ax) self.mark_selected_threshold(ax) add_colored_legend( ax, labels + ("Iso-$F_2$-curves", ), colors + (iso_F_color, ), loc=(0.03, 0.03), ) fig.tight_layout() self.output().write(fig)
def work(self): nrows = 5 axheights = ones(nrows) axheights[0] = 2 axheights[1:3] = 0.84 for trange, output in zip(config.time_ranges, self.outputs): log.info(f"Generating figure {output.filename}") fig, axes = subplots( nrows=nrows, figsize=paperfig(0.57, 0.75), gridspec_kw=dict(height_ratios=axheights), ) self.plot_input(axes[0], trange) self.plot_offline(axes[1:3], trange) self.plot_online(axes[3:], trange) add_time_scalebar(axes[0], 200, in_layout=False, pos_along=0.56) fig.tight_layout() output.write(fig)
def work(self): fig, axes = subplots( nrows=6, ncols=2, figsize=paperfig(1.2, 1.31), gridspec_kw=dict(width_ratios=(1, 0.24), height_ratios=(1, 1, 1, 1, 1.22, 0.8)), ) self.remove_empty_axes(axes) self.plot_wideband(axes[0, 0]) self.plot_filter_output(axes[1, 0]) self.plot_analytic(axes[2, 0]) self.plot_smoothed_envelope(axes[3, 0]) self.plot_thresholded_envelope(axes[4, 0], axes[4, 1]) self.plot_segments(axes[5, 0]) with ignore(UserWarning): fig.tight_layout() self.output().write(fig)
def work(self): fig, (ax_top, ax_btm) = subplots(nrows=2, figsize=paperfig(0.60, 0.57)) ax_btm: Axes te_bpf, te_rnn = get_tes() delays = concat([ DataFrame( dict( Absolute=te_rnn.abs_delays, Relative=te_rnn.rel_delays, Algorithm="RNN", )), DataFrame( dict( Absolute=te_bpf.abs_delays, Relative=te_bpf.rel_delays, Algorithm="BPF", )), ]) delays.Absolute *= 1000 delays = delays[delays.Absolute < 100] kwargs = dict( data=delays, y="Algorithm", alpha_dot=0.4, palette=colors, order=("BPF", "RNN"), width_kde=0.8, width_box=0.1, jitter=0.12, ms=2, ) distplot(x="Absolute", ax=ax_top, **kwargs) distplot(x="Relative", ax=ax_btm, **kwargs) ax_btm.xaxis.set_major_formatter(fraction) for ax in (ax_top, ax_btm): ax.set_yticklabels([]) ax.set_ylabel("") ax.set_xlabel("") # ax_top.set_xlabel("Absolute detection latency (ms)") # ax_btm.set_xlabel("Relative detection latency") fig.tight_layout(h_pad=4) self.output().write(fig)
def work(self): # fmt: off df: DataFrame = concat(( DataFrame(dict(proj=S.T @ -SVecs[:, 0], method='PCA', src='Signal')), DataFrame(dict(proj=N.T @ -SVecs[:, 0], method='PCA', src='Noise')), DataFrame(dict(proj=S.T @ GEVecs[:, -1], method='GEVec', src='Signal')), DataFrame(dict(proj=N.T @ GEVecs[:, -1], method='GEVec', src='Noise')), )) # fmt: on def normalise(method): select = (df.method == method, "proj") scale = max(abs(df.loc[select])) df.loc[select] /= scale normalise("PCA") normalise("GEVec") fig, ax = subplots() stripplot(data=df, x="proj", y="method", hue="src", dodge=0.5, ax=ax) ax.legend_.remove() ax.spines["bottom"].set_visible(True) ax.set_xticks([]) ax.set_xlabel("Projection on 1st eigenvector") ax.set_ylabel("") ax.yaxis.set_tick_params(labelcolor="black", labelsize=fontsize, pad=16) for label in ax.get_yticklabels(): label: Text label.set_color(colors[label.get_text()]) fig.tight_layout() self.output().write(fig)
def work(self): fig, axes = subplots( nrows=2, ncols=2, figsize=paperfig(width=1.2, height=0.55) ) ax_top_left: Axes = axes[0, 0] ax_top_right: Axes = axes[0, 1] ax_bottom_left: Axes = axes[1, 0] ax_bottom_right: Axes = axes[1, 1] ax_top_right.remove() ax_bottom_left.set_xlabel("Frequency (Hz)") ax_bottom_right.set_xlabel("Frequency (Hz)") ax_gain_dB = ax_top_left ax_gain = ax_bottom_left ax_grpdelay = ax_bottom_right ax_gain.set_ylabel("Gain") ax_gain_dB.set_ylabel("Gain (dB)") ax_gain_dB.set_ylim(-63, 4) ax_grpdelay.set_ylabel("Group delay (ms)") # Force zero-line in view: ax_grpdelay.axhline(y=0, color="none") f_max = 500 margin = 10 # To avoid phase discontinutities. In Hz. f = linspace(margin, f_max - margin, 10000) for filta in (self.filter_original, self.filter_replica): H = filta.freqresp(f) g = gain(H) ax_gain.plot(f, g) ax_gain_dB.plot(f, dB(g)) ax_grpdelay.plot(f, group_delay(H, f)) add_colored_legend( fig, (self.label_original, self.label_replica), loc="lower left", bbox_to_anchor=(0.5, 0.6), ) fig.tight_layout(w_pad=3) self.output().write(fig)
def work(self): fig, axes = subplots(ncols=2, figsize=paperfig(1.1, 0.5)) eps = 0.001 lims = (-eps, 1 + eps) R = linspace(0, 1, 1000) for i, beta in enumerate([1, 2]): ax = axes[i] color = f"C{i}" ax.set_title(f"$F_{beta}$", color=color) ax.set_xlim(lims) ax.set_ylim(lims) ax.set_aspect("equal") ax.plot(R, R, "grey", lw=0.5) ax.set_xlabel("Recall") ax.set_xticks([0, 0.25, 0.50, 0.75, 1]) ax.set_yticks([0, 0.25, 0.50, 0.75, 1]) if i == 0: ax.set_ylabel("Precision") ax.xaxis.set_major_formatter(fraction) ax.yaxis.set_major_formatter(fraction) for F in [0.4, 0.6, 0.8, 0.9, 0.95]: P = iso_F_line(R, F, beta) dom = P > 0 ax.plot(R[dom], P[dom], color=color) ax.text( s=f"{F:.0%}", color=color, x=lims[1] + 0.02, y=P[-1], va="center", ) fig.tight_layout(w_pad=3) self.output().write(fig)
def work(self): fig, ax = subplots(figsize=(5, 5)) scatter(S, ax, colors["Signal"]) scatter(N, ax, colors["Noise"]) plot_vector(SVals[0] * -SVecs[:, 0], ax, colors["PCA"]) # plot_vector(SVals_N[1] * -SVecs_N[:,1], ax, 'black') # plot_vector(SVals[1] * SVecs[:,1], ax) plot_vector(40 * sqrt(GEVals[-1]) * GEVecs[:, -1], ax, colors["GEVec"]) # plot_vector(40*np.sqrt(gevl[-2]) * gevc[:,-2], ax, 'C5') ax.set_aspect("equal") lims = 11 ax.set_xlim(-lims, lims) ax.set_ylim(-lims, lims) for s, loc in (["Signal", (0.52, 0.26)], ["Noise", (0.7, 0.46)]): fig.text(*loc, s, color=colors[s], fontsize=fontsize) ax.set_xlabel("Electrode 3") ax.set_ylabel("Electrode 12") ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) ax.spines["bottom"].set_visible(True) ax.spines["left"].set_visible(True) fig.tight_layout() self.output().write(fig)
def plot_signal(signal: Signal, time_range: Tuple[float, float], y_scale: float = 500, height: float = 0.5, channels: Optional[ndarray] = None, bottom_first: bool = True, tight_ylims: bool = False, zero_lines: bool = True, time_grid: bool = True, y_grid: Optional[bool] = None, color: Color = "black", lw: float = 0.9, ax: Optional[Axes] = None, **kwargs) -> (Figure, Axes): """ Plot a time-slice of a single- or multichannel signal. When the signal is multichannel, each channel will be plotted with the same scale. :param time_range: Time slice to plot. In seconds. :param channels: Which channels to plot. Plots all channels by default. :param height: Height of each channel, in inches. Only relevant when no Axes is given. :param y_scale: How much data-y-units the visual vertical spacing between channels represents. :param bottom_first: If True (default), the first channel will be plotted at the bottom of the figure. :param tight_ylims: If True, adapts the ylims to tightly fit the data visible in `time_range`. If False (default), makes sure that plots of different time ranges of `signal` will all have the same ylims. :param zero_lines: Whether to plot grey y==0 lines in each channel. :param time_grid: Whether to plot vertical gridlines, with corresponding absolute time ticks and ticklabels. :param y_grid: Whether to plot horizontal gridlines, with corresponding y-ticks and -ticklabels. By default, only plots a y-grid if the signal is single-channel and `zero_lines` is False. :param ax: The axes to plot on. If None (default), creates a new figure and axes. :param kwargs: Passed on to `ax.plot()`. """ signal = signal.as_matrix() if ax is None: fig, ax = subplots(figsize=(12, height * signal.num_channels)) else: fig = ax.get_figure() if y_grid is None: if signal.num_channels == 1 and not zero_lines: y_grid = True else: y_grid = False if channels is None: channels = arange(signal.num_channels) ix = time_to_index(time_range, signal.fs, arr_size=signal.num_samples, clip=True) y: Signal = signal[slice(*ix), channels] t = y.get_time_vector(t0=time_range[0]) if bottom_first: y_offsets = y_scale * arange(0, signal.num_channels) else: y_offsets = y_scale * arange(0, -signal.num_channels, -1) y_separated = y + y_offsets if zero_lines: ax.hlines(y_offsets, *time_range, colors="grey", lw=1) kwargs.update(dict(color=color, lw=lw)) ax.plot(t, y_separated, **kwargs) ax.set_xlim(time_range) if not tight_ylims: ax.set_ylim(_get_global_ylims(signal, y_scale)) if time_grid: ax.set_xlabel("Time (s)") else: ax.grid(False, which="x") ax.set_xticks([]) if not y_grid: ax.grid(False, which="y") ax.set_yticks([]) return (fig, ax)