Beispiel #1
0
 def plot_grid(self):
     fig, ax = subplots(figsize=paperfig(0.55, 0.55))
     ax.imshow(
         self.data_matrix,
         cmap=self.cmap,
         origin="lower",
         aspect="auto",
         norm=self.norm,
     )
     ax.set_xlabel("Ripple definition (μV)")
     ax.set_ylabel("Sharp wave definition (μV)")
     num_ripple = len(self.ripple_thresholds)
     num_SW = len(self.SW_thresholds)
     ax.set_xticks(arange(num_ripple)[::2])
     ax.set_yticks(arange(num_SW)[::2])
     ax.set_xticklabels([f"{x:.0f}" for x in self.ripple_thresholds][::2])
     ax.set_yticklabels([f"{x:.0f}" for x in self.SW_thresholds][::2])
     for SW_ix in range(num_SW):
         for ripple_ix in range(num_ripple):
             value = self.data_matrix[SW_ix][ripple_ix]
             text = self.fstring.format(x=value)
             ax.text(
                 ripple_ix,
                 SW_ix,
                 text,
                 ha="center",
                 va="center",
                 color=self.text_color(value),
                 size="smaller",
             )
     ax.grid(False)
     fig.tight_layout()
     self.output_grid.write(fig)
Beispiel #2
0
 def work(self):
     fig, axes = subplots(
         nrows=2,
         ncols=2,
         # figsize=1.1 * array([9.8, 9]),
         figsize=paperfig(1, 0.82),
         gridspec_kw=dict(width_ratios=[1.63, 1], height_ratios=[1, 1.63]),
     )
     ax_PR: Axes = axes[1, 0]
     ax_delay_P: Axes = axes[1, 1]
     ax_delay_R: Axes = axes[0, 0]
     axes[0, 1].remove()
     self.setup_axes(ax_PR, ax_delay_P, ax_delay_R)
     self.plot_PR_curves(ax_PR)
     self.shade_under_PR_curves(ax_PR)
     self.plot_delays(ax_delay_P, ax_delay_R)
     self.mark_cutoffs(ax_PR, ax_delay_P, ax_delay_R)
     add_colored_legend(
         fig,
         self.titles,
         self.colors,
         loc="upper right",
         bbox_to_anchor=(0.965, 0.965),
     )
     fig.tight_layout()
     self.output().write(fig)
Beispiel #3
0
 def plot_grid(self):
     downscale = 1.65
     self.rowheight /= downscale
     self.col_pad /= downscale
     num_gridrows = len(self.num_delays)
     nrows = num_gridrows + 1
     ncols = len(self.channel_combo_names)
     figwidth = 1 / downscale + 1.8 * ncols / downscale
     figheight = 1 / downscale + self.rowheight * nrows
     rel_rowheight = SearchGrid.rowheight / self.rowheight
     channelmap_rel_height = 2.5 * rel_rowheight
     fig, axes = subplots(
         nrows=nrows,
         ncols=ncols,
         figsize=(figwidth, figheight),
         gridspec_kw=dict(height_ratios=[1] * num_gridrows +
                          [channelmap_rel_height]),
     )
     self.plot_grid_cells(axes)
     self.plot_channelmaps(axes)
     self.label_cell_x(axes[0, 0])
     self.label_cell_y(axes[0, -1])
     add_colored_legend(
         fig,
         ("GEVec", "Online BPF"),
         (self.GEVec_color, self.sota_color),
         ncol=2,
         loc="upper center",
     )
     self.label_grid(fig, axes)
     rect = (0.07, rel_rowheight * 0.013, 1, 1 - rel_rowheight * 0.04)
     fig.tight_layout(w_pad=self.col_pad, rect=rect)
     self.output_grid.write(fig)
Beispiel #4
0
 def make_figure(self, time_range):
     nrows = 1 + len(self.extra_signals)
     axheights = ones(nrows)
     figheight = 0.30 + 0.14 * len(self.extra_signals)
     if not self.reference_channel_only:
         num_channels = self.multichannel_test.num_channels
         axheights[0] = 1 + num_channels / 4
         figheight *= 1.2
     fig, axes = subplots(
         nrows=nrows,
         sharex=True,
         figsize=paperfig(width=self.figwidth, height=0.9 * figheight),
         gridspec_kw=dict(height_ratios=axheights),
     )
     input_ax = axes[0]
     extra_axes = axes[1:]
     self.plot_input_signal(time_range, input_ax)
     self.plot_other_signals(time_range, extra_axes)
     self.post_plot(time_range, input_ax, extra_axes)
     add_segments(input_ax, self.reference_segs_test)
     add_time_scalebar(
         extra_axes[0],
         select_scalebar_time(time_range),
         "ms",
         pos_along=0.73,
         pos_across=1.25,
         in_layout=False,
     )
     fig.tight_layout(rect=(0.02, 0, 1, 1))
     return fig
Beispiel #5
0
 def work(self):
     tups = zip(self.trainers, self.convolvers, self.colors, self.output())
     for trainer, convolver, color, filetarget in tups:
         fig, ax = subplots(figsize=(5 + convolver.num_delays / 3, 5))
         GEVec = trainer.output().read()
         signal = trainer.multichannel_train
         num_channels = signal.num_channels
         weights = shape_GEVec(GEVec, num_channels)
         wmax = max(abs(weights))
         image = ax.imshow(
             weights, origin="lower", cmap="PiYG", vmin=-wmax, vmax=wmax
         )
         cbar = fig.colorbar(image, ax=ax, shrink=0.6)
         cbar.set_label("Weight")
         ax.set_xticks(convolver.delays)
         ax.set_yticks(convolver.channels)
         ax.set_xlim(ax.get_xlim()[::-1])
         if signal.fs != 1000:
             warn('Delay axis scale "(ms)" in GEVec plot is not correct.')
         ax.set_xlabel("Delay (ms)")
         ax.set_ylabel("Channel")
         title = f"GEVec\n({convolver.filename})"
         ax.set_title(title, color=color)
         ax.grid(False)
         fig.tight_layout()
         filetarget.write(fig)
Beispiel #6
0
 def work(self):
     fig, axes = subplots(nrows=2, sharex=True, figsize=paperfig())
     ax_top, ax_btm = axes
     ax_btm.set_xlabel("Number of delays")
     self.plot_on_axes(ax_top, ax_btm)
     self.add_legend(ax_top)
     fig.tight_layout()
     self.output().write(fig)
Beispiel #7
0
 def work(self):
     with PdfPages(self.output().path_string) as pdf:
         for time_range in self.time_ranges:
             fig, ax = subplots()
             plot_signal(self.signal, time_range, ax=ax, time_grid=False)
             add_time_scalebar(ax, 500, "ms", pos_across=-0.04)
             add_voltage_scalebar(ax, pos_across=-0.04)
             t = format_duration(time_range[0], auto_ms=False)
             ax.set_title(f"{t} since start of recording.")
             fig.tight_layout()
             pdf.savefig(fig)
             close(fig)
Beispiel #8
0
 def plot_colorbar(self):
     fig, ax = subplots(figsize=(1, 2))
     cbar = ColorbarBase(
         ax=ax,
         label=self.colorbar_label,
         norm=self.norm,
         cmap=self.cmap,
         extend="both",
         format=fraction,
     )
     fig.tight_layout()
     self.output_colorbar.write(fig)
Beispiel #9
0
 def plot_colorbar(self):
     fig, ax = subplots(figsize=paperfig(0.42, 0.16))
     cbar = ColorbarBase(
         ax=ax,
         orientation="horizontal",
         label=self.colorbar_label,
         norm=self.norm,
         cmap=self.cmap,
         extend="both",
         format=StrMethodFormatter(self.fstring),
     )
     fig.tight_layout()
     self.output_colorbar.write(fig)
Beispiel #10
0
 def work(self):
     fig, ax = subplots(figsize=paperfig(0.59, 0.497))
     self.setup_ax(ax)
     self.plot_PR(ax)
     self.shade_under_PR_curves(ax)
     self.plot_iso_F_curves(ax)
     self.mark_selected_threshold(ax)
     add_colored_legend(
         ax,
         labels + ("Iso-$F_2$-curves", ),
         colors + (iso_F_color, ),
         loc=(0.03, 0.03),
     )
     fig.tight_layout()
     self.output().write(fig)
Beispiel #11
0
 def work(self):
     nrows = 5
     axheights = ones(nrows)
     axheights[0] = 2
     axheights[1:3] = 0.84
     for trange, output in zip(config.time_ranges, self.outputs):
         log.info(f"Generating figure {output.filename}")
         fig, axes = subplots(
             nrows=nrows,
             figsize=paperfig(0.57, 0.75),
             gridspec_kw=dict(height_ratios=axheights),
         )
         self.plot_input(axes[0], trange)
         self.plot_offline(axes[1:3], trange)
         self.plot_online(axes[3:], trange)
         add_time_scalebar(axes[0], 200, in_layout=False, pos_along=0.56)
         fig.tight_layout()
         output.write(fig)
Beispiel #12
0
 def work(self):
     fig, axes = subplots(
         nrows=6,
         ncols=2,
         figsize=paperfig(1.2, 1.31),
         gridspec_kw=dict(width_ratios=(1, 0.24),
                          height_ratios=(1, 1, 1, 1, 1.22, 0.8)),
     )
     self.remove_empty_axes(axes)
     self.plot_wideband(axes[0, 0])
     self.plot_filter_output(axes[1, 0])
     self.plot_analytic(axes[2, 0])
     self.plot_smoothed_envelope(axes[3, 0])
     self.plot_thresholded_envelope(axes[4, 0], axes[4, 1])
     self.plot_segments(axes[5, 0])
     with ignore(UserWarning):
         fig.tight_layout()
     self.output().write(fig)
Beispiel #13
0
 def work(self):
     fig, (ax_top, ax_btm) = subplots(nrows=2, figsize=paperfig(0.60, 0.57))
     ax_btm: Axes
     te_bpf, te_rnn = get_tes()
     delays = concat([
         DataFrame(
             dict(
                 Absolute=te_rnn.abs_delays,
                 Relative=te_rnn.rel_delays,
                 Algorithm="RNN",
             )),
         DataFrame(
             dict(
                 Absolute=te_bpf.abs_delays,
                 Relative=te_bpf.rel_delays,
                 Algorithm="BPF",
             )),
     ])
     delays.Absolute *= 1000
     delays = delays[delays.Absolute < 100]
     kwargs = dict(
         data=delays,
         y="Algorithm",
         alpha_dot=0.4,
         palette=colors,
         order=("BPF", "RNN"),
         width_kde=0.8,
         width_box=0.1,
         jitter=0.12,
         ms=2,
     )
     distplot(x="Absolute", ax=ax_top, **kwargs)
     distplot(x="Relative", ax=ax_btm, **kwargs)
     ax_btm.xaxis.set_major_formatter(fraction)
     for ax in (ax_top, ax_btm):
         ax.set_yticklabels([])
         ax.set_ylabel("")
         ax.set_xlabel("")
     # ax_top.set_xlabel("Absolute detection latency (ms)")
     # ax_btm.set_xlabel("Relative detection latency")
     fig.tight_layout(h_pad=4)
     self.output().write(fig)
Beispiel #14
0
    def work(self):
        # fmt: off
        df: DataFrame = concat((
                 DataFrame(dict(proj=S.T @ -SVecs[:, 0],
                                method='PCA',
                                src='Signal')),
                 DataFrame(dict(proj=N.T @ -SVecs[:, 0],
                                method='PCA',
                                src='Noise')),
                 DataFrame(dict(proj=S.T @ GEVecs[:, -1],
                                method='GEVec',
                                src='Signal')),
                 DataFrame(dict(proj=N.T @ GEVecs[:, -1],
                                method='GEVec',
                                src='Noise')),
        ))
        # fmt: on
        def normalise(method):
            select = (df.method == method, "proj")
            scale = max(abs(df.loc[select]))
            df.loc[select] /= scale

        normalise("PCA")
        normalise("GEVec")

        fig, ax = subplots()
        stripplot(data=df, x="proj", y="method", hue="src", dodge=0.5, ax=ax)
        ax.legend_.remove()
        ax.spines["bottom"].set_visible(True)
        ax.set_xticks([])
        ax.set_xlabel("Projection on 1st eigenvector")
        ax.set_ylabel("")
        ax.yaxis.set_tick_params(labelcolor="black", labelsize=fontsize, pad=16)
        for label in ax.get_yticklabels():
            label: Text
            label.set_color(colors[label.get_text()])
        fig.tight_layout()
        self.output().write(fig)
Beispiel #15
0
 def work(self):
     fig, axes = subplots(
         nrows=2, ncols=2, figsize=paperfig(width=1.2, height=0.55)
     )
     ax_top_left: Axes = axes[0, 0]
     ax_top_right: Axes = axes[0, 1]
     ax_bottom_left: Axes = axes[1, 0]
     ax_bottom_right: Axes = axes[1, 1]
     ax_top_right.remove()
     ax_bottom_left.set_xlabel("Frequency (Hz)")
     ax_bottom_right.set_xlabel("Frequency (Hz)")
     ax_gain_dB = ax_top_left
     ax_gain = ax_bottom_left
     ax_grpdelay = ax_bottom_right
     ax_gain.set_ylabel("Gain")
     ax_gain_dB.set_ylabel("Gain (dB)")
     ax_gain_dB.set_ylim(-63, 4)
     ax_grpdelay.set_ylabel("Group delay (ms)")
     # Force zero-line in view:
     ax_grpdelay.axhline(y=0, color="none")
     f_max = 500
     margin = 10  # To avoid phase discontinutities. In Hz.
     f = linspace(margin, f_max - margin, 10000)
     for filta in (self.filter_original, self.filter_replica):
         H = filta.freqresp(f)
         g = gain(H)
         ax_gain.plot(f, g)
         ax_gain_dB.plot(f, dB(g))
         ax_grpdelay.plot(f, group_delay(H, f))
     add_colored_legend(
         fig,
         (self.label_original, self.label_replica),
         loc="lower left",
         bbox_to_anchor=(0.5, 0.6),
     )
     fig.tight_layout(w_pad=3)
     self.output().write(fig)
Beispiel #16
0
    def work(self):
        
        fig, axes = subplots(ncols=2, figsize=paperfig(1.1, 0.5))
        eps = 0.001
        lims = (-eps, 1 + eps)
        R = linspace(0, 1, 1000)

        for i, beta in enumerate([1, 2]):
            ax = axes[i]
            color = f"C{i}"
            ax.set_title(f"$F_{beta}$", color=color)
            ax.set_xlim(lims)
            ax.set_ylim(lims)
            ax.set_aspect("equal")
            ax.plot(R, R, "grey", lw=0.5)
            ax.set_xlabel("Recall")
            ax.set_xticks([0, 0.25, 0.50, 0.75, 1])
            ax.set_yticks([0, 0.25, 0.50, 0.75, 1])
            if i == 0:
                ax.set_ylabel("Precision")
            ax.xaxis.set_major_formatter(fraction)
            ax.yaxis.set_major_formatter(fraction)
            for F in [0.4, 0.6, 0.8, 0.9, 0.95]:
                P = iso_F_line(R, F, beta)
                dom = P > 0
                ax.plot(R[dom], P[dom], color=color)
                ax.text(
                    s=f"{F:.0%}",
                    color=color,
                    x=lims[1] + 0.02,
                    y=P[-1],
                    va="center",
                )

        fig.tight_layout(w_pad=3)
        self.output().write(fig)
Beispiel #17
0
 def work(self):
     fig, ax = subplots(figsize=(5, 5))
     scatter(S, ax, colors["Signal"])
     scatter(N, ax, colors["Noise"])
     plot_vector(SVals[0] * -SVecs[:, 0], ax, colors["PCA"])
     # plot_vector(SVals_N[1] * -SVecs_N[:,1], ax, 'black')
     # plot_vector(SVals[1] * SVecs[:,1], ax)
     plot_vector(40 * sqrt(GEVals[-1]) * GEVecs[:, -1], ax, colors["GEVec"])
     # plot_vector(40*np.sqrt(gevl[-2]) * gevc[:,-2], ax, 'C5')
     ax.set_aspect("equal")
     lims = 11
     ax.set_xlim(-lims, lims)
     ax.set_ylim(-lims, lims)
     for s, loc in (["Signal", (0.52, 0.26)], ["Noise", (0.7, 0.46)]):
         fig.text(*loc, s, color=colors[s], fontsize=fontsize)
     ax.set_xlabel("Electrode 3")
     ax.set_ylabel("Electrode 12")
     ax.grid(False)
     ax.set_xticks([])
     ax.set_yticks([])
     ax.spines["bottom"].set_visible(True)
     ax.spines["left"].set_visible(True)
     fig.tight_layout()
     self.output().write(fig)
Beispiel #18
0
def plot_signal(signal: Signal,
                time_range: Tuple[float, float],
                y_scale: float = 500,
                height: float = 0.5,
                channels: Optional[ndarray] = None,
                bottom_first: bool = True,
                tight_ylims: bool = False,
                zero_lines: bool = True,
                time_grid: bool = True,
                y_grid: Optional[bool] = None,
                color: Color = "black",
                lw: float = 0.9,
                ax: Optional[Axes] = None,
                **kwargs) -> (Figure, Axes):
    """
    Plot a time-slice of a single- or multichannel signal.

    When the signal is multichannel, each channel will be plotted with the same
    scale.

    :param time_range:  Time slice to plot. In seconds.
    :param channels:  Which channels to plot. Plots all channels by default.
    :param height:  Height of each channel, in inches. Only relevant when no
                Axes is given.
    :param y_scale:  How much data-y-units the visual vertical spacing between
                channels represents.
    :param bottom_first:  If True (default), the first channel will be
                plotted at the bottom of the figure.
    :param tight_ylims:  If True, adapts the ylims to tightly fit the data
                visible in `time_range`. If False (default), makes sure that
                plots of different time ranges of `signal` will all have the
                same ylims.
    :param zero_lines:  Whether to plot grey y==0 lines in each channel.
    :param time_grid:  Whether to plot vertical gridlines, with corresponding
                absolute time ticks and ticklabels.
    :param y_grid:  Whether to plot horizontal gridlines, with corresponding
                y-ticks and -ticklabels. By default, only plots a y-grid if the
                signal is single-channel and `zero_lines` is False.
    :param ax:  The axes to plot on. If None (default), creates a new figure and
                axes.
    :param kwargs:  Passed on to `ax.plot()`.
    """
    signal = signal.as_matrix()
    if ax is None:
        fig, ax = subplots(figsize=(12, height * signal.num_channels))
    else:
        fig = ax.get_figure()
    if y_grid is None:
        if signal.num_channels == 1 and not zero_lines:
            y_grid = True
        else:
            y_grid = False
    if channels is None:
        channels = arange(signal.num_channels)
    ix = time_to_index(time_range,
                       signal.fs,
                       arr_size=signal.num_samples,
                       clip=True)
    y: Signal = signal[slice(*ix), channels]
    t = y.get_time_vector(t0=time_range[0])
    if bottom_first:
        y_offsets = y_scale * arange(0, signal.num_channels)
    else:
        y_offsets = y_scale * arange(0, -signal.num_channels, -1)
    y_separated = y + y_offsets
    if zero_lines:
        ax.hlines(y_offsets, *time_range, colors="grey", lw=1)
    kwargs.update(dict(color=color, lw=lw))
    ax.plot(t, y_separated, **kwargs)
    ax.set_xlim(time_range)
    if not tight_ylims:
        ax.set_ylim(_get_global_ylims(signal, y_scale))
    if time_grid:
        ax.set_xlabel("Time (s)")
    else:
        ax.grid(False, which="x")
        ax.set_xticks([])
    if not y_grid:
        ax.grid(False, which="y")
        ax.set_yticks([])
    return (fig, ax)