Exemple #1
0
def test_vlayout(lcfg):
    nplots = 7
    layout = RendererLayout(lcfg, [1] * nplots)

    assert layout.wave_ncol == 3
    assert layout.wave_nrow == 3

    region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
    assert len(region2d) == nplots
    for i, regions in enumerate(region2d):
        assert len(regions) == 1, (i, len(regions))

    np.testing.assert_equal(region2d[0][0].pos, (0, 0))
    np.testing.assert_equal(region2d[2][0].pos, (2, 0))
    np.testing.assert_equal(region2d[3][0].pos, (0, 1))
    np.testing.assert_equal(region2d[6][0].pos, (0, 2))
Exemple #2
0
def test_hlayout(lcfg):
    nplots = 15
    layout = RendererLayout(lcfg, [1] * nplots)

    assert layout.wave_ncol == 2
    assert layout.wave_nrow == 8

    region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
    assert len(region2d) == nplots
    for i, regions in enumerate(region2d):
        assert len(regions) == 1, (i, len(regions))

    np.testing.assert_equal(region2d[0][0].pos, (0, 0))
    np.testing.assert_equal(region2d[1][0].pos, (0, 1))
    np.testing.assert_equal(region2d[2][0].pos, (1, 0))

    m = nplots - 1
    npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2))
Exemple #3
0
def test_stereo_layout(
    orientation: Orientation,
    stereo_orientation: StereoOrientation,
    wave_nchans: List[int],
    nrow_ncol: int,
    is_nrows: bool,
):
    """
    Not-entirely-rigorous test for layout computation.
    Mind-numbingly boring to write (and read?).

    Honestly I prefer a good naming scheme in RendererLayout.arrange()
    over unit tests.

    - This is a regression test...
    - And an obstacle to refactoring or feature development.
    """
    # region Setup
    if is_nrows:
        nrows = nrow_ncol
        ncols = None
    else:
        nrows = None
        ncols = nrow_ncol

    lcfg = LayoutConfig(
        orientation=orientation,
        nrows=nrows,
        ncols=ncols,
        stereo_orientation=stereo_orientation,
    )
    nwaves = len(wave_nchans)
    layout = RendererLayout(lcfg, wave_nchans)
    # endregion

    # Assert layout dimensions correct
    assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows)
    assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols)

    region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec)

    # Loop through layout regions
    assert len(region2d) == len(wave_nchans)
    for wave_i, wave_chans in enumerate(region2d):
        stereo_nchan = wave_nchans[wave_i]
        assert len(wave_chans) == stereo_nchan

        # Compute channel dims within wave.
        if stereo_orientation == StereoOrientation.overlay:
            chans_per_wave = [1, 1]
        elif stereo_orientation == StereoOrientation.v:  # pos[0]++
            chans_per_wave = [stereo_nchan, 1]
        else:
            assert stereo_orientation == StereoOrientation.h  # pos[1]++
            chans_per_wave = [1, stereo_nchan]

        # Sanity-check position of channel 0 relative to origin (wave grid).
        assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0)
        npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0)

        for chan_j, chan in enumerate(wave_chans):
            # Assert 0 <= position < size.
            assert chan.pos.shape == chan.size.shape == (2, )
            assert (0 <= chan.pos).all()
            assert (chan.pos < chan.size).all()

            # Sanity-check position of chan relative to origin (wave grid).
            npt.assert_equal(chan.pos // chans_per_wave,
                             wave_chans[0].pos // chans_per_wave)

            # Check position of region (relative to channel 0)
            chan_wave_pos = chan.pos - wave_chans[0].pos

            if stereo_orientation == StereoOrientation.overlay:
                npt.assert_equal(chan_wave_pos, [0, 0])
            elif stereo_orientation == StereoOrientation.v:  # pos[0]++
                npt.assert_equal(chan_wave_pos, [chan_j, 0])
            else:
                assert stereo_orientation == StereoOrientation.h  # pos[1]++
                npt.assert_equal(chan_wave_pos, [0, chan_j])

            # Check screen edges
            screen_edges = chan.screen_edges
            assert bool(screen_edges & Edges.Top) == (chan.row == 0)
            assert bool(screen_edges & Edges.Left) == (chan.col == 0)
            assert bool(screen_edges
                        & Edges.Bottom) == (chan.row == chan.nrow - 1)
            assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol -
                                                        1)

            # Check stereo edges
            wave_edges = chan.wave_edges
            if stereo_orientation == StereoOrientation.overlay:
                assert wave_edges == ~Edges.NONE
            elif stereo_orientation == StereoOrientation.v:  # pos[0]++
                lr = Edges.Left | Edges.Right
                assert wave_edges & lr == lr
                assert bool(wave_edges & Edges.Top) == (chan.row %
                                                        stereo_nchan == 0)
                assert bool(wave_edges & Edges.Bottom) == ((chan.row + 1) %
                                                           stereo_nchan == 0)
            else:
                assert stereo_orientation == StereoOrientation.h  # pos[1]++
                tb = Edges.Top | Edges.Bottom
                assert wave_edges & tb == tb
                assert bool(wave_edges & Edges.Left) == (chan.col %
                                                         stereo_nchan == 0)
                assert bool(wave_edges & Edges.Right) == ((chan.col + 1) %
                                                          stereo_nchan == 0)
Exemple #4
0
    def _setup_axes(self, wave_nchans: List[int]) -> None:
        """
        Creates a flat array of Matplotlib Axes, with the new layout.
        Sets up each Axes with correct region limits.
        """

        self.layout = RendererLayout(self.lcfg, wave_nchans)
        self.layout_mono = RendererLayout(self.lcfg, [1] * self.nplots)

        if hasattr(self, "_fig"):
            raise Exception(
                "I don't currently expect to call _setup_axes() twice")
            # plt.close(self.fig)

        cfg = self.cfg

        self._fig = Figure()
        self._canvas_type(self._fig)

        px_inch = PX_INCH / cfg.res_divisor
        self._fig.set_dpi(px_inch)
        """
        Requirements:
        - px_inch /= res_divisor (to scale visual elements correctly)
        - int(set_size_inches * px_inch) == self.w,h
            - matplotlib uses int instead of round. Who knows why.
        - round(set_size_inches * px_inch) == self.w,h
            - just in case matplotlib changes its mind

        Solution:
        - (set_size_inches * px_inch) == self.w,h + 0.25
        - set_size_inches == (self.w,h + 0.25) / px_inch
        """
        offset = 0.25
        self._fig.set_size_inches((self.w + offset) / px_inch,
                                  (self.h + offset) / px_inch)

        real_dims = self._fig.canvas.get_width_height()
        assert (self.w, self.h) == real_dims, [(self.w, self.h), real_dims]

        # Setup background
        self._fig.set_facecolor(cfg.bg_color)

        # Create Axes (using self.lcfg, wave_nchans)
        # _axes2d[wave][chan] = Axes
        self._axes2d = self.layout.arrange(self._axes_factory)
        """
        Adding an axes using the same arguments as a previous axes
        currently reuses the earlier instance.
        In a future version, a new instance will always be created and returned.
        Meanwhile, this warning can be suppressed, and the future behavior ensured,
        by passing a unique label to each axes instance.

        ax=fig.add_axes(label=) is unused, even if you call ax.legend().
        """
        # _axes_mono[wave] = Axes
        self._axes_mono = []
        # Returns 2D list of [self.nplots][1]Axes.
        axes_mono_2d = self.layout_mono.arrange(self._axes_factory,
                                                label="mono")
        for axes_list in axes_mono_2d:
            (axes, ) = axes_list  # type: Axes

            # List of colors at
            # https://matplotlib.org/gallery/color/colormap_reference.html
            # Discussion at https://github.com/matplotlib/matplotlib/issues/10840
            cmap: ListedColormap = get_cmap("Accent")
            colors = cmap.colors
            axes.set_prop_cycle(color=colors)

            self._axes_mono.append(axes)

        # Setup axes
        for idx, N in enumerate(self.wave_nsamps):
            wave_axes = self._axes2d[idx]

            viewport_stride = self.render_strides[idx] * cfg.viewport_width
            ylim = cfg.viewport_height

            def scale_axes(ax: "Axes"):
                xlim = calc_limits(N, viewport_stride)
                ax.set_xlim(*xlim)
                ax.set_ylim(-ylim, ylim)

            scale_axes(self._axes_mono[idx])
            for ax in unique_by_id(wave_axes):
                scale_axes(ax)

                # Setup midlines (depends on max_x and wave_data)
                midline_color = cfg.midline_color
                midline_width = cfg.grid_line_width

                # Not quite sure if midlines or gridlines draw on top
                kw = dict(color=midline_color, linewidth=midline_width)
                if cfg.v_midline:
                    ax.axvline(x=calc_center(viewport_stride), **kw)
                if cfg.h_midline:
                    ax.axhline(y=0, **kw)

        self._save_background()
Exemple #5
0
class AbstractMatplotlibRenderer(_RendererBackend, ABC):
    """Matplotlib renderer which can use any backend (agg, mplcairo).
    To pick a backend, subclass and set _canvas_type at the class level.
    """

    _canvas_type: Type["FigureCanvasBase"] = abstract_classvar

    @staticmethod
    @abstractmethod
    def _canvas_to_bytes(canvas: "FigureCanvasBase") -> ByteBuffer:
        pass

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        dict.__setitem__(matplotlib.rcParams, "lines.antialiased",
                         self.cfg.antialiasing)

        self._setup_axes(self.wave_nchans)

        self._artists: List["Artist"] = []

    _fig: "Figure"

    # _axes2d[wave][chan] = Axes
    # Primary, used to draw oscilloscope lines and gridlines.
    _axes2d: List[List["Axes"]]  # set by set_layout()

    # _axes_mono[wave] = Axes
    # Secondary, used for titles and debug plots.
    _axes_mono: List["Axes"]

    def _setup_axes(self, wave_nchans: List[int]) -> None:
        """
        Creates a flat array of Matplotlib Axes, with the new layout.
        Sets up each Axes with correct region limits.
        """

        self.layout = RendererLayout(self.lcfg, wave_nchans)
        self.layout_mono = RendererLayout(self.lcfg, [1] * self.nplots)

        if hasattr(self, "_fig"):
            raise Exception(
                "I don't currently expect to call _setup_axes() twice")
            # plt.close(self.fig)

        cfg = self.cfg

        self._fig = Figure()
        self._canvas_type(self._fig)

        px_inch = PX_INCH / cfg.res_divisor
        self._fig.set_dpi(px_inch)
        """
        Requirements:
        - px_inch /= res_divisor (to scale visual elements correctly)
        - int(set_size_inches * px_inch) == self.w,h
            - matplotlib uses int instead of round. Who knows why.
        - round(set_size_inches * px_inch) == self.w,h
            - just in case matplotlib changes its mind

        Solution:
        - (set_size_inches * px_inch) == self.w,h + 0.25
        - set_size_inches == (self.w,h + 0.25) / px_inch
        """
        offset = 0.25
        self._fig.set_size_inches((self.w + offset) / px_inch,
                                  (self.h + offset) / px_inch)

        real_dims = self._fig.canvas.get_width_height()
        assert (self.w, self.h) == real_dims, [(self.w, self.h), real_dims]

        # Setup background
        self._fig.set_facecolor(cfg.bg_color)

        # Create Axes (using self.lcfg, wave_nchans)
        # _axes2d[wave][chan] = Axes
        self._axes2d = self.layout.arrange(self._axes_factory)
        """
        Adding an axes using the same arguments as a previous axes
        currently reuses the earlier instance.
        In a future version, a new instance will always be created and returned.
        Meanwhile, this warning can be suppressed, and the future behavior ensured,
        by passing a unique label to each axes instance.

        ax=fig.add_axes(label=) is unused, even if you call ax.legend().
        """
        # _axes_mono[wave] = Axes
        self._axes_mono = []
        # Returns 2D list of [self.nplots][1]Axes.
        axes_mono_2d = self.layout_mono.arrange(self._axes_factory,
                                                label="mono")
        for axes_list in axes_mono_2d:
            (axes, ) = axes_list  # type: Axes

            # List of colors at
            # https://matplotlib.org/gallery/color/colormap_reference.html
            # Discussion at https://github.com/matplotlib/matplotlib/issues/10840
            cmap: ListedColormap = get_cmap("Accent")
            colors = cmap.colors
            axes.set_prop_cycle(color=colors)

            self._axes_mono.append(axes)

        # Setup axes
        for idx, N in enumerate(self.wave_nsamps):
            wave_axes = self._axes2d[idx]

            viewport_stride = self.render_strides[idx] * cfg.viewport_width
            ylim = cfg.viewport_height

            def scale_axes(ax: "Axes"):
                xlim = calc_limits(N, viewport_stride)
                ax.set_xlim(*xlim)
                ax.set_ylim(-ylim, ylim)

            scale_axes(self._axes_mono[idx])
            for ax in unique_by_id(wave_axes):
                scale_axes(ax)

                # Setup midlines (depends on max_x and wave_data)
                midline_color = cfg.midline_color
                midline_width = cfg.grid_line_width

                # Not quite sure if midlines or gridlines draw on top
                kw = dict(color=midline_color, linewidth=midline_width)
                if cfg.v_midline:
                    ax.axvline(x=calc_center(viewport_stride), **kw)
                if cfg.h_midline:
                    ax.axhline(y=0, **kw)

        self._save_background()

    transparent = "#00000000"

    # satisfies RegionFactory
    def _axes_factory(self, r: RegionSpec, label: str = "") -> "Axes":
        cfg = self.cfg

        width = 1 / r.ncol
        left = r.col / r.ncol
        assert 0 <= left < 1

        height = 1 / r.nrow
        bottom = (r.nrow - r.row - 1) / r.nrow
        assert 0 <= bottom < 1

        # Disabling xticks/yticks is unnecessary, since we hide Axises.
        ax = self._fig.add_axes([left, bottom, width, height],
                                xticks=[],
                                yticks=[],
                                label=label)

        grid_color = cfg.grid_color
        if grid_color:
            # Initialize borders
            # Hide Axises
            # (drawing them is very slow, and we disable ticks+labels anyway)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # Background color
            # ax.patch.set_fill(False) sets _fill=False,
            # then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
            # It is no faster than below.
            ax.set_facecolor(self.transparent)

            # Set border colors
            for spine in ax.spines.values():  # type: Spine
                spine.set_linewidth(cfg.grid_line_width)
                spine.set_color(grid_color)

            def hide(key: str):
                ax.spines[key].set_visible(False)

            # Hide all axes except bottom-right.
            hide("top")
            hide("left")

            # If bottom of screen, hide bottom. If right of screen, hide right.
            if r.screen_edges & Edges.Bottom:
                hide("bottom")
            if r.screen_edges & Edges.Right:
                hide("right")

            # Dim stereo gridlines
            if cfg.stereo_grid_opacity > 0:
                dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
                dim_color[-1] = cfg.stereo_grid_opacity

                def dim(key: str):
                    ax.spines[key].set_color(dim_color)

            else:
                dim = hide

            # If not bottom of wave, dim bottom. If not right of wave, dim right.
            if not r.wave_edges & Edges.Bottom:
                dim("bottom")
            if not r.wave_edges & Edges.Right:
                dim("right")

        else:
            ax.set_axis_off()

        return ax

    # Public API
    def add_lines_stereo(self, dummy_datas: List[np.ndarray],
                         strides: List[int]) -> UpdateLines:
        cfg = self.cfg

        # Plot lines over background
        line_width = cfg.line_width

        # Foreach wave, plot dummy data.
        lines2d = []
        for wave_idx, wave_data in enumerate(dummy_datas):
            wave_zeros = np.zeros_like(wave_data)

            wave_axes = self._axes2d[wave_idx]
            wave_lines = []

            xs = calc_xs(len(wave_zeros), strides[wave_idx])

            # Foreach chan
            for chan_idx, chan_zeros in enumerate(wave_zeros.T):
                ax = wave_axes[chan_idx]
                line_color = self._line_params[wave_idx].color
                chan_line: Line2D = ax.plot(xs,
                                            chan_zeros,
                                            color=line_color,
                                            linewidth=line_width)[0]
                wave_lines.append(chan_line)

            lines2d.append(wave_lines)
            self._artists.extend(wave_lines)

        return lambda datas: self._update_lines_stereo(lines2d, datas)

    @staticmethod
    def _update_lines_stereo(lines2d: "List[List[Line2D]]",
                             datas: List[np.ndarray]) -> None:
        """
        Preconditions:
        - lines2d[wave][chan] = Line2D
        - datas[wave] = ndarray, [samp][chan] = FLOAT
        """
        nplots = len(lines2d)
        ndata = len(datas)
        if nplots != ndata:
            raise ValueError(
                f"incorrect data to plot: {nplots} plots but {ndata} dummy_datas"
            )

        # Draw waveform data
        # Foreach wave
        for wave_idx, wave_data in enumerate(datas):
            wave_lines = lines2d[wave_idx]

            # Foreach chan
            for chan_idx, chan_data in enumerate(wave_data.T):
                chan_line = wave_lines[chan_idx]
                chan_line.set_ydata(chan_data)

    def _add_xy_line_mono(self, wave_idx: int, xs: Sequence[float],
                          ys: Sequence[float], stride: int) -> CustomLine:
        cfg = self.cfg

        # Plot lines over background
        line_width = cfg.line_width

        ax = self._axes_mono[wave_idx]
        mono_line: Line2D = ax.plot(xs, ys, linewidth=line_width)[0]

        self._artists.append(mono_line)

        # noinspection PyTypeChecker
        return CustomLine(stride, xs, mono_line.set_xdata, mono_line.set_ydata)

    # Channel labels
    def add_labels(self, labels: List[str]) -> List["Text"]:
        """
        Updates background, adds text.
        Do NOT call after calling self.add_lines().
        """
        nlabel = len(labels)
        if nlabel != self.nplots:
            raise ValueError(
                f"incorrect labels: {self.nplots} plots but {nlabel} labels")

        cfg = self.cfg
        color = cfg.get_label_color

        size_pt = cfg.label_font.size
        distance_px = cfg.label_padding_ratio * size_pt

        @attr.dataclass
        class AxisPosition:
            pos_axes: float
            offset_px: float
            align: str

        xpos = cfg.label_position.x.match(
            left=AxisPosition(0, distance_px, "left"),
            right=AxisPosition(1, -distance_px, "right"),
        )
        ypos = cfg.label_position.y.match(
            bottom=AxisPosition(0, distance_px, "bottom"),
            top=AxisPosition(1, -distance_px, "top"),
        )

        pos_axes = (xpos.pos_axes, ypos.pos_axes)
        offset_pt = (xpos.offset_px, ypos.offset_px)

        out: List["Text"] = []
        for label_text, ax in zip(labels, self._axes_mono):
            # https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.annotate.html
            # Annotation subclasses Text.
            text: "Annotation" = ax.annotate(
                label_text,
                # Positioning
                xy=pos_axes,
                xycoords="axes fraction",
                xytext=offset_pt,
                textcoords="offset points",
                horizontalalignment=xpos.align,
                verticalalignment=ypos.align,
                # Cosmetics
                color=color,
                fontsize=px_from_points(size_pt),
                fontfamily=cfg.label_font.family,
                fontweight=("bold" if cfg.label_font.bold else "normal"),
                fontstyle=("italic" if cfg.label_font.italic else "normal"),
            )
            out.append(text)

        self._save_background()
        return out

    # Output frames
    def get_frame(self) -> ByteBuffer:
        """Returns bytes with shape (h, w, self.bytes_per_pixel).
        The actual return value's shape may be flat.
        """
        self._redraw_over_background()

        canvas = self._fig.canvas

        # Agg is the default noninteractive backend except on OSX.
        # https://matplotlib.org/faq/usage_faq.html
        if not isinstance(canvas, self._canvas_type):
            raise RuntimeError(
                f"oh shit, cannot read data from {obj_name(canvas)} != {self._canvas_type.__name__}"
            )

        buffer_rgb = self._canvas_to_bytes(canvas)
        assert len(buffer_rgb) == self.w * self.h * self.bytes_per_pixel

        return buffer_rgb

    # Pre-rendered background
    bg_cache: Any  # "matplotlib.backends._backend_agg.BufferRegion"

    def _save_background(self) -> None:
        """ Draw static background. """
        # https://stackoverflow.com/a/8956211
        # https://matplotlib.org/api/animation_api.html#funcanimation
        fig = self._fig

        fig.canvas.draw()
        self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)

    def _redraw_over_background(self) -> None:
        """ Redraw animated elements of the image. """

        # Both FigureCanvasAgg and FigureCanvasCairo, but not FigureCanvasBase,
        # support restore_region().
        canvas: FigureCanvasAgg = self._fig.canvas
        canvas.restore_region(self.bg_cache)

        for artist in self._artists:
            artist.axes.draw_artist(artist)
Exemple #6
0
    def _set_layout(self, wave_nchans: List[int]) -> None:
        """
        Creates a flat array of Matplotlib Axes, with the new layout.
        Opens a window showing the Figure (and Axes).

        Inputs: self.cfg, self.fig
        Outputs: self.nrows, self.ncols, self.axes
        """

        self.layout = RendererLayout(self.lcfg, wave_nchans)

        # Create Axes
        # https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
        if hasattr(self, "_fig"):
            raise Exception(
                "I don't currently expect to call _set_layout() twice")
            # plt.close(self.fig)

        grid_color = self.cfg.grid_color
        self._fig = Figure()
        FigureCanvasAgg(self._fig)

        # RegionFactory
        def axes_factory(r: RegionSpec) -> "Axes":
            width = 1 / r.ncol
            left = r.col / r.ncol
            assert 0 <= left < 1

            height = 1 / r.nrow
            bottom = (r.nrow - r.row - 1) / r.nrow
            assert 0 <= bottom < 1

            # Disabling xticks/yticks is unnecessary, since we hide Axises.
            ax = self._fig.add_axes([left, bottom, width, height],
                                    xticks=[],
                                    yticks=[])

            if grid_color:
                # Initialize borders
                # Hide Axises
                # (drawing them is very slow, and we disable ticks+labels anyway)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)

                # Background color
                # ax.patch.set_fill(False) sets _fill=False,
                # then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
                # It is no faster than below.
                ax.set_facecolor(self.transparent)

                # Set border colors
                for spine in ax.spines.values():
                    spine.set_color(grid_color)

                def hide(key: str):
                    ax.spines[key].set_visible(False)

                # Hide all axes except bottom-right.
                hide("top")
                hide("left")

                # If bottom of screen, hide bottom. If right of screen, hide right.
                if r.screen_edges & Edges.Bottom:
                    hide("bottom")
                if r.screen_edges & Edges.Right:
                    hide("right")

                # Dim stereo gridlines
                if self.cfg.stereo_grid_opacity > 0:
                    dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
                    dim_color[-1] = self.cfg.stereo_grid_opacity

                    def dim(key: str):
                        ax.spines[key].set_color(dim_color)

                else:
                    dim = hide

                # If not bottom of wave, dim bottom. If not right of wave, dim right.
                if not r.wave_edges & Edges.Bottom:
                    dim("bottom")
                if not r.wave_edges & Edges.Right:
                    dim("right")

            else:
                ax.set_axis_off()

            return ax

        # Generate arrangement (using self.lcfg, wave_nchans)
        # _axes2d[wave][chan] = Axes
        self._axes2d = self.layout.arrange(axes_factory)

        # Setup figure geometry
        self._fig.set_dpi(DPI)
        self._fig.set_size_inches(self.cfg.width / DPI, self.cfg.height / DPI)
Exemple #7
0
class MatplotlibRenderer(Renderer):
    """
    Renderer backend which takes data and produces images.
    Does not touch Wave or Channel.

    If __init__ reads cfg, cfg cannot be hotswapped.

    Reasons to hotswap cfg: RendererCfg:
    - GUI preview size
    - Changing layout
    - Changing #smp drawn (samples_visible)
    (see RendererCfg)

        Original OVGen does not support hotswapping.
        It disables changing options during rendering.

    Reasons to hotswap trigger algorithms:
    - changing scan_nsamp (cannot be hotswapped, since correlation buffer is incompatible)
    So don't.
    """
    def __init__(self, *args, **kwargs):
        Renderer.__init__(self, *args, **kwargs)

        dict.__setitem__(matplotlib.rcParams, "lines.antialiased",
                         self.cfg.antialiasing)

        self._fig: "Figure"

        # _axes2d[wave][chan] = Axes
        self._axes2d: List[List["Axes"]]  # set by set_layout()

        # _lines2d[wave][chan] = Line2D
        self._lines2d: List[List[Line2D]] = []
        self._lines_flat: List["Line2D"] = []

    transparent = "#00000000"

    layout: RendererLayout

    def _set_layout(self, wave_nchans: List[int]) -> None:
        """
        Creates a flat array of Matplotlib Axes, with the new layout.
        Opens a window showing the Figure (and Axes).

        Inputs: self.cfg, self.fig
        Outputs: self.nrows, self.ncols, self.axes
        """

        self.layout = RendererLayout(self.lcfg, wave_nchans)

        # Create Axes
        # https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
        if hasattr(self, "_fig"):
            raise Exception(
                "I don't currently expect to call _set_layout() twice")
            # plt.close(self.fig)

        grid_color = self.cfg.grid_color
        self._fig = Figure()
        FigureCanvasAgg(self._fig)

        # RegionFactory
        def axes_factory(r: RegionSpec) -> "Axes":
            width = 1 / r.ncol
            left = r.col / r.ncol
            assert 0 <= left < 1

            height = 1 / r.nrow
            bottom = (r.nrow - r.row - 1) / r.nrow
            assert 0 <= bottom < 1

            # Disabling xticks/yticks is unnecessary, since we hide Axises.
            ax = self._fig.add_axes([left, bottom, width, height],
                                    xticks=[],
                                    yticks=[])

            if grid_color:
                # Initialize borders
                # Hide Axises
                # (drawing them is very slow, and we disable ticks+labels anyway)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)

                # Background color
                # ax.patch.set_fill(False) sets _fill=False,
                # then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
                # It is no faster than below.
                ax.set_facecolor(self.transparent)

                # Set border colors
                for spine in ax.spines.values():
                    spine.set_color(grid_color)

                def hide(key: str):
                    ax.spines[key].set_visible(False)

                # Hide all axes except bottom-right.
                hide("top")
                hide("left")

                # If bottom of screen, hide bottom. If right of screen, hide right.
                if r.screen_edges & Edges.Bottom:
                    hide("bottom")
                if r.screen_edges & Edges.Right:
                    hide("right")

                # Dim stereo gridlines
                if self.cfg.stereo_grid_opacity > 0:
                    dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
                    dim_color[-1] = self.cfg.stereo_grid_opacity

                    def dim(key: str):
                        ax.spines[key].set_color(dim_color)

                else:
                    dim = hide

                # If not bottom of wave, dim bottom. If not right of wave, dim right.
                if not r.wave_edges & Edges.Bottom:
                    dim("bottom")
                if not r.wave_edges & Edges.Right:
                    dim("right")

            else:
                ax.set_axis_off()

            return ax

        # Generate arrangement (using self.lcfg, wave_nchans)
        # _axes2d[wave][chan] = Axes
        self._axes2d = self.layout.arrange(axes_factory)

        # Setup figure geometry
        self._fig.set_dpi(DPI)
        self._fig.set_size_inches(self.cfg.width / DPI, self.cfg.height / DPI)

    def render_frame(self, datas: List[np.ndarray]) -> None:
        ndata = len(datas)
        if self.nplots != ndata:
            raise ValueError(
                f"incorrect data to plot: {self.nplots} plots but {ndata} datas"
            )

        # Initialize axes and draw waveform data
        if not self._lines2d:
            assert len(datas[0].shape) == 2, datas[0].shape

            wave_nchans = [data.shape[1] for data in datas]
            self._set_layout(wave_nchans)

            cfg = self.cfg

            # Setup background/axes
            self._fig.set_facecolor(cfg.bg_color)
            for idx, wave_data in enumerate(datas):
                wave_axes = self._axes2d[idx]
                for ax in unique_by_id(wave_axes):
                    max_x = len(wave_data) - 1
                    ax.set_xlim(0, max_x)
                    ax.set_ylim(-1, 1)

                    # Setup midlines (depends on max_x and wave_data)
                    midline_color = cfg.midline_color
                    midline_width = pixels(1)

                    # zorder=-100 still draws on top of gridlines :(
                    kw = dict(color=midline_color, linewidth=midline_width)
                    if cfg.v_midline:
                        ax.axvline(x=max_x / 2, **kw)
                    if cfg.h_midline:
                        ax.axhline(y=0, **kw)

            self._save_background()

            # Plot lines over background
            line_width = pixels(cfg.line_width)

            # Foreach wave
            for wave_idx, wave_data in enumerate(datas):
                wave_axes = self._axes2d[wave_idx]
                wave_lines = []

                # Foreach chan
                for chan_idx, chan_data in enumerate(wave_data.T):
                    ax = wave_axes[chan_idx]
                    line_color = self._line_params[wave_idx].color
                    chan_line: Line2D = ax.plot(chan_data,
                                                color=line_color,
                                                linewidth=line_width)[0]
                    wave_lines.append(chan_line)

                self._lines2d.append(wave_lines)
                self._lines_flat.extend(wave_lines)

        # Draw waveform data
        else:
            # Foreach wave
            for wave_idx, wave_data in enumerate(datas):
                wave_lines = self._lines2d[wave_idx]

                # Foreach chan
                for chan_idx, chan_data in enumerate(wave_data.T):
                    chan_line = wave_lines[chan_idx]
                    chan_line.set_ydata(chan_data)

        self._redraw_over_background()

    bg_cache: Any  # "matplotlib.backends._backend_agg.BufferRegion"

    def _save_background(self) -> None:
        """ Draw static background. """
        # https://stackoverflow.com/a/8956211
        # https://matplotlib.org/api/animation_api.html#funcanimation
        fig = self._fig

        fig.canvas.draw()
        self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)

    def _redraw_over_background(self) -> None:
        """ Redraw animated elements of the image. """

        canvas: FigureCanvasAgg = self._fig.canvas
        canvas.restore_region(self.bg_cache)

        for line in self._lines_flat:
            line.axes.draw_artist(line)

        # https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
        # thinks fig.canvas.blit(ax.bbox) leaks memory
        # and fig.canvas.update() works.
        # Except I found no memory leak...
        # and update() doesn't exist in FigureCanvasBase when no GUI is present.

        canvas.blit(self._fig.bbox)

    def get_frame(self) -> ByteBuffer:
        """ Returns ndarray of shape w,h,3. """
        canvas = self._fig.canvas

        # Agg is the default noninteractive backend except on OSX.
        # https://matplotlib.org/faq/usage_faq.html
        if not isinstance(canvas, FigureCanvasAgg):
            raise RuntimeError(
                f"oh shit, cannot read data from {type(canvas)} != FigureCanvasAgg"
            )

        w = self.cfg.width
        h = self.cfg.height
        assert (w, h) == canvas.get_width_height()

        buffer_rgb = canvas.tostring_rgb()
        assert len(buffer_rgb) == w * h * RGB_DEPTH

        return buffer_rgb