class C:
        def __init__(self):
            self._attr = 1

        def _meth(self, arg):
            return arg

        attr = _api.deprecate_privatize_attribute("0.0")
        meth = _api.deprecate_privatize_attribute("0.0")
Exemple #2
0
class CbarAxesBase:
    def __init__(self, *args, orientation, **kwargs):
        self.orientation = orientation
        self._default_label_on = True
        self._locator = None  # deprecated.
        super().__init__(*args, **kwargs)

    def colorbar(self, mappable, *, ticks=None, **kwargs):

        if self.orientation in ["top", "bottom"]:
            orientation = "horizontal"
        else:
            orientation = "vertical"

        cb = mpl.colorbar.Colorbar(self,
                                   mappable,
                                   orientation=orientation,
                                   ticks=ticks,
                                   **kwargs)
        self._cbid = mappable.colorbar_cid  # deprecated in 3.3.
        self._locator = cb.locator  # deprecated in 3.3.

        self._config_axes()
        return cb

    cbid = _api.deprecate_privatize_attribute(
        "3.3", alternative="mappable.colorbar_cid")
    locator = _api.deprecate_privatize_attribute(
        "3.3", alternative=".colorbar().locator")

    def _config_axes(self):
        """Make an axes patch and outline."""
        ax = self
        ax.set_navigate(False)
        ax.axis[:].toggle(all=False)
        b = self._default_label_on
        ax.axis[self.orientation].toggle(all=b)

    def toggle_label(self, b):
        self._default_label_on = b
        axis = self.axis[self.orientation]
        axis.toggle(ticklabels=b, label=b)

    def cla(self):
        super().cla()
        self._config_axes()
Exemple #3
0
class NavigationToolbar2WebAgg(backend_bases.NavigationToolbar2):

    # Use the standard toolbar items + download button
    toolitems = [(text, tooltip_text, image_file, name_of_method)
                 for text, tooltip_text, image_file, name_of_method in (
                     *backend_bases.NavigationToolbar2.toolitems,
                     ('Download', 'Download plot', 'filesave', 'download'))
                 if name_of_method in _ALLOWED_TOOL_ITEMS]

    cursor = _api.deprecate_privatize_attribute("3.5")

    def __init__(self, canvas):
        self.message = ''
        self._cursor = None  # Remove with deprecation.
        super().__init__(canvas)

    def set_message(self, message):
        if message != self.message:
            self.canvas.send_event("message", message=message)
        self.message = message

    def draw_rubberband(self, event, x0, y0, x1, y1):
        self.canvas.send_event("rubberband", x0=x0, y0=y0, x1=x1, y1=y1)

    def release_zoom(self, event):
        super().release_zoom(event)
        self.canvas.send_event("rubberband", x0=-1, y0=-1, x1=-1, y1=-1)

    def save_figure(self, *args):
        """Save the current figure"""
        self.canvas.send_event('save')

    def pan(self):
        super().pan()
        self.canvas.send_event('navigate_mode', mode=self.mode.name)

    def zoom(self):
        super().zoom()
        self.canvas.send_event('navigate_mode', mode=self.mode.name)

    def set_history_buttons(self):
        can_backward = self._nav_stack._pos > 0
        can_forward = self._nav_stack._pos < len(self._nav_stack._elements) - 1
        self.canvas.send_event('history_buttons',
                               Back=can_backward,
                               Forward=can_forward)
Exemple #4
0
class ScalarMappable:
    """
    A mixin class to map scalar data to RGBA.

    The ScalarMappable applies data normalization before returning RGBA colors
    from the given colormap.
    """
    def __init__(self, norm=None, cmap=None):
        """

        Parameters
        ----------
        norm : `matplotlib.colors.Normalize` (or subclass thereof)
            The normalizing object which scales data, typically into the
            interval ``[0, 1]``.
            If *None*, *norm* defaults to a *colors.Normalize* object which
            initializes its scaling based on the first data processed.
        cmap : str or `~matplotlib.colors.Colormap`
            The colormap used to map normalized data values to RGBA colors.
        """
        self._A = None
        self.norm = None  # So that the setter knows we're initializing.
        self.set_norm(norm)  # The Normalize instance of this ScalarMappable.
        self.cmap = None  # So that the setter knows we're initializing.
        self.set_cmap(cmap)  # The Colormap instance of this ScalarMappable.
        #: The last colorbar associated with this ScalarMappable. May be None.
        self.colorbar = None
        self.callbacksSM = cbook.CallbackRegistry()
        self._update_dict = {'array': False}

    def _scale_norm(self, norm, vmin, vmax):
        """
        Helper for initial scaling.

        Used by public functions that create a ScalarMappable and support
        parameters *vmin*, *vmax* and *norm*. This makes sure that a *norm*
        will take precedence over *vmin*, *vmax*.

        Note that this method does not set the norm.
        """
        if vmin is not None or vmax is not None:
            self.set_clim(vmin, vmax)
            if norm is not None:
                _api.warn_deprecated(
                    "3.3",
                    message="Passing parameters norm and vmin/vmax "
                    "simultaneously is deprecated since %(since)s and "
                    "will become an error %(removal)s. Please pass "
                    "vmin/vmax directly to the norm when creating it.")

        # always resolve the autoscaling so we have concrete limits
        # rather than deferring to draw time.
        self.autoscale_None()

    def to_rgba(self, x, alpha=None, bytes=False, norm=True):
        """
        Return a normalized rgba array corresponding to *x*.

        In the normal case, *x* is a 1D or 2D sequence of scalars, and
        the corresponding ndarray of rgba values will be returned,
        based on the norm and colormap set for this ScalarMappable.

        There is one special case, for handling images that are already
        rgb or rgba, such as might have been read from an image file.
        If *x* is an ndarray with 3 dimensions,
        and the last dimension is either 3 or 4, then it will be
        treated as an rgb or rgba array, and no mapping will be done.
        The array can be uint8, or it can be floating point with
        values in the 0-1 range; otherwise a ValueError will be raised.
        If it is a masked array, the mask will be ignored.
        If the last dimension is 3, the *alpha* kwarg (defaulting to 1)
        will be used to fill in the transparency.  If the last dimension
        is 4, the *alpha* kwarg is ignored; it does not
        replace the pre-existing alpha.  A ValueError will be raised
        if the third dimension is other than 3 or 4.

        In either case, if *bytes* is *False* (default), the rgba
        array will be floats in the 0-1 range; if it is *True*,
        the returned rgba array will be uint8 in the 0 to 255 range.

        If norm is False, no normalization of the input data is
        performed, and it is assumed to be in the range (0-1).

        """
        # First check for special case, image input:
        try:
            if x.ndim == 3:
                if x.shape[2] == 3:
                    if alpha is None:
                        alpha = 1
                    if x.dtype == np.uint8:
                        alpha = np.uint8(alpha * 255)
                    m, n = x.shape[:2]
                    xx = np.empty(shape=(m, n, 4), dtype=x.dtype)
                    xx[:, :, :3] = x
                    xx[:, :, 3] = alpha
                elif x.shape[2] == 4:
                    xx = x
                else:
                    raise ValueError("Third dimension must be 3 or 4")
                if xx.dtype.kind == 'f':
                    if norm and (xx.max() > 1 or xx.min() < 0):
                        raise ValueError("Floating point image RGB values "
                                         "must be in the 0..1 range.")
                    if bytes:
                        xx = (xx * 255).astype(np.uint8)
                elif xx.dtype == np.uint8:
                    if not bytes:
                        xx = xx.astype(np.float32) / 255
                else:
                    raise ValueError("Image RGB array must be uint8 or "
                                     "floating point; found %s" % xx.dtype)
                return xx
        except AttributeError:
            # e.g., x is not an ndarray; so try mapping it
            pass

        # This is the normal case, mapping a scalar array:
        x = ma.asarray(x)
        if norm:
            x = self.norm(x)
        rgba = self.cmap(x, alpha=alpha, bytes=bytes)
        return rgba

    def set_array(self, A):
        """
        Set the image array from numpy array *A*.

        Parameters
        ----------
        A : ndarray or None
        """
        self._A = A
        self._update_dict['array'] = True

    def get_array(self):
        """Return the data array."""
        return self._A

    def get_cmap(self):
        """Return the `.Colormap` instance."""
        return self.cmap

    def get_clim(self):
        """
        Return the values (min, max) that are mapped to the colormap limits.
        """
        return self.norm.vmin, self.norm.vmax

    def set_clim(self, vmin=None, vmax=None):
        """
        Set the norm limits for image scaling.

        Parameters
        ----------
        vmin, vmax : float
             The limits.

             The limits may also be passed as a tuple (*vmin*, *vmax*) as a
             single positional argument.

             .. ACCEPTS: (vmin: float, vmax: float)
        """
        if vmax is None:
            try:
                vmin, vmax = vmin
            except (TypeError, ValueError):
                pass
        if vmin is not None:
            self.norm.vmin = colors._sanitize_extrema(vmin)
        if vmax is not None:
            self.norm.vmax = colors._sanitize_extrema(vmax)
        self.changed()

    def get_alpha(self):
        """
        Returns
        -------
        float
            Always returns 1.
        """
        # This method is intended to be overridden by Artist sub-classes
        return 1.

    def set_cmap(self, cmap):
        """
        Set the colormap for luminance data.

        Parameters
        ----------
        cmap : `.Colormap` or str or None
        """
        in_init = self.cmap is None
        cmap = get_cmap(cmap)
        self.cmap = cmap
        if not in_init:
            self.changed()  # Things are not set up properly yet.

    def set_norm(self, norm):
        """
        Set the normalization instance.

        Parameters
        ----------
        norm : `.Normalize` or None

        Notes
        -----
        If there are any colorbars using the mappable for this norm, setting
        the norm of the mappable will reset the norm, locator, and formatters
        on the colorbar to default.
        """
        _api.check_isinstance((colors.Normalize, None), norm=norm)
        in_init = self.norm is None
        if norm is None:
            norm = colors.Normalize()
        self.norm = norm
        if not in_init:
            self.changed()  # Things are not set up properly yet.

    def autoscale(self):
        """
        Autoscale the scalar limits on the norm instance using the
        current array
        """
        if self._A is None:
            raise TypeError('You must first set_array for mappable')
        self.norm.autoscale(self._A)
        self.changed()

    def autoscale_None(self):
        """
        Autoscale the scalar limits on the norm instance using the
        current array, changing only limits that are None
        """
        if self._A is None:
            raise TypeError('You must first set_array for mappable')
        self.norm.autoscale_None(self._A)
        self.changed()

    def _add_checker(self, checker):
        """
        Add an entry to a dictionary of boolean flags
        that are set to True when the mappable is changed.
        """
        self._update_dict[checker] = False

    def _check_update(self, checker):
        """Return whether mappable has changed since the last check."""
        if self._update_dict[checker]:
            self._update_dict[checker] = False
            return True
        return False

    def changed(self):
        """
        Call this whenever the mappable is changed to notify all the
        callbackSM listeners to the 'changed' signal.
        """
        self.callbacksSM.process('changed', self)
        for key in self._update_dict:
            self._update_dict[key] = True
        self.stale = True

    update_dict = _api.deprecate_privatize_attribute("3.3")
    add_checker = _api.deprecate_privatize_attribute("3.3")
    check_update = _api.deprecate_privatize_attribute("3.3")
Exemple #5
0
class FloatingAxisArtistHelper(AxisArtistHelper.Floating):
    grid_info = _api.deprecate_privatize_attribute("3.5")

    def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
        """
        nth_coord = along which coordinate value varies.
         nth_coord = 0 ->  x axis, nth_coord = 1 -> y axis
        """
        super().__init__(nth_coord, value)
        self.value = value
        self.grid_helper = grid_helper
        self._extremes = -np.inf, np.inf
        self._line_num_points = 100  # number of points to create a line

    def set_extremes(self, e1, e2):
        if e1 is None:
            e1 = -np.inf
        if e2 is None:
            e2 = np.inf
        self._extremes = e1, e2

    def update_lim(self, axes):
        self.grid_helper.update_lim(axes)

        x1, x2 = axes.get_xlim()
        y1, y2 = axes.get_ylim()
        grid_finder = self.grid_helper.grid_finder
        extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, x1,
                                              y1, x2, y2)

        lon_min, lon_max, lat_min, lat_max = extremes
        e_min, e_max = self._extremes  # ranges of other coordinates
        if self.nth_coord == 0:
            lat_min = max(e_min, lat_min)
            lat_max = min(e_max, lat_max)
        elif self.nth_coord == 1:
            lon_min = max(e_min, lon_min)
            lon_max = min(e_max, lon_max)

        lon_levs, lon_n, lon_factor = \
            grid_finder.grid_locator1(lon_min, lon_max)
        lat_levs, lat_n, lat_factor = \
            grid_finder.grid_locator2(lat_min, lat_max)

        if self.nth_coord == 0:
            xx0 = np.full(self._line_num_points, self.value)
            yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
            xx, yy = grid_finder.transform_xy(xx0, yy0)
        elif self.nth_coord == 1:
            xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
            yy0 = np.full(self._line_num_points, self.value)
            xx, yy = grid_finder.transform_xy(xx0, yy0)

        self._grid_info = {
            "extremes": (lon_min, lon_max, lat_min, lat_max),
            "lon_info": (lon_levs, lon_n, lon_factor),
            "lat_info": (lat_levs, lat_n, lat_factor),
            "lon_labels":
            grid_finder.tick_formatter1("bottom", lon_factor, lon_levs),
            "lat_labels":
            grid_finder.tick_formatter2("bottom", lat_factor, lat_levs),
            "line_xy": (xx, yy),
        }

    def get_axislabel_transform(self, axes):
        return Affine2D()  # axes.transData

    def get_axislabel_pos_angle(self, axes):

        extremes = self._grid_info["extremes"]

        if self.nth_coord == 0:
            xx0 = self.value
            yy0 = (extremes[2] + extremes[3]) / 2
            dxx = 0
            dyy = abs(extremes[2] - extremes[3]) / 1000
        elif self.nth_coord == 1:
            xx0 = (extremes[0] + extremes[1]) / 2
            yy0 = self.value
            dxx = abs(extremes[0] - extremes[1]) / 1000
            dyy = 0

        grid_finder = self.grid_helper.grid_finder
        (xx1, ), (yy1, ) = grid_finder.transform_xy([xx0], [yy0])

        data_to_axes = axes.transData - axes.transAxes
        p = data_to_axes.transform([xx1, yy1])

        if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
            xx1c, yy1c = axes.transData.transform([xx1, yy1])
            (xx2, ), (yy2, ) = grid_finder.transform_xy([xx0 + dxx],
                                                        [yy0 + dyy])
            xx2c, yy2c = axes.transData.transform([xx2, yy2])
            return (xx1c,
                    yy1c), np.rad2deg(np.arctan2(yy2c - yy1c, xx2c - xx1c))
        else:
            return None, None

    def get_tick_transform(self, axes):
        return IdentityTransform()  # axes.transData

    def get_tick_iterators(self, axes):
        """tick_loc, tick_angle, tick_label, (optionally) tick_label"""

        grid_finder = self.grid_helper.grid_finder

        lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
        lat_levs = np.asarray(lat_levs)
        yy0 = lat_levs / lat_factor
        dy = 0.01 / lat_factor

        lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
        lon_levs = np.asarray(lon_levs)
        xx0 = lon_levs / lon_factor
        dx = 0.01 / lon_factor

        e0, e1 = self._extremes

        if self.nth_coord == 0:
            mask = (e0 <= yy0) & (yy0 <= e1)
            #xx0, yy0 = xx0[mask], yy0[mask]
            yy0 = yy0[mask]
        elif self.nth_coord == 1:
            mask = (e0 <= xx0) & (xx0 <= e1)
            #xx0, yy0 = xx0[mask], yy0[mask]
            xx0 = xx0[mask]

        def transform_xy(x, y):
            x1, y1 = grid_finder.transform_xy(x, y)
            x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
            x2, y2 = x2y2.transpose()
            return x2, y2

        # find angles
        if self.nth_coord == 0:
            xx0 = np.full_like(yy0, self.value)

            xx1, yy1 = transform_xy(xx0, yy0)

            xx00 = xx0.copy()
            xx00[xx0 + dx > e1] -= dx
            xx1a, yy1a = transform_xy(xx00, yy0)
            xx1b, yy1b = transform_xy(xx00 + dx, yy0)

            xx2a, yy2a = transform_xy(xx0, yy0)
            xx2b, yy2b = transform_xy(xx0, yy0 + dy)

            labels = self._grid_info["lat_labels"]
            labels = [l for l, m in zip(labels, mask) if m]

        elif self.nth_coord == 1:
            yy0 = np.full_like(xx0, self.value)

            xx1, yy1 = transform_xy(xx0, yy0)

            xx1a, yy1a = transform_xy(xx0, yy0)
            xx1b, yy1b = transform_xy(xx0, yy0 + dy)

            xx00 = xx0.copy()
            xx00[xx0 + dx > e1] -= dx
            xx2a, yy2a = transform_xy(xx00, yy0)
            xx2b, yy2b = transform_xy(xx00 + dx, yy0)

            labels = self._grid_info["lon_labels"]
            labels = [l for l, m in zip(labels, mask) if m]

        def f1():
            dd = np.arctan2(yy1b - yy1a, xx1b - xx1a)  # angle normal
            dd2 = np.arctan2(yy2b - yy2a, xx2b - xx2a)  # angle tangent
            mm = (yy1b == yy1a) & (xx1b == xx1a)  # mask where dd not defined
            dd[mm] = dd2[mm] + np.pi / 2

            tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
            for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
                c2 = tick_to_axes.transform((x, y))
                delta = 0.00001
                if 0 - delta <= c2[0] <= 1 + delta and 0 - delta <= c2[
                        1] <= 1 + delta:
                    d1, d2 = np.rad2deg([d, d2])
                    yield [x, y], d1, d2, lab

        return f1(), iter([])

    def get_line_transform(self, axes):
        return axes.transData

    def get_line(self, axes):
        self.update_lim(axes)
        x, y = self._grid_info["line_xy"]
        return Path(np.column_stack([x, y]))
Exemple #6
0
class GridHelperCurveLinear(GridHelperBase):
    grid_info = _api.deprecate_privatize_attribute("3.5")

    def __init__(self,
                 aux_trans,
                 extreme_finder=None,
                 grid_locator1=None,
                 grid_locator2=None,
                 tick_formatter1=None,
                 tick_formatter2=None):
        """
        aux_trans : a transform from the source (curved) coordinate to
        target (rectilinear) coordinate. An instance of MPL's Transform
        (inverse transform should be defined) or a tuple of two callable
        objects which defines the transform and its inverse. The callables
        need take two arguments of array of source coordinates and
        should return two target coordinates.

        e.g., ``x2, y2 = trans(x1, y1)``
        """
        super().__init__()
        self._grid_info = None
        self._aux_trans = aux_trans
        self.grid_finder = GridFinder(aux_trans, extreme_finder, grid_locator1,
                                      grid_locator2, tick_formatter1,
                                      tick_formatter2)

    def update_grid_finder(self, aux_trans=None, **kw):
        if aux_trans is not None:
            self.grid_finder.update_transform(aux_trans)
        self.grid_finder.update(**kw)
        self._old_limits = None  # Force revalidation.

    def new_fixed_axis(self,
                       loc,
                       nth_coord=None,
                       axis_direction=None,
                       offset=None,
                       axes=None):
        if axes is None:
            axes = self.axes
        if axis_direction is None:
            axis_direction = loc
        _helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
        axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
        # Why is clip not set on axisline, unlike in new_floating_axis or in
        # the floating_axig.GridHelperCurveLinear subclass?
        return axisline

    def new_floating_axis(self,
                          nth_coord,
                          value,
                          axes=None,
                          axis_direction="bottom"):

        if axes is None:
            axes = self.axes

        _helper = FloatingAxisArtistHelper(self, nth_coord, value,
                                           axis_direction)

        axisline = AxisArtist(axes, _helper)

        # _helper = FloatingAxisArtistHelper(self, nth_coord,
        #                                    value,
        #                                    label_direction=label_direction,
        #                                    )

        # axisline = AxisArtistFloating(axes, _helper,
        #                               axis_direction=axis_direction)
        axisline.line.set_clip_on(True)
        axisline.line.set_clip_box(axisline.axes.bbox)
        # axisline.major_ticklabels.set_visible(True)
        # axisline.minor_ticklabels.set_visible(False)

        return axisline

    def _update_grid(self, x1, y1, x2, y2):
        self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)

    def get_gridlines(self, which="major", axis="both"):
        grid_lines = []
        if axis in ["both", "x"]:
            for gl in self._grid_info["lon"]["lines"]:
                grid_lines.extend(gl)
        if axis in ["both", "y"]:
            for gl in self._grid_info["lat"]["lines"]:
                grid_lines.extend(gl)
        return grid_lines

    def get_tick_iterator(self, nth_coord, axis_side, minor=False):

        # axisnr = dict(left=0, bottom=1, right=2, top=3)[axis_side]
        angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
        # angle = [0, 90, 180, 270][axisnr]
        lon_or_lat = ["lon", "lat"][nth_coord]
        if not minor:  # major ticks
            for (xy, a), l in zip(
                    self._grid_info[lon_or_lat]["tick_locs"][axis_side],
                    self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
                angle_normal = a
                yield xy, angle_normal, angle_tangent, l
        else:
            for (xy, a), l in zip(
                    self._grid_info[lon_or_lat]["tick_locs"][axis_side],
                    self._grid_info[lon_or_lat]["tick_labels"][axis_side]):
                angle_normal = a
                yield xy, angle_normal, angle_tangent, ""
Exemple #7
0
class ToolViewsPositions(ToolBase):
    """
    Auxiliary Tool to handle changes in views and positions.

    Runs in the background and should get used by all the tools that
    need to access the figure's history of views and positions, e.g.

    * `ToolZoom`
    * `ToolPan`
    * `ToolHome`
    * `ToolBack`
    * `ToolForward`
    """
    def __init__(self, *args, **kwargs):
        self.views = WeakKeyDictionary()
        self.positions = WeakKeyDictionary()
        self.home_views = WeakKeyDictionary()
        super().__init__(*args, **kwargs)

    def add_figure(self, figure):
        """Add the current figure to the stack of views and positions."""

        if figure not in self.views:
            self.views[figure] = cbook.Stack()
            self.positions[figure] = cbook.Stack()
            self.home_views[figure] = WeakKeyDictionary()
            # Define Home
            self.push_current(figure)
            # Make sure we add a home view for new axes as they're added
            figure.add_axobserver(lambda fig: self.update_home_views(fig))

    def clear(self, figure):
        """Reset the axes stack."""
        if figure in self.views:
            self.views[figure].clear()
            self.positions[figure].clear()
            self.home_views[figure].clear()
            self.update_home_views()

    def update_view(self):
        """
        Update the view limits and position for each axes from the current
        stack position. If any axes are present in the figure that aren't in
        the current stack position, use the home view limits for those axes and
        don't update *any* positions.
        """

        views = self.views[self.figure]()
        if views is None:
            return
        pos = self.positions[self.figure]()
        if pos is None:
            return
        home_views = self.home_views[self.figure]
        all_axes = self.figure.get_axes()
        for a in all_axes:
            if a in views:
                cur_view = views[a]
            else:
                cur_view = home_views[a]
            a._set_view(cur_view)

        if set(all_axes).issubset(pos):
            for a in all_axes:
                # Restore both the original and modified positions
                a._set_position(pos[a][0], 'original')
                a._set_position(pos[a][1], 'active')

        self.figure.canvas.draw_idle()

    def push_current(self, figure=None):
        """
        Push the current view limits and position onto their respective stacks.
        """
        if not figure:
            figure = self.figure
        views = WeakKeyDictionary()
        pos = WeakKeyDictionary()
        for a in figure.get_axes():
            views[a] = a._get_view()
            pos[a] = self._axes_pos(a)
        self.views[figure].push(views)
        self.positions[figure].push(pos)

    def _axes_pos(self, ax):
        """
        Return the original and modified positions for the specified axes.

        Parameters
        ----------
        ax : matplotlib.axes.Axes
            The `.Axes` to get the positions for.

        Returns
        -------
        original_position, modified_position
            A tuple of the original and modified positions.
        """

        return (ax.get_position(True).frozen(), ax.get_position().frozen())

    def update_home_views(self, figure=None):
        """
        Make sure that ``self.home_views`` has an entry for all axes present
        in the figure.
        """

        if not figure:
            figure = self.figure
        for a in figure.get_axes():
            if a not in self.home_views[figure]:
                self.home_views[figure][a] = a._get_view()

    # Can be removed once Locator.refresh() is removed, and replaced by an
    # inline call to self.figure.canvas.draw_idle().
    def _refresh_locators(self):
        for a in self.figure.get_axes():
            xaxis = getattr(a, 'xaxis', None)
            yaxis = getattr(a, 'yaxis', None)
            zaxis = getattr(a, 'zaxis', None)
            locators = []
            if xaxis is not None:
                locators.append(xaxis.get_major_locator())
                locators.append(xaxis.get_minor_locator())
            if yaxis is not None:
                locators.append(yaxis.get_major_locator())
                locators.append(yaxis.get_minor_locator())
            if zaxis is not None:
                locators.append(zaxis.get_major_locator())
                locators.append(zaxis.get_minor_locator())

            for loc in locators:
                mpl.ticker._if_refresh_overridden_call_and_emit_deprec(loc)
        self.figure.canvas.draw_idle()

    refresh_locators = _api.deprecate_privatize_attribute(
        "3.3", alternative="self.figure.canvas.draw_idle()")

    def home(self):
        """Recall the first view and position from the stack."""
        self.views[self.figure].home()
        self.positions[self.figure].home()

    def back(self):
        """Back one step in the stack of views and positions."""
        self.views[self.figure].back()
        self.positions[self.figure].back()

    def forward(self):
        """Forward one step in the stack of views and positions."""
        self.views[self.figure].forward()
        self.positions[self.figure].forward()
Exemple #8
0
class TexManager:
    """
    Convert strings to dvi files using TeX, caching the results to a directory.

    Repeated calls to this constructor always return the same instance.
    """

    texcache = os.path.join(mpl.get_cachedir(), 'tex.cache')

    _grey_arrayd = {}
    _font_family = 'serif'
    _font_families = ('serif', 'sans-serif', 'cursive', 'monospace')
    _font_info = {
        'new century schoolbook': ('pnc', r'\renewcommand{\rmdefault}{pnc}'),
        'bookman': ('pbk', r'\renewcommand{\rmdefault}{pbk}'),
        'times': ('ptm', r'\usepackage{mathptmx}'),
        'palatino': ('ppl', r'\usepackage{mathpazo}'),
        'zapf chancery': ('pzc', r'\usepackage{chancery}'),
        'cursive': ('pzc', r'\usepackage{chancery}'),
        'charter': ('pch', r'\usepackage{charter}'),
        'serif': ('cmr', ''),
        'sans-serif': ('cmss', ''),
        'helvetica': ('phv', r'\usepackage{helvet}'),
        'avant garde': ('pag', r'\usepackage{avant}'),
        'courier': ('pcr', r'\usepackage{courier}'),
        # Loading the type1ec package ensures that cm-super is installed, which
        # is necessary for unicode computer modern.  (It also allows the use of
        # computer modern at arbitrary sizes, but that's just a side effect.)
        'monospace': ('cmtt', r'\usepackage{type1ec}'),
        'computer modern roman': ('cmr', r'\usepackage{type1ec}'),
        'computer modern sans serif': ('cmss', r'\usepackage{type1ec}'),
        'computer modern typewriter': ('cmtt', r'\usepackage{type1ec}')
    }
    _font_types = {
        'new century schoolbook': 'serif',
        'bookman': 'serif',
        'times': 'serif',
        'palatino': 'serif',
        'charter': 'serif',
        'computer modern roman': 'serif',
        'zapf chancery': 'cursive',
        'helvetica': 'sans-serif',
        'avant garde': 'sans-serif',
        'computer modern sans serif': 'sans-serif',
        'courier': 'monospace',
        'computer modern typewriter': 'monospace'
    }

    grey_arrayd = _api.deprecate_privatize_attribute("3.5")
    font_family = _api.deprecate_privatize_attribute("3.5")
    font_families = _api.deprecate_privatize_attribute("3.5")
    font_info = _api.deprecate_privatize_attribute("3.5")

    @functools.lru_cache()  # Always return the same instance.
    def __new__(cls):
        Path(cls.texcache).mkdir(parents=True, exist_ok=True)
        return object.__new__(cls)

    def get_font_config(self):
        ff = rcParams['font.family']
        ff_val = ff[0].lower() if len(ff) == 1 else None
        reduced_notation = False
        if len(ff) == 1 and ff_val in self._font_families:
            self._font_family = ff_val
        elif len(ff) == 1 and ff_val in self._font_info:
            reduced_notation = True
            self._font_family = self._font_types[ff_val]
        else:
            _log.info(
                'font.family must be one of (%s) when text.usetex is '
                'True. serif will be used by default.',
                ', '.join(self._font_families))
            self._font_family = 'serif'

        fontconfig = [self._font_family]
        fonts = {}
        for font_family in self._font_families:
            if reduced_notation and self._font_family == font_family:
                fonts[font_family] = self._font_info[ff_val]
            else:
                for font in rcParams['font.' + font_family]:
                    if font.lower() in self._font_info:
                        fonts[font_family] = self._font_info[font.lower()]
                        _log.debug('family: %s, font: %s, info: %s',
                                   font_family, font,
                                   self._font_info[font.lower()])
                        break
                    else:
                        _log.debug('%s font is not compatible with usetex.',
                                   font)
                else:
                    _log.info(
                        'No LaTeX-compatible font found for the %s font'
                        'family in rcParams. Using default.', font_family)
                    fonts[font_family] = self._font_info[font_family]
            fontconfig.append(fonts[font_family][0])
        # Add a hash of the latex preamble to fontconfig so that the
        # correct png is selected for strings rendered with same font and dpi
        # even if the latex preamble changes within the session
        preamble_bytes = self.get_custom_preamble().encode('utf-8')
        fontconfig.append(hashlib.md5(preamble_bytes).hexdigest())

        # The following packages and commands need to be included in the latex
        # file's preamble:
        cmd = {
            fonts[family][1]
            for family in ['serif', 'sans-serif', 'monospace']
        }
        if self._font_family == 'cursive':
            cmd.add(fonts['cursive'][1])
        cmd.add(r'\usepackage{type1cm}')
        self._font_preamble = '\n'.join(sorted(cmd))

        return ''.join(fontconfig)

    def get_basefile(self, tex, fontsize, dpi=None):
        """
        Return a filename based on a hash of the string, fontsize, and dpi.
        """
        s = ''.join([
            tex,
            self.get_font_config(),
            '%f' % fontsize,
            self.get_custom_preamble(),
            str(dpi or '')
        ])
        return os.path.join(self.texcache,
                            hashlib.md5(s.encode('utf-8')).hexdigest())

    def get_font_preamble(self):
        """
        Return a string containing font configuration for the tex preamble.
        """
        return self._font_preamble

    def get_custom_preamble(self):
        """Return a string containing user additions to the tex preamble."""
        return rcParams['text.latex.preamble']

    def _get_preamble(self):
        return "\n".join([
            r"\documentclass{article}",
            # Pass-through \mathdefault, which is used in non-usetex mode to
            # use the default text font but was historically suppressed in
            # usetex mode.
            r"\newcommand{\mathdefault}[1]{#1}",
            self._font_preamble,
            r"\usepackage[utf8]{inputenc}",
            r"\DeclareUnicodeCharacter{2212}{\ensuremath{-}}",
            # geometry is loaded before the custom preamble as convert_psfrags
            # relies on a custom preamble to change the geometry.
            r"\usepackage[papersize=72in, margin=1in]{geometry}",
            self.get_custom_preamble(),
            # textcomp is loaded last (if not already loaded by the custom
            # preamble) in order not to clash with custom packages (e.g.
            # newtxtext) which load it with different options.
            r"\makeatletter"
            r"\@ifpackageloaded{textcomp}{}{\usepackage{textcomp}}"
            r"\makeatother",
        ])

    def make_tex(self, tex, fontsize):
        """
        Generate a tex file to render the tex string at a specific font size.

        Return the file name.
        """
        basefile = self.get_basefile(tex, fontsize)
        texfile = '%s.tex' % basefile
        fontcmd = {
            'sans-serif': r'{\sffamily %s}',
            'monospace': r'{\ttfamily %s}'
        }.get(self._font_family, r'{\rmfamily %s}')

        Path(texfile).write_text(r"""
%s
\pagestyle{empty}
\begin{document}
%% The empty hbox ensures that a page is printed even for empty inputs, except
%% when using psfrag which gets confused by it.
\fontsize{%f}{%f}%%
\ifdefined\psfrag\else\hbox{}\fi%%
%s
\end{document}
""" % (self._get_preamble(), fontsize, fontsize * 1.25, fontcmd % tex),
                                 encoding='utf-8')

        return texfile

    def _run_checked_subprocess(self, command, tex, *, cwd=None):
        _log.debug(cbook._pformat_subprocess(command))
        try:
            report = subprocess.check_output(
                command,
                cwd=cwd if cwd is not None else self.texcache,
                stderr=subprocess.STDOUT)
        except FileNotFoundError as exc:
            raise RuntimeError(
                'Failed to process string with tex because {} could not be '
                'found'.format(command[0])) from exc
        except subprocess.CalledProcessError as exc:
            raise RuntimeError(
                '{prog} was not able to process the following string:\n'
                '{tex!r}\n\n'
                'Here is the full report generated by {prog}:\n'
                '{exc}\n\n'.format(prog=command[0],
                                   tex=tex.encode('unicode_escape'),
                                   exc=exc.output.decode('utf-8'))) from exc
        _log.debug(report)
        return report

    def make_dvi(self, tex, fontsize):
        """
        Generate a dvi file containing latex's layout of tex string.

        Return the file name.
        """
        basefile = self.get_basefile(tex, fontsize)
        dvifile = '%s.dvi' % basefile
        if not os.path.exists(dvifile):
            texfile = self.make_tex(tex, fontsize)
            # Generate the dvi in a temporary directory to avoid race
            # conditions e.g. if multiple processes try to process the same tex
            # string at the same time.  Having tmpdir be a subdirectory of the
            # final output dir ensures that they are on the same filesystem,
            # and thus replace() works atomically.
            with TemporaryDirectory(dir=Path(dvifile).parent) as tmpdir:
                self._run_checked_subprocess([
                    "latex", "-interaction=nonstopmode", "--halt-on-error",
                    texfile
                ],
                                             tex,
                                             cwd=tmpdir)
                (Path(tmpdir) / Path(dvifile).name).replace(dvifile)
        return dvifile

    def make_png(self, tex, fontsize, dpi):
        """
        Generate a png file containing latex's rendering of tex string.

        Return the file name.
        """
        basefile = self.get_basefile(tex, fontsize, dpi)
        pngfile = '%s.png' % basefile
        # see get_rgba for a discussion of the background
        if not os.path.exists(pngfile):
            dvifile = self.make_dvi(tex, fontsize)
            cmd = [
                "dvipng", "-bg", "Transparent", "-D",
                str(dpi), "-T", "tight", "-o", pngfile, dvifile
            ]
            # When testing, disable FreeType rendering for reproducibility; but
            # dvipng 1.16 has a bug (fixed in f3ff241) that breaks --freetype0
            # mode, so for it we keep FreeType enabled; the image will be
            # slightly off.
            bad_ver = parse_version("1.16")
            if (getattr(mpl, "_called_from_pytest", False)
                    and mpl._get_executable_info("dvipng").version != bad_ver):
                cmd.insert(1, "--freetype0")
            self._run_checked_subprocess(cmd, tex)
        return pngfile

    def get_grey(self, tex, fontsize=None, dpi=None):
        """Return the alpha channel."""
        if not fontsize:
            fontsize = rcParams['font.size']
        if not dpi:
            dpi = rcParams['savefig.dpi']
        key = tex, self.get_font_config(), fontsize, dpi
        alpha = self._grey_arrayd.get(key)
        if alpha is None:
            pngfile = self.make_png(tex, fontsize, dpi)
            rgba = mpl.image.imread(os.path.join(self.texcache, pngfile))
            self._grey_arrayd[key] = alpha = rgba[:, :, -1]
        return alpha

    def get_rgba(self, tex, fontsize=None, dpi=None, rgb=(0, 0, 0)):
        r"""
        Return latex's rendering of the tex string as an rgba array.

        Examples
        --------
        >>> texmanager = TexManager()
        >>> s = r"\TeX\ is $\displaystyle\sum_n\frac{-e^{i\pi}}{2^n}$!"
        >>> Z = texmanager.get_rgba(s, fontsize=12, dpi=80, rgb=(1, 0, 0))
        """
        alpha = self.get_grey(tex, fontsize, dpi)
        rgba = np.empty((*alpha.shape, 4))
        rgba[..., :3] = mpl.colors.to_rgb(rgb)
        rgba[..., -1] = alpha
        return rgba

    def get_text_width_height_descent(self, tex, fontsize, renderer=None):
        """Return width, height and descent of the text."""
        if tex.strip() == '':
            return 0, 0, 0
        dvifile = self.make_dvi(tex, fontsize)
        dpi_fraction = renderer.points_to_pixels(1.) if renderer else 1
        with dviread.Dvi(dvifile, 72 * dpi_fraction) as dvi:
            page, = dvi
        # A total height (including the descent) needs to be returned.
        return page.width, page.height + page.descent, page.descent
Exemple #9
0
class TexManager:
    """
    Convert strings to dvi files using TeX, caching the results to a directory.

    Repeated calls to this constructor always return the same instance.
    """

    texcache = os.path.join(mpl.get_cachedir(), 'tex.cache')
    _grey_arrayd = {}

    _font_families = ('serif', 'sans-serif', 'cursive', 'monospace')
    _font_preambles = {
        'new century schoolbook': r'\renewcommand{\rmdefault}{pnc}',
        'bookman': r'\renewcommand{\rmdefault}{pbk}',
        'times': r'\usepackage{mathptmx}',
        'palatino': r'\usepackage{mathpazo}',
        'zapf chancery': r'\usepackage{chancery}',
        'cursive': r'\usepackage{chancery}',
        'charter': r'\usepackage{charter}',
        'serif': '',
        'sans-serif': '',
        'helvetica': r'\usepackage{helvet}',
        'avant garde': r'\usepackage{avant}',
        'courier': r'\usepackage{courier}',
        # Loading the type1ec package ensures that cm-super is installed, which
        # is necessary for Unicode computer modern.  (It also allows the use of
        # computer modern at arbitrary sizes, but that's just a side effect.)
        'monospace': r'\usepackage{type1ec}',
        'computer modern roman': r'\usepackage{type1ec}',
        'computer modern sans serif': r'\usepackage{type1ec}',
        'computer modern typewriter': r'\usepackage{type1ec}',
    }
    _font_types = {
        'new century schoolbook': 'serif',
        'bookman': 'serif',
        'times': 'serif',
        'palatino': 'serif',
        'zapf chancery': 'cursive',
        'charter': 'serif',
        'helvetica': 'sans-serif',
        'avant garde': 'sans-serif',
        'courier': 'monospace',
        'computer modern roman': 'serif',
        'computer modern sans serif': 'sans-serif',
        'computer modern typewriter': 'monospace',
    }

    grey_arrayd = _api.deprecate_privatize_attribute("3.5")
    font_family = _api.deprecate_privatize_attribute("3.5")
    font_families = _api.deprecate_privatize_attribute("3.5")
    font_info = _api.deprecate_privatize_attribute("3.5")

    @functools.lru_cache()  # Always return the same instance.
    def __new__(cls):
        Path(cls.texcache).mkdir(parents=True, exist_ok=True)
        return object.__new__(cls)

    @_api.deprecated("3.6")
    def get_font_config(self):
        preamble, font_cmd = self._get_font_preamble_and_command()
        # Add a hash of the latex preamble to fontconfig so that the
        # correct png is selected for strings rendered with same font and dpi
        # even if the latex preamble changes within the session
        preambles = preamble + font_cmd + self.get_custom_preamble()
        return hashlib.md5(preambles.encode('utf-8')).hexdigest()

    @classmethod
    def _get_font_family_and_reduced(cls):
        """Return the font family name and whether the font is reduced."""
        ff = rcParams['font.family']
        ff_val = ff[0].lower() if len(ff) == 1 else None
        if len(ff) == 1 and ff_val in cls._font_families:
            return ff_val, False
        elif len(ff) == 1 and ff_val in cls._font_preambles:
            return cls._font_types[ff_val], True
        else:
            _log.info('font.family must be one of (%s) when text.usetex is '
                      'True. serif will be used by default.',
                      ', '.join(cls._font_families))
            return 'serif', False

    @classmethod
    def _get_font_preamble_and_command(cls):
        requested_family, is_reduced_font = cls._get_font_family_and_reduced()

        preambles = {}
        for font_family in cls._font_families:
            if is_reduced_font and font_family == requested_family:
                preambles[font_family] = cls._font_preambles[
                    rcParams['font.family'][0].lower()]
            else:
                for font in rcParams['font.' + font_family]:
                    if font.lower() in cls._font_preambles:
                        preambles[font_family] = \
                            cls._font_preambles[font.lower()]
                        _log.debug(
                            'family: %s, font: %s, info: %s',
                            font_family, font,
                            cls._font_preambles[font.lower()])
                        break
                    else:
                        _log.debug('%s font is not compatible with usetex.',
                                   font)
                else:
                    _log.info('No LaTeX-compatible font found for the %s font'
                              'family in rcParams. Using default.',
                              font_family)
                    preambles[font_family] = cls._font_preambles[font_family]

        # The following packages and commands need to be included in the latex
        # file's preamble:
        cmd = {preambles[family]
               for family in ['serif', 'sans-serif', 'monospace']}
        if requested_family == 'cursive':
            cmd.add(preambles['cursive'])
        cmd.add(r'\usepackage{type1cm}')
        preamble = '\n'.join(sorted(cmd))
        fontcmd = (r'\sffamily' if requested_family == 'sans-serif' else
                   r'\ttfamily' if requested_family == 'monospace' else
                   r'\rmfamily')
        return preamble, fontcmd

    @classmethod
    def get_basefile(cls, tex, fontsize, dpi=None):
        """
        Return a filename based on a hash of the string, fontsize, and dpi.
        """
        src = cls._get_tex_source(tex, fontsize) + str(dpi)
        return os.path.join(
            cls.texcache, hashlib.md5(src.encode('utf-8')).hexdigest())

    @classmethod
    def get_font_preamble(cls):
        """
        Return a string containing font configuration for the tex preamble.
        """
        font_preamble, command = cls._get_font_preamble_and_command()
        return font_preamble

    @classmethod
    def get_custom_preamble(cls):
        """Return a string containing user additions to the tex preamble."""
        return rcParams['text.latex.preamble']

    @classmethod
    def _get_tex_source(cls, tex, fontsize):
        """Return the complete TeX source for processing a TeX string."""
        font_preamble, fontcmd = cls._get_font_preamble_and_command()
        baselineskip = 1.25 * fontsize
        return "\n".join([
            r"\documentclass{article}",
            r"% Pass-through \mathdefault, which is used in non-usetex mode",
            r"% to use the default text font but was historically suppressed",
            r"% in usetex mode.",
            r"\newcommand{\mathdefault}[1]{#1}",
            font_preamble,
            r"\usepackage[utf8]{inputenc}",
            r"\DeclareUnicodeCharacter{2212}{\ensuremath{-}}",
            r"% geometry is loaded before the custom preamble as ",
            r"% convert_psfrags relies on a custom preamble to change the ",
            r"% geometry.",
            r"\usepackage[papersize=72in, margin=1in]{geometry}",
            cls.get_custom_preamble(),
            r"% Use `underscore` package to take care of underscores in text.",
            r"% The [strings] option allows to use underscores in file names.",
            _usepackage_if_not_loaded("underscore", option="strings"),
            r"% Custom packages (e.g. newtxtext) may already have loaded ",
            r"% textcomp with different options.",
            _usepackage_if_not_loaded("textcomp"),
            r"\pagestyle{empty}",
            r"\begin{document}",
            r"% The empty hbox ensures that a page is printed even for empty",
            r"% inputs, except when using psfrag which gets confused by it.",
            r"% matplotlibbaselinemarker is used by dviread to detect the",
            r"% last line's baseline.",
            rf"\fontsize{{{fontsize}}}{{{baselineskip}}}%",
            r"\ifdefined\psfrag\else\hbox{}\fi%",
            rf"{{\obeylines{fontcmd} {tex}}}%",
            r"\special{matplotlibbaselinemarker}%",
            r"\end{document}",
        ])

    @classmethod
    def make_tex(cls, tex, fontsize):
        """
        Generate a tex file to render the tex string at a specific font size.

        Return the file name.
        """
        texfile = cls.get_basefile(tex, fontsize) + ".tex"
        Path(texfile).write_text(cls._get_tex_source(tex, fontsize))
        return texfile

    @classmethod
    def _run_checked_subprocess(cls, command, tex, *, cwd=None):
        _log.debug(cbook._pformat_subprocess(command))
        try:
            report = subprocess.check_output(
                command, cwd=cwd if cwd is not None else cls.texcache,
                stderr=subprocess.STDOUT)
        except FileNotFoundError as exc:
            raise RuntimeError(
                'Failed to process string with tex because {} could not be '
                'found'.format(command[0])) from exc
        except subprocess.CalledProcessError as exc:
            raise RuntimeError(
                '{prog} was not able to process the following string:\n'
                '{tex!r}\n\n'
                'Here is the full command invocation and its output:\n\n'
                '{format_command}\n\n'
                '{exc}\n\n'.format(
                    prog=command[0],
                    format_command=cbook._pformat_subprocess(command),
                    tex=tex.encode('unicode_escape'),
                    exc=exc.output.decode('utf-8'))) from None
        _log.debug(report)
        return report

    @classmethod
    def make_dvi(cls, tex, fontsize):
        """
        Generate a dvi file containing latex's layout of tex string.

        Return the file name.
        """
        basefile = cls.get_basefile(tex, fontsize)
        dvifile = '%s.dvi' % basefile
        if not os.path.exists(dvifile):
            texfile = Path(cls.make_tex(tex, fontsize))
            # Generate the dvi in a temporary directory to avoid race
            # conditions e.g. if multiple processes try to process the same tex
            # string at the same time.  Having tmpdir be a subdirectory of the
            # final output dir ensures that they are on the same filesystem,
            # and thus replace() works atomically.  It also allows referring to
            # the texfile with a relative path (for pathological MPLCONFIGDIRs,
            # the absolute path may contain characters (e.g. ~) that TeX does
            # not support.)
            with TemporaryDirectory(dir=Path(dvifile).parent) as tmpdir:
                cls._run_checked_subprocess(
                    ["latex", "-interaction=nonstopmode", "--halt-on-error",
                     f"../{texfile.name}"], tex, cwd=tmpdir)
                (Path(tmpdir) / Path(dvifile).name).replace(dvifile)
        return dvifile

    @classmethod
    def make_png(cls, tex, fontsize, dpi):
        """
        Generate a png file containing latex's rendering of tex string.

        Return the file name.
        """
        basefile = cls.get_basefile(tex, fontsize, dpi)
        pngfile = '%s.png' % basefile
        # see get_rgba for a discussion of the background
        if not os.path.exists(pngfile):
            dvifile = cls.make_dvi(tex, fontsize)
            cmd = ["dvipng", "-bg", "Transparent", "-D", str(dpi),
                   "-T", "tight", "-o", pngfile, dvifile]
            # When testing, disable FreeType rendering for reproducibility; but
            # dvipng 1.16 has a bug (fixed in f3ff241) that breaks --freetype0
            # mode, so for it we keep FreeType enabled; the image will be
            # slightly off.
            if (getattr(mpl, "_called_from_pytest", False) and
                    mpl._get_executable_info("dvipng").raw_version != "1.16"):
                cmd.insert(1, "--freetype0")
            cls._run_checked_subprocess(cmd, tex)
        return pngfile

    @classmethod
    def get_grey(cls, tex, fontsize=None, dpi=None):
        """Return the alpha channel."""
        if not fontsize:
            fontsize = rcParams['font.size']
        if not dpi:
            dpi = rcParams['savefig.dpi']
        key = cls._get_tex_source(tex, fontsize), dpi
        alpha = cls._grey_arrayd.get(key)
        if alpha is None:
            pngfile = cls.make_png(tex, fontsize, dpi)
            rgba = mpl.image.imread(os.path.join(cls.texcache, pngfile))
            cls._grey_arrayd[key] = alpha = rgba[:, :, -1]
        return alpha

    @classmethod
    def get_rgba(cls, tex, fontsize=None, dpi=None, rgb=(0, 0, 0)):
        r"""
        Return latex's rendering of the tex string as an rgba array.

        Examples
        --------
        >>> texmanager = TexManager()
        >>> s = r"\TeX\ is $\displaystyle\sum_n\frac{-e^{i\pi}}{2^n}$!"
        >>> Z = texmanager.get_rgba(s, fontsize=12, dpi=80, rgb=(1, 0, 0))
        """
        alpha = cls.get_grey(tex, fontsize, dpi)
        rgba = np.empty((*alpha.shape, 4))
        rgba[..., :3] = mpl.colors.to_rgb(rgb)
        rgba[..., -1] = alpha
        return rgba

    @classmethod
    def get_text_width_height_descent(cls, tex, fontsize, renderer=None):
        """Return width, height and descent of the text."""
        if tex.strip() == '':
            return 0, 0, 0
        dvifile = cls.make_dvi(tex, fontsize)
        dpi_fraction = renderer.points_to_pixels(1.) if renderer else 1
        with dviread.Dvi(dvifile, 72 * dpi_fraction) as dvi:
            page, = dvi
        # A total height (including the descent) needs to be returned.
        return page.width, page.height + page.descent, page.descent