Exemple #1
0
class Plot(QtGui.QWidget):
	def __init__(self, winTitle="plot", parent = None, flags = QtCore.Qt.WindowFlags(0)):
		""" __init__(winTitle="plot", parent = None, flags = Qt.WindowFlags(0)): Construct a new plot widget.\n winTitle = Tab title.\n parent = Widget parent.\n flags = QWidget flags"""
		QtGui.QWidget.__init__(self, parent, flags)
		self.setWindowTitle(winTitle)
		# Create matplotlib canvas
		self.fig = Figure()
		self.canvas = FigureCanvas(self.fig)
		self.canvas.setParent(self)
		# Get axes
		self.axes     = self.fig.add_subplot(111)
		self.axesList = [self.axes]
		self.axes.xaxis.set_ticks_position('bottom')
		self.axes.spines['top'].set_color('none')
		self.axes.yaxis.set_ticks_position('left')
		self.axes.spines['right'].set_color('none')
		# Setup layout
		vbox = QtGui.QVBoxLayout()
		vbox.addWidget(self.canvas)
		self.setLayout(vbox)
		# Active series
		self.series = []
		# Indicators
		self.skip   = False
		self.legend = False
		self.legPos = (1.0, 1.0)
		self.legSiz = 14
		self.grid   = False
	
	def plot(self, x, y, name=None):
		""" plot(self, x, y, name=None): Plot a new line and return it.\n x = X values\n y = Y values\n name = Serie name (for legend). """
		l = Line(self.axes, x, y, name)
		self.series.append(l)
		# Update window
		self.update()
		return l

	def update(self):
		""" update(): Updates plot. """
		if not self.skip:
			self.skip = True
			if self.legend:
				legend(self.legend)
			self.canvas.draw()
			self.skip = False

	def isGrid(self):
		""" isGrid(): Return True if Grid is active, False otherwise. """
		return bool(self.grid)

	def isLegend(self):
		""" isLegend(): Return True if Legend is active, False otherwise. """
		return bool(self.legend)

	def setActiveAxes(self, index):
		""" setActiveAxes(index): Change current active axes.\n index = Index of the new active axes. """
		self.axes = self.axesList[index]
		self.fig.sca(self.axes)
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """
    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        if matplotlib.__version__[0] < '2':
            self.ax.set_axis_bgcolor('none')
        else:
            self.ax.set_facecolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()
        self.matplotlibVersion = matplotlib.__version__

        self.setGraphXLimits(0., 100.)
        self.setGraphYLimits(0., 100., axis='right')
        self.setGraphYLimits(0., 100., axis='left')

        self._enableAxis('right', False)

    # Add methods

    def addCurve(self, x, y, legend, color, symbol, linewidth, linestyle,
                 yaxis, xerror, yerror, z, selectable, fill, alpha):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4
                and type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (yerror is not None and yerror.ndim == 2
                    and yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x,
                                      y,
                                      label=legend,
                                      xerr=xerror,
                                      yerr=yerror,
                                      linestyle=' ',
                                      color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x,
                                      y,
                                      label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x,
                                   y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker)
            artists.append(scatter)

            if fill:
                artists.append(
                    axes.fill_between(x,
                                      1.0e-8,
                                      y,
                                      facecolor=actualColor[0],
                                      linestyle=''))

        else:  # Curve
            curveList = axes.plot(x,
                                  y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker)
            artists += list(curveList)

            if fill:
                artists.append(axes.fill_between(x, 1.0e-8, y,
                                                 facecolor=color))

        for artist in artists:
            artist.set_zorder(z)
            if alpha < 1:
                artist.set_alpha(alpha)

        return Container(artists)

    def addImage(self, data, legend, origin, scale, z, selectable, draggable,
                 colormap, alpha):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z, selectable,
                          draggable):
            assert parameter is not None

        origin = float(origin[0]), float(origin[1])
        scale = float(scale[0]), float(scale[1])
        height, width = data.shape[0:2]

        picker = (selectable or draggable)

        # Debian 7 specific support
        # No transparent colormap with matplotlib < 1.2.0
        # Add support for transparent colormap for uint8 data with
        # colormap with 256 colors, linear norm, [0, 255] range
        if matplotlib.__version__ < '1.2.0':
            if (len(data.shape) == 2 and colormap['name'] is None
                    and 'colors' in colormap):
                colors = numpy.array(colormap['colors'], copy=False)
                if (colors.shape[-1] == 4
                        and not numpy.all(numpy.equal(colors[3], 255))):
                    # This is a transparent colormap
                    if (colors.shape == (256, 4)
                            and colormap['normalization'] == 'linear'
                            and not colormap['autoscale']
                            and colormap['vmin'] == 0
                            and colormap['vmax'] == 255
                            and data.dtype == numpy.uint8):
                        # Supported case, convert data to RGBA
                        data = colors[data.reshape(-1)].reshape(data.shape +
                                                                (4, ))
                    else:
                        _logger.warning(
                            'matplotlib %s does not support transparent '
                            'colormap.', matplotlib.__version__)

        if ((height * width) > 5.0e5 and origin == (0., 0.)
                and scale == (1., 1.)):
            imageClass = ModestImage
        else:
            imageClass = AxesImage

        # the normalization can be a source of time waste
        # Two possibilities, we receive data or a ready to show image
        if len(data.shape) == 3:  # RGBA image
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               picker=picker,
                               zorder=z,
                               origin='lower')

        else:
            # Convert colormap argument to matplotlib colormap
            scalarMappable = Colors.getMPLScalarMappable(colormap, data)

            # try as data
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               cmap=scalarMappable.cmap,
                               picker=picker,
                               zorder=z,
                               norm=scalarMappable.norm,
                               origin='lower')
        if alpha < 1:
            image.set_alpha(alpha)

        # Set image extent
        xmin = origin[0]
        xmax = xmin + scale[0] * width
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin

        ymin = origin[1]
        ymax = ymin + scale[1] * height
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin

        image.set_extent((xmin, xmax, ymin, ymax))

        # Set image data
        if scale[0] < 0. or scale[1] < 0.:
            # For negative scale, step by -1
            xstep = 1 if scale[0] >= 0. else -1
            ystep = 1 if scale[1] >= 0. else -1
            data = data[::ystep, ::xstep]

        image.set_data(data)

        self.ax.add_artist(image)

        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z):
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        if shape == "line":
            item = self.ax.plot(x,
                                y,
                                label=legend,
                                color=color,
                                linestyle='-',
                                marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color)
            if fill:
                item.set_hatch('.')

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            xView = xView.reshape(1, -1)
            yView = yView.reshape(1, -1)
            item = Polygon(numpy.vstack((xView, yView)).T,
                           closed=(shape == 'polygon'),
                           fill=False,
                           label=legend,
                           color=color)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color, selectable, draggable,
                  symbol, constraint, overlay):
        legend = "__MARKER__" + legend

        if x is not None and y is not None:
            line = self.ax.plot(x,
                                y,
                                label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                xtmp, ytmp = self.ax.transData.transform_point((x, y))
                inv = self.ax.transData.inverted()
                xtmp, ytmp = inv.transform_point((xtmp, ytmp))

                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                line._infoText = self.ax.text(x,
                                              ytmp,
                                              text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x, label=legend, color=color)
            if text is not None:
                text = " " + text
                ymin, ymax = self.getGraphYLimits(axis='left')
                delta = abs(ymax - ymin)
                if ymin > ymax:
                    ymax = ymin
                ymax -= 0.005 * delta
                line._infoText = self.ax.text(x,
                                              ymax,
                                              text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y, label=legend, color=color)

            if text is not None:
                text = " " + text
                xmin, xmax = self.getGraphXLimits()
                delta = abs(xmax - xmin)
                if xmin > xmax:
                    xmax = xmin
                xmax -= 0.005 * delta
                line._infoText = self.ax.text(xmax,
                                              y,
                                              text,
                                              color=color,
                                              horizontalalignment='right',
                                              verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        if overlay:
            line.set_animated(True)
            self._overlays.add(line)

        return line

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        if hasattr(item, "_infoText"):  # For markers text
            item._infoText.remove()
            item._infoText = None
        self._overlays.discard(item)
        item.remove()

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(self.ax.get_ybound()[0],
                                    visible=False,
                                    color=color,
                                    linewidth=linewidth,
                                    linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(self.ax.get_xbound()[0],
                                    visible=False,
                                    color=color,
                                    linewidth=linewidth,
                                    linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setCurveColor(self, curve, color):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if isinstance(artist, (Line2D, LineCollection)):
                artist.set_color(color)
            elif isinstance(artist, PathCollection):
                artist.set_facecolors(color)
                artist.set_edgecolors(color)
            else:
                _logger.warning('setActiveCurve ignoring artist %s',
                                str(artist))

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def resetZoom(self, dataMargins):
        xAuto = self._plot.isXAxisAutoScale()
        yAuto = self._plot.isYAxisAutoScale()

        if not xAuto and not yAuto:
            _logger.debug("Nothing to autoscale")
        else:  # Some axes to autoscale
            xLimits = self.getGraphXLimits()
            yLimits = self.getGraphYLimits(axis='left')
            y2Limits = self.getGraphYLimits(axis='right')

            # Get data range
            ranges = self._plot.getDataRange()
            xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
            ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
            if ranges.yright is None:
                ymin2, ymax2 = None, None
            else:
                ymin2, ymax2 = ranges.yright

            # Add margins around data inside the plot area
            newLimits = list(
                utils.addMarginsToLimits(dataMargins,
                                         self.ax.get_xscale() == 'log',
                                         self.ax.get_yscale() == 'log', xmin,
                                         xmax, ymin, ymax, ymin2, ymax2))

            if self.isKeepDataAspectRatio():
                # Use limits with margins to keep ratio
                xmin, xmax, ymin, ymax = newLimits[:4]

                # Compute bbox wth figure aspect ratio
                figW, figH = self.fig.get_size_inches()
                figureRatio = figH / figW

                dataRatio = (ymax - ymin) / (xmax - xmin)
                if dataRatio < figureRatio:
                    # Increase y range
                    ycenter = 0.5 * (ymax + ymin)
                    yrange = (xmax - xmin) * figureRatio
                    newLimits[2] = ycenter - 0.5 * yrange
                    newLimits[3] = ycenter + 0.5 * yrange

                elif dataRatio > figureRatio:
                    # Increase x range
                    xcenter = 0.5 * (xmax + xmin)
                    xrange_ = (ymax - ymin) / figureRatio
                    newLimits[0] = xcenter - 0.5 * xrange_
                    newLimits[1] = xcenter + 0.5 * xrange_

            self.setLimits(*newLimits)

            if not xAuto and yAuto:
                self.setGraphXLimits(*xLimits)
            elif xAuto and not yAuto:
                if y2Limits is not None:
                    self.setGraphYLimits(y2Limits[0],
                                         y2Limits[1],
                                         axis='right')
                if yLimits is not None:
                    self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

    # Graph axes

    def setXAxisLogarithmic(self, flag):
        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # colormap

    def getSupportedColormaps(self):
        default = super(BackendMatplotlib, self).getSupportedColormaps()
        maps = [m for m in cm.datad]
        maps.sort()
        return default + tuple(maps)

    # Data <-> Pixel coordinates conversion

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T
        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent().transformed(
            self.fig.dpi_scale_trans.inverted())
        dpi = self.fig.dpi
        # Warning this is not returning int...
        return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi,
                bbox.bounds[2] * dpi, bbox.bounds[3] * dpi)
Exemple #3
0
class Plot(PySide.QtGui.QWidget):
    def __init__(self,
                 winTitle="plot",
                 parent=None,
                 flags=PySide.QtCore.Qt.WindowFlags(0)):
        """Construct a new plot widget.

        Keyword arguments:
        winTitle -- Tab title.
        parent -- Widget parent.
        flags -- QWidget flags
        """
        PySide.QtGui.QWidget.__init__(self, parent, flags)
        self.setWindowTitle(winTitle)
        # Create matplotlib canvas
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        # Get axes
        self.axes = self.fig.add_subplot(111)
        self.axesList = [self.axes]
        self.axes.xaxis.set_ticks_position('bottom')
        self.axes.spines['top'].set_color('none')
        self.axes.yaxis.set_ticks_position('left')
        self.axes.spines['right'].set_color('none')
        # Add the navigation toolbar by default
        self.mpl_toolbar = NavigationToolbar(self.canvas, self)
        # Setup layout
        vbox = PySide.QtGui.QVBoxLayout()
        vbox.addWidget(self.canvas)
        vbox.addWidget(self.mpl_toolbar)
        self.setLayout(vbox)
        # Active series
        self.series = []
        # Indicators
        self.skip = False
        self.legend = False
        self.legPos = (1.0, 1.0)
        self.legSiz = 14
        self.grid = False

    def plot(self, x, y, name=None):
        """Plot a new line and return it.

        Keyword arguments:
        x -- X values
        y -- Y values
        name -- Serie name (for legend). """
        l = Line(self.axes, x, y, name)
        self.series.append(l)
        # Update window
        self.update()
        return l

    def update(self):
        """Update the plot, redrawing the canvas."""
        if not self.skip:
            self.skip = True
            if self.legend:
                legend(self.legend, self.legPos, self.legSiz)
            self.canvas.draw()
            self.skip = False

    def isGrid(self):
        """Return True if Grid is active, False otherwise."""
        return bool(self.grid)

    def isLegend(self):
        """Return True if Legend is active, False otherwise."""
        return bool(self.legend)

    def setActiveAxes(self, index):
        """Change the current active axes.

        Keyword arguments:
        index -- Index of the new active axes set.
        """
        self.axes = self.axesList[index]
        self.fig.sca(self.axes)
Exemple #4
0
class Plot(QtGui.QWidget):
    def __init__(self,
                 winTitle="plot",
                 parent=None,
                 flags=QtCore.Qt.WindowFlags(0)):
        """ __init__(winTitle="plot", parent = None, flags = Qt.WindowFlags(0)): Construct a new plot widget.\n winTitle = Tab title.\n parent = Widget parent.\n flags = QWidget flags"""
        QtGui.QWidget.__init__(self, parent, flags)
        self.setWindowTitle(winTitle)
        # Create matplotlib canvas
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        # Get axes
        self.axes = self.fig.add_subplot(111)
        self.axesList = [self.axes]
        self.axes.xaxis.set_ticks_position('bottom')
        self.axes.spines['top'].set_color('none')
        self.axes.yaxis.set_ticks_position('left')
        self.axes.spines['right'].set_color('none')
        # Setup layout
        vbox = QtGui.QVBoxLayout()
        vbox.addWidget(self.canvas)
        self.setLayout(vbox)
        # Active series
        self.series = []
        # Indicators
        self.skip = False
        self.legend = False
        self.legPos = (1.0, 1.0)
        self.legSiz = 14
        self.grid = False

    def plot(self, x, y, name=None):
        """ plot(self, x, y, name=None): Plot a new line and return it.\n x = X values\n y = Y values\n name = Serie name (for legend). """
        l = Line(self.axes, x, y, name)
        self.series.append(l)
        # Update window
        self.update()
        return l

    def update(self):
        """ update(): Updates plot. """
        if not self.skip:
            self.skip = True
            if self.legend:
                legend(self.legend)
            self.canvas.draw()
            self.skip = False

    def isGrid(self):
        """ isGrid(): Return True if Grid is active, False otherwise. """
        return bool(self.grid)

    def isLegend(self):
        """ isLegend(): Return True if Legend is active, False otherwise. """
        return bool(self.legend)

    def setActiveAxes(self, index):
        """ setActiveAxes(index): Change current active axes.\n index = Index of the new active axes. """
        self.axes = self.axesList[index]
        self.fig.sca(self.axes)
Exemple #5
0
class Plot(PySide.QtGui.QWidget):
    def __init__(self,
                 winTitle="plot",
                 parent=None,
                 flags=PySide.QtCore.Qt.WindowFlags(0)):
        """Construct a new plot widget.

        Keyword arguments:
        winTitle -- Tab title.
        parent -- Widget parent.
        flags -- QWidget flags
        """
        PySide.QtGui.QWidget.__init__(self, parent, flags)
        self.setWindowTitle(winTitle)
        # Create matplotlib canvas
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        # Get axes
        self.axes = self.fig.add_subplot(111)
        self.axesList = [self.axes]
        self.axes.xaxis.set_ticks_position('bottom')
        self.axes.spines['top'].set_color('none')
        self.axes.yaxis.set_ticks_position('left')
        self.axes.spines['right'].set_color('none')
        # Add the navigation toolbar by default
        self.mpl_toolbar = NavigationToolbar(self.canvas, self)
        # Setup layout
        vbox = PySide.QtGui.QVBoxLayout()
        vbox.addWidget(self.canvas)
        vbox.addWidget(self.mpl_toolbar)
        self.setLayout(vbox)
        # Active series
        self.series = []
        # Indicators
        self.skip = False
        self.legend = False
        self.legPos = (1.0, 1.0)
        self.legSiz = 14
        self.grid = False

    def plot(self, x, y, name=None):
        """Plot a new line and return it.

        Keyword arguments:
        x -- X values
        y -- Y values
        name -- Serie name (for legend). """
        l = Line(self.axes, x, y, name)
        self.series.append(l)
        # Update window
        self.update()
        return l

    def update(self):
        """Update the plot, redrawing the canvas."""
        if not self.skip:
            self.skip = True
            if self.legend:
                legend(self.legend, self.legPos, self.legSiz)
            self.canvas.draw()
            self.skip = False

    def isGrid(self):
        """Return True if Grid is active, False otherwise."""
        return bool(self.grid)

    def isLegend(self):
        """Return True if Legend is active, False otherwise."""
        return bool(self.legend)

    def setActiveAxes(self, index):
        """Change the current active axes.

        Keyword arguments:
        index -- Index of the new active axes set.
        """
        self.axes = self.axesList[index]
        self.fig.sca(self.axes)
Exemple #6
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """

    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True
        self._axesDisplayed = True
        self._matplotlibVersion = _parse_version(matplotlib.__version__)

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")
        # Make sure background of Axes is displayed
        self.ax2.patch.set_visible(True)

        # Set axis zorder=0.5 so grid is displayed at 0.5
        self.ax.set_axisbelow(True)

        # disable the use of offsets
        try:
            self.ax.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax.get_xaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
        except:
            _logger.warning('Cannot disabled axes offsets in %s '
                            % matplotlib.__version__)

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        if self._matplotlibVersion < _parse_version('2'):
            self.ax.set_axis_bgcolor('none')
        else:
            self.ax.set_facecolor('none')
        self.fig.sca(self.ax)

        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()

        self._enableAxis('right', False)
        self._isXAxisTimeSeries = False

    def _overlayItems(self):
        """Generator of backend renderer for overlay items"""
        for item in self._plot.getItems():
            if (item.isOverlay() and
                    item.isVisible() and
                    item._backendRenderer is not None):
                yield item._backendRenderer

    def _hasOverlays(self):
        """Returns whether there is an overlay layer or not.

        The overlay layers contains overlay items and the crosshair.

        :rtype: bool
        """
        if self._graphCursor:
            return True  # There is the crosshair

        for item in self._overlayItems():
            return True  # There is at least one overlay item
        return False

    # Add methods

    def addCurve(self, x, y,
                 color, symbol, linewidth, linestyle,
                 yaxis,
                 xerror, yerror, z, selectable,
                 fill, alpha, symbolsize):
        for parameter in (x, y, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill, alpha, symbolsize):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4 and
                type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
            if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and
                        xerror.shape[1] == 1):
                xerror = numpy.ravel(xerror)
            if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
                    yerror.shape[1] == 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x, y,
                                      xerr=xerror, yerr=yerror,
                                      linestyle=' ', color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x, y,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x, y,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker,
                                   s=symbolsize**2)
            artists.append(scatter)

            if fill:
                artists.append(axes.fill_between(
                    x, FLOAT32_MINPOS, y, facecolor=actualColor[0], linestyle=''))

        else:  # Curve
            curveList = axes.plot(x, y,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker,
                                  markersize=symbolsize)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color))

        for artist in artists:
            artist.set_animated(True)
            artist.set_zorder(z + 1)
            if alpha < 1:
                artist.set_alpha(alpha)

        return _PickableContainer(artists)

    def addImage(self, data, origin, scale, z, selectable, draggable, colormap, alpha):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, origin, scale, z, selectable, draggable):
            assert parameter is not None

        origin = float(origin[0]), float(origin[1])
        scale = float(scale[0]), float(scale[1])
        height, width = data.shape[0:2]

        picker = (selectable or draggable)

        # All image are shown as RGBA image
        image = Image(self.ax,
                      interpolation='nearest',
                      picker=picker,
                      zorder=z + 1,
                      origin='lower')

        if alpha < 1:
            image.set_alpha(alpha)

        # Set image extent
        xmin = origin[0]
        xmax = xmin + scale[0] * width
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin

        ymin = origin[1]
        ymax = ymin + scale[1] * height
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin

        image.set_extent((xmin, xmax, ymin, ymax))

        # Set image data
        if scale[0] < 0. or scale[1] < 0.:
            # For negative scale, step by -1
            xstep = 1 if scale[0] >= 0. else -1
            ystep = 1 if scale[1] >= 0. else -1
            data = data[::ystep, ::xstep]

        if data.ndim == 2:  # Data image, convert to RGBA image
            data = colormap.applyToData(data)

        image.set_data(data)
        image.set_animated(True)
        self.ax.add_artist(image)
        return image

    def addTriangles(self, x, y, triangles, color, z, selectable, alpha):
        for parameter in (x, y, triangles, color, z, selectable, alpha):
            assert parameter is not None

        # 0 enables picking on filled triangle
        picker = 0 if selectable else None

        color = numpy.array(color, copy=False)
        assert color.ndim == 2 and len(color) == len(x)

        if color.dtype not in [numpy.float32, numpy.float]:
            color = color.astype(numpy.float32) / 255.

        collection = TriMesh(
            Triangulation(x, y, triangles),
            alpha=alpha,
            picker=picker,
            zorder=z + 1)
        collection.set_color(color)
        collection.set_animated(True)
        self.ax.add_collection(collection)

        return collection

    def addItem(self, x, y, shape, color, fill, overlay, z,
                linestyle, linewidth, linebgcolor):
        if (linebgcolor is not None and
                shape not in ('rectangle', 'polygon', 'polylines')):
            _logger.warning(
                'linebgcolor not implemented for %s with matplotlib backend',
                shape)
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        linestyle = normalize_linestyle(linestyle)

        if shape == "line":
            item = self.ax.plot(x, y, color=color,
                                linestyle=linestyle, linewidth=linewidth,
                                marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color,
                             linestyle=linestyle,
                             linewidth=linewidth)
            if fill:
                item.set_hatch('.')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            points = numpy.array((xView, yView)).T
            if shape == 'polygon':
                closed = True
            else:  # shape == 'polylines'
                closed = numpy.all(numpy.equal(points[0], points[-1]))
            item = Polygon(points,
                           closed=closed,
                           fill=False,
                           color=color,
                           linestyle=linestyle,
                           linewidth=linewidth)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z + 1)
        item.set_animated(True)

        return item

    def addMarker(self, x, y, text, color,
                  selectable, draggable,
                  symbol, linestyle, linewidth, constraint, yaxis):
        textArtist = None

        xmin, xmax = self.getGraphXLimits()
        ymin, ymax = self.getGraphYLimits(axis=yaxis)

        if yaxis == 'left':
            ax = self.ax
        elif yaxis == 'right':
            ax = self.ax2
        else:
            assert(False)

        if x is not None and y is not None:
            line = ax.plot(x, y,
                           linestyle=" ",
                           color=color,
                           marker=symbol,
                           markersize=10.)[-1]

            if text is not None:
                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                textArtist = ax.text(x, y, text,
                                     color=color,
                                     horizontalalignment='left',
                                     verticalalignment=valign)

        elif x is not None:
            line = ax.axvline(x,
                              color=color,
                              linewidth=linewidth,
                              linestyle=linestyle)
            if text is not None:
                # Y position will be updated in updateMarkerText call
                textArtist = ax.text(x, 1., " " + text,
                                     color=color,
                                     horizontalalignment='left',
                                     verticalalignment='top')

        elif y is not None:
            line = ax.axhline(y,
                              color=color,
                              linewidth=linewidth,
                              linestyle=linestyle)

            if text is not None:
                # X position will be updated in updateMarkerText call
                textArtist = ax.text(1., y, " " + text,
                                     color=color,
                                     horizontalalignment='right',
                                     verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        # All markers are overlays
        line.set_animated(True)
        if textArtist is not None:
            textArtist.set_animated(True)

        artists = [line] if textArtist is None else [line, textArtist]
        container = _MarkerContainer(artists, x, y, yaxis)
        container.updateMarkerText(xmin, xmax, ymin, ymax)

        return container

    def _updateMarkers(self):
        xmin, xmax = self.ax.get_xbound()
        ymin1, ymax1 = self.ax.get_ybound()
        ymin2, ymax2 = self.ax2.get_ybound()
        for item in self._overlayItems():
            if isinstance(item, _MarkerContainer):
                if item.yAxis == 'left':
                    item.updateMarkerText(xmin, xmax, ymin1, ymax1)
                else:
                    item.updateMarkerText(xmin, xmax, ymin2, ymax2)

    # Remove methods

    def remove(self, item):
        try:
            item.remove()
        except ValueError:
            pass  # Already removed e.g., in set[X|Y]AxisLogarithmic

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(
                self.ax.get_ybound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(
                self.ax.get_xbound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setCurveColor(self, curve, color):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if isinstance(artist, (Line2D, LineCollection)):
                artist.set_color(color)
            elif isinstance(artist, PathCollection):
                artist.set_facecolors(color)
                artist.set_edgecolors(color)
            else:
                _logger.warning(
                    'setActiveCurve ignoring artist %s', str(artist))

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

        self._updateMarkers()

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.ax.apply_aspect()
            self.ax2.apply_aspect()
            self._dirtyLimits = False
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
        self._updateMarkers()

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.ax.apply_aspect()
            self.ax2.apply_aspect()
            self._dirtyLimits = False

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

        self._updateMarkers()

    # Graph axes

    def setXAxisTimeZone(self, tz):
        super(BackendMatplotlib, self).setXAxisTimeZone(tz)

        # Make new formatter and locator with the time zone.
        self.setXAxisTimeSeries(self.isXAxisTimeSeries())

    def isXAxisTimeSeries(self):
        return self._isXAxisTimeSeries

    def setXAxisTimeSeries(self, isTimeSeries):
        self._isXAxisTimeSeries = isTimeSeries
        if self._isXAxisTimeSeries:
            # We can't use a matplotlib.dates.DateFormatter because it expects
            # the data to be in datetimes. Silx works internally with
            # timestamps (floats).
            locator = NiceDateLocator(tz=self.getXAxisTimeZone())
            self.ax.xaxis.set_major_locator(locator)
            self.ax.xaxis.set_major_formatter(
                NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
        else:
            try:
                scalarFormatter = ScalarFormatter(useOffset=False)
            except:
                _logger.warning('Cannot disabled axes offsets in %s ' %
                                matplotlib.__version__)
                scalarFormatter = ScalarFormatter()
            self.ax.xaxis.set_major_formatter(scalarFormatter)

    def setXAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.1.0 when one tries to set an axis
        # to log scale with both limits <= 0
        # In this case a draw with positive limits is needed first
        if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
            xlim = self.ax.get_xlim()
            if xlim[0] <= 0 and xlim[1] <= 0:
                self.ax.set_xlim(1, 10)
                self.draw()

        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.0 issue with negative bounds
        # before switching to log scale
        if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
            redraw = False
            for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
                ylim = axis.get_ylim()
                if ylim[0] <= 0 or ylim[1] <= 0:
                    dataRange = self._plot.getDataRange()[dataRangeIndex]
                    if dataRange is None:
                        dataRange = 1, 100  # Fallback
                    axis.set_ylim(*dataRange)
                    redraw = True
            if redraw:
                self.draw()

        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # Data <-> Pixel coordinates conversion

    def _mplQtYAxisCoordConversion(self, y):
        """Qt origin (top) to/from matplotlib origin (bottom) conversion.

        :rtype: float
        """
        height = self.fig.get_window_extent().height
        return height - y

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T

        # Convert from matplotlib origin (bottom) to Qt origin (top)
        yPixel = self._mplQtYAxisCoordConversion(yPixel)

        return xPixel, yPixel

    def pixelToData(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        # Convert from Qt origin (top) to matplotlib origin (bottom)
        y = self._mplQtYAxisCoordConversion(y)

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))
        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent()
        # Warning this is not returning int...
        return (bbox.xmin,
                self._mplQtYAxisCoordConversion(bbox.ymax),
                bbox.width,
                bbox.height)

    def setAxesDisplayed(self, displayed):
        """Display or not the axes.

        :param bool displayed: If `True` axes are displayed. If `False` axes
            are not anymore visible and the margin used for them is removed.
        """
        BackendBase.BackendBase.setAxesDisplayed(self, displayed)
        if displayed:
            # show axes and viewbox rect
            self.ax.set_axis_on()
            self.ax2.set_axis_on()
            # set the default margins
            self.ax.set_position([.15, .15, .75, .75])
            self.ax2.set_position([.15, .15, .75, .75])
        else:
            # hide axes and viewbox rect
            self.ax.set_axis_off()
            self.ax2.set_axis_off()
            # remove external margins
            self.ax.set_position([0, 0, 1, 1])
            self.ax2.set_position([0, 0, 1, 1])
        self._synchronizeBackgroundColors()
        self._synchronizeForegroundColors()
        self._plot._setDirtyPlot()

    def _synchronizeBackgroundColors(self):
        backgroundColor = self._plot.getBackgroundColor().getRgbF()

        dataBackgroundColor = self._plot.getDataBackgroundColor()
        if dataBackgroundColor.isValid():
            dataBackgroundColor = dataBackgroundColor.getRgbF()
        else:
            dataBackgroundColor = backgroundColor

        if self.ax2.axison:
            self.fig.patch.set_facecolor(backgroundColor)
            if self._matplotlibVersion < _parse_version('2'):
                self.ax2.set_axis_bgcolor(dataBackgroundColor)
            else:
                self.ax2.set_facecolor(dataBackgroundColor)
        else:
            self.fig.patch.set_facecolor(dataBackgroundColor)

    def _synchronizeForegroundColors(self):
        foregroundColor = self._plot.getForegroundColor().getRgbF()

        gridColor = self._plot.getGridColor()
        if gridColor.isValid():
            gridColor = gridColor.getRgbF()
        else:
            gridColor = foregroundColor

        for axes in (self.ax, self.ax2):
            if axes.axison:
                axes.spines['bottom'].set_color(foregroundColor)
                axes.spines['top'].set_color(foregroundColor)
                axes.spines['right'].set_color(foregroundColor)
                axes.spines['left'].set_color(foregroundColor)
                axes.tick_params(axis='x', colors=foregroundColor)
                axes.tick_params(axis='y', colors=foregroundColor)
                axes.yaxis.label.set_color(foregroundColor)
                axes.xaxis.label.set_color(foregroundColor)
                axes.title.set_color(foregroundColor)

                for line in axes.get_xgridlines():
                    line.set_color(gridColor)

                for line in axes.get_ygridlines():
                    line.set_color(gridColor)
Exemple #7
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """
    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True
        self._axesDisplayed = True

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # disable the use of offsets
        try:
            self.ax.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax.get_xaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
        except:
            _logger.warning('Cannot disabled axes offsets in %s ' \
                            % matplotlib.__version__)

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        if matplotlib.__version__[0] < '2':
            self.ax.set_axis_bgcolor('none')
        else:
            self.ax.set_facecolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()
        self.matplotlibVersion = matplotlib.__version__

        self._enableAxis('right', False)
        self._isXAxisTimeSeries = False

    # Add methods

    def addCurve(self, x, y, legend, color, symbol, linewidth, linestyle,
                 yaxis, xerror, yerror, z, selectable, fill, alpha,
                 symbolsize):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill, alpha, symbolsize):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4
                and type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2
                    and yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x,
                                      y,
                                      label=legend,
                                      xerr=xerror,
                                      yerr=yerror,
                                      linestyle=' ',
                                      color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x,
                                      y,
                                      label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x,
                                   y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker,
                                   s=symbolsize)
            artists.append(scatter)

            if fill:
                artists.append(
                    axes.fill_between(x,
                                      FLOAT32_MINPOS,
                                      y,
                                      facecolor=actualColor[0],
                                      linestyle=''))

        else:  # Curve
            curveList = axes.plot(x,
                                  y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker,
                                  markersize=symbolsize)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color))

        for artist in artists:
            artist.set_zorder(z)
            if alpha < 1:
                artist.set_alpha(alpha)

        return Container(artists)

    def addImage(self, data, legend, origin, scale, z, selectable, draggable,
                 colormap, alpha):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z, selectable,
                          draggable):
            assert parameter is not None

        origin = float(origin[0]), float(origin[1])
        scale = float(scale[0]), float(scale[1])
        height, width = data.shape[0:2]

        picker = (selectable or draggable)

        # Debian 7 specific support
        # No transparent colormap with matplotlib < 1.2.0
        # Add support for transparent colormap for uint8 data with
        # colormap with 256 colors, linear norm, [0, 255] range
        if matplotlib.__version__ < '1.2.0':
            if (len(data.shape) == 2 and colormap.getName() is None
                    and colormap.getColormapLUT() is not None):
                colors = colormap.getColormapLUT()
                if (colors.shape[-1] == 4
                        and not numpy.all(numpy.equal(colors[3], 255))):
                    # This is a transparent colormap
                    if (colors.shape == (256, 4)
                            and colormap.getNormalization() == 'linear'
                            and not colormap.isAutoscale()
                            and colormap.getVMin() == 0
                            and colormap.getVMax() == 255
                            and data.dtype == numpy.uint8):
                        # Supported case, convert data to RGBA
                        data = colors[data.reshape(-1)].reshape(data.shape +
                                                                (4, ))
                    else:
                        _logger.warning(
                            'matplotlib %s does not support transparent '
                            'colormap.', matplotlib.__version__)

        if ((height * width) > 5.0e5 and origin == (0., 0.)
                and scale == (1., 1.)):
            imageClass = ModestImage
        else:
            imageClass = AxesImage

        # the normalization can be a source of time waste
        # Two possibilities, we receive data or a ready to show image
        if len(data.shape) == 3:  # RGBA image
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               picker=picker,
                               zorder=z,
                               origin='lower')

        else:
            # Convert colormap argument to matplotlib colormap
            scalarMappable = MPLColormap.getScalarMappable(colormap, data)

            # try as data
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               cmap=scalarMappable.cmap,
                               picker=picker,
                               zorder=z,
                               norm=scalarMappable.norm,
                               origin='lower')
        if alpha < 1:
            image.set_alpha(alpha)

        # Set image extent
        xmin = origin[0]
        xmax = xmin + scale[0] * width
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin

        ymin = origin[1]
        ymax = ymin + scale[1] * height
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin

        image.set_extent((xmin, xmax, ymin, ymax))

        # Set image data
        if scale[0] < 0. or scale[1] < 0.:
            # For negative scale, step by -1
            xstep = 1 if scale[0] >= 0. else -1
            ystep = 1 if scale[1] >= 0. else -1
            data = data[::ystep, ::xstep]

        if matplotlib.__version__ < "2.1":
            # matplotlib 1.4.2 do not support float128
            dtype = data.dtype
            if dtype.kind == "f" and dtype.itemsize >= 16:
                _logger.warning("Your matplotlib version do not support "
                                "float128. Data converted to floa64.")
                data = data.astype(numpy.float64)

        image.set_data(data)

        self.ax.add_artist(image)

        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z):
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        if shape == "line":
            item = self.ax.plot(x,
                                y,
                                label=legend,
                                color=color,
                                linestyle='-',
                                marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color)
            if fill:
                item.set_hatch('.')

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            points = numpy.array((xView, yView)).T
            if shape == 'polygon':
                closed = True
            else:  # shape == 'polylines'
                closed = numpy.all(numpy.equal(points[0], points[-1]))
            item = Polygon(points,
                           closed=closed,
                           fill=False,
                           label=legend,
                           color=color)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color, selectable, draggable,
                  symbol, constraint):
        legend = "__MARKER__" + legend

        textArtist = None

        xmin, xmax = self.getGraphXLimits()
        ymin, ymax = self.getGraphYLimits(axis='left')

        if x is not None and y is not None:
            line = self.ax.plot(x,
                                y,
                                label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                textArtist = self.ax.text(x,
                                          y,
                                          text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x, label=legend, color=color)
            if text is not None:
                # Y position will be updated in updateMarkerText call
                textArtist = self.ax.text(x,
                                          1.,
                                          " " + text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y, label=legend, color=color)

            if text is not None:
                # X position will be updated in updateMarkerText call
                textArtist = self.ax.text(1.,
                                          y,
                                          " " + text,
                                          color=color,
                                          horizontalalignment='right',
                                          verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        # All markers are overlays
        line.set_animated(True)
        if textArtist is not None:
            textArtist.set_animated(True)

        artists = [line] if textArtist is None else [line, textArtist]
        container = _MarkerContainer(artists, x, y)
        container.updateMarkerText(xmin, xmax, ymin, ymax)
        self._overlays.add(container)

        return container

    def _updateMarkers(self):
        xmin, xmax = self.ax.get_xbound()
        ymin, ymax = self.ax.get_ybound()
        for item in self._overlays:
            if isinstance(item, _MarkerContainer):
                item.updateMarkerText(xmin, xmax, ymin, ymax)

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        self._overlays.discard(item)
        try:
            item.remove()
        except ValueError:
            pass  # Already removed e.g., in set[X|Y]AxisLogarithmic

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(self.ax.get_ybound()[0],
                                    visible=False,
                                    color=color,
                                    linewidth=linewidth,
                                    linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(self.ax.get_xbound()[0],
                                    visible=False,
                                    color=color,
                                    linewidth=linewidth,
                                    linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setCurveColor(self, curve, color):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if isinstance(artist, (Line2D, LineCollection)):
                artist.set_color(color)
            elif isinstance(artist, PathCollection):
                artist.set_facecolors(color)
                artist.set_edgecolors(color)
            else:
                _logger.warning('setActiveCurve ignoring artist %s',
                                str(artist))

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

        self._updateMarkers()

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
        self._updateMarkers()

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

        self._updateMarkers()

    # Graph axes

    def setXAxisTimeZone(self, tz):
        super(BackendMatplotlib, self).setXAxisTimeZone(tz)

        # Make new formatter and locator with the time zone.
        self.setXAxisTimeSeries(self.isXAxisTimeSeries())

    def isXAxisTimeSeries(self):
        return self._isXAxisTimeSeries

    def setXAxisTimeSeries(self, isTimeSeries):
        self._isXAxisTimeSeries = isTimeSeries
        if self._isXAxisTimeSeries:
            # We can't use a matplotlib.dates.DateFormatter because it expects
            # the data to be in datetimes. Silx works internally with
            # timestamps (floats).
            locator = NiceDateLocator(tz=self.getXAxisTimeZone())
            self.ax.xaxis.set_major_locator(locator)
            self.ax.xaxis.set_major_formatter(
                NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
        else:
            try:
                scalarFormatter = ScalarFormatter(useOffset=False)
            except:
                _logger.warning('Cannot disabled axes offsets in %s ' %
                                matplotlib.__version__)
                scalarFormatter = ScalarFormatter()
            self.ax.xaxis.set_major_formatter(scalarFormatter)

    def setXAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.1.0 when one tries to set an axis
        # to log scale with both limits <= 0
        # In this case a draw with positive limits is needed first
        if flag and matplotlib.__version__ >= '2.1.0':
            xlim = self.ax.get_xlim()
            if xlim[0] <= 0 and xlim[1] <= 0:
                self.ax.set_xlim(1, 10)
                self.draw()

        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.1.0 when one tries to set an axis
        # to log scale with both limits <= 0
        # In this case a draw with positive limits is needed first
        if flag and matplotlib.__version__ >= '2.1.0':
            redraw = False
            for axis in (self.ax, self.ax2):
                ylim = axis.get_ylim()
                if ylim[0] <= 0 and ylim[1] <= 0:
                    axis.set_ylim(1, 10)
                    redraw = True
            if redraw:
                self.draw()

        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # Data <-> Pixel coordinates conversion

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T
        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent().transformed(
            self.fig.dpi_scale_trans.inverted())
        dpi = self.fig.dpi
        # Warning this is not returning int...
        return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi,
                bbox.bounds[2] * dpi, bbox.bounds[3] * dpi)

    def setAxesDisplayed(self, displayed):
        """Display or not the axes.

        :param bool displayed: If `True` axes are displayed. If `False` axes
            are not anymore visible and the margin used for them is removed.
        """
        BackendBase.BackendBase.setAxesDisplayed(self, displayed)
        if displayed:
            # show axes and viewbox rect
            self.ax.set_axis_on()
            self.ax2.set_axis_on()
            # set the default margins
            self.ax.set_position([.15, .15, .75, .75])
            self.ax2.set_position([.15, .15, .75, .75])
        else:
            # hide axes and viewbox rect
            self.ax.set_axis_off()
            self.ax2.set_axis_off()
            # remove external margins
            self.ax.set_position([0, 0, 1, 1])
            self.ax2.set_position([0, 0, 1, 1])
        self._plot._setDirtyPlot()
Exemple #8
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """

    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        self.ax.set_axis_bgcolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()
        self.matplotlibVersion = matplotlib.__version__

        self.setGraphXLimits(0., 100.)
        self.setGraphYLimits(0., 100., axis='right')
        self.setGraphYLimits(0., 100., axis='left')

        self._enableAxis('right', False)

    # Add methods

    def addCurve(self, x, y, legend,
                 color, symbol, linewidth, linestyle,
                 yaxis,
                 xerror, yerror, z, selectable,
                 fill):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4 and
                type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (yerror is not None and yerror.ndim == 2 and
                    yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x, y, label=legend,
                                      xerr=xerror, yerr=yerror,
                                      linestyle=' ', color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x, y, label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x, y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker)
            artists.append(scatter)

            if fill:
                artists.append(axes.fill_between(
                    x, 1.0e-8, y, facecolor=actualColor[0], linestyle=''))

        else:  # Curve
            curveList = axes.plot(x, y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, 1.0e-8, y,
                                      facecolor=color, linewidth=0))

        for artist in artists:
            artist.set_zorder(z)

        return Container(artists)

    def addImage(self, data, legend,
                 origin, scale, z,
                 selectable, draggable,
                 colormap):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z,
                          selectable, draggable):
            assert parameter is not None

        h, w = data.shape[0:2]
        xmin = origin[0]
        xmax = xmin + scale[0] * w
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin
        ymin = origin[1]
        ymax = ymin + scale[1] * h
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin
        extent = (xmin, xmax, ymax, ymin)

        picker = (selectable or draggable)

        # Debian 7 specific support
        # No transparent colormap with matplotlib < 1.2.0
        # Add support for transparent colormap for uint8 data with
        # colormap with 256 colors, linear norm, [0, 255] range
        if matplotlib.__version__ < '1.2.0':
            if (len(data.shape) == 2 and colormap['name'] is None and
                    'colors' in colormap):
                colors = numpy.array(colormap['colors'], copy=False)
                if (colors.shape[-1] == 4 and
                        not numpy.all(numpy.equal(colors[3], 255))):
                    # This is a transparent colormap
                    if (colors.shape == (256, 4) and
                            colormap['normalization'] == 'linear' and
                            not colormap['autoscale'] and
                            colormap['vmin'] == 0 and
                            colormap['vmax'] == 255 and
                            data.dtype == numpy.uint8):
                        # Supported case, convert data to RGBA
                        data = colors[data.reshape(-1)].reshape(
                            data.shape + (4,))
                    else:
                        _logger.warning(
                            'matplotlib %s does not support transparent '
                            'colormap.', matplotlib.__version__)

        # the normalization can be a source of time waste
        # Two possibilities, we receive data or a ready to show image
        if len(data.shape) == 3:
            if data.shape[-1] == 4:
                # force alpha? data[:,:,3] = 255
                pass

            # RGBA image
            # TODO: Possibility to mirror the image
            # in case of pixmaps just setting
            # extend = (xmin, xmax, ymax, ymin)
            # instead of (xmin, xmax, ymin, ymax)
            extent = (xmin, xmax, ymin, ymax)
            if tuple(origin) != (0., 0.) or tuple(scale) != (1., 1.):
                # for the time being not properly handled
                imageClass = AxesImage
            elif (data.shape[0] * data.shape[1]) > 5.0e5:
                imageClass = ModestImage
            else:
                imageClass = AxesImage
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               picker=picker,
                               zorder=z)
            if image.origin == 'upper':
                image.set_extent((xmin, xmax, ymax, ymin))
            else:
                image.set_extent((xmin, xmax, ymin, ymax))
            image.set_data(data)

        else:
            assert colormap is not None

            if colormap['name'] is not None:
                cmap = self.__getColormap(colormap['name'])
            else:  # No name, use custom colors
                if 'colors' not in colormap:
                    raise ValueError(
                        'addImage: colormap no name nor list of colors.')
                colors = numpy.array(colormap['colors'], copy=True)
                assert len(colors.shape) == 2
                assert colors.shape[-1] in (3, 4)
                if colors.dtype == numpy.uint8:
                    # Convert to float in [0., 1.]
                    colors = colors.astype(numpy.float32) / 255.
                cmap = ListedColormap(colors)

            if colormap['normalization'].startswith('log'):
                vmin, vmax = None, None
                if not colormap['autoscale']:
                    if colormap['vmin'] > 0.:
                        vmin = colormap['vmin']
                    if colormap['vmax'] > 0.:
                        vmax = colormap['vmax']

                    if vmin is None or vmax is None:
                        _logger.warning('Log colormap with negative bounds, ' +
                                        'changing bounds to positive ones.')
                    elif vmin > vmax:
                        _logger.warning('Colormap bounds are inverted.')
                        vmin, vmax = vmax, vmin

                # Set unset/negative bounds to positive bounds
                if vmin is None or vmax is None:
                    finiteData = data[numpy.isfinite(data)]
                    posData = finiteData[finiteData > 0]
                    if vmax is None:
                        # 1. as an ultimate fallback
                        vmax = posData.max() if posData.size > 0 else 1.
                    if vmin is None:
                        vmin = posData.min() if posData.size > 0 else vmax
                    if vmin > vmax:
                        vmin = vmax

                norm = LogNorm(vmin, vmax)

            else:  # Linear normalization
                if colormap['autoscale']:
                    finiteData = data[numpy.isfinite(data)]
                    vmin = finiteData.min()
                    vmax = finiteData.max()
                else:
                    vmin = colormap['vmin']
                    vmax = colormap['vmax']
                    if vmin > vmax:
                        _logger.warning('Colormap bounds are inverted.')
                        vmin, vmax = vmax, vmin

                norm = Normalize(vmin, vmax)

            # try as data
            if tuple(origin) != (0., 0.) or tuple(scale) != (1., 1.):
                # for the time being not properly handled
                imageClass = AxesImage
            elif (data.shape[0] * data.shape[1]) > 5.0e5:
                imageClass = ModestImage
            else:
                imageClass = AxesImage
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               cmap=cmap,
                               extent=extent,
                               picker=picker,
                               zorder=z,
                               norm=norm)

            if image.origin == 'upper':
                image.set_extent((xmin, xmax, ymax, ymin))
            else:
                image.set_extent((xmin, xmax, ymin, ymax))

            image.set_data(data)

        self.ax.add_artist(image)

        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z):
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        if shape == "line":
            item = self.ax.plot(x, y, label=legend, color=color,
                                linestyle='-', marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color)
            if fill:
                item.set_hatch('.')

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            xView = xView.reshape(1, -1)
            yView = yView.reshape(1, -1)
            item = Polygon(numpy.vstack((xView, yView)).T,
                           closed=(shape == 'polygon'),
                           fill=False,
                           label=legend,
                           color=color)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color,
                  selectable, draggable,
                  symbol, constraint, overlay):
        legend = "__MARKER__" + legend

        if x is not None and y is not None:
            line = self.ax.plot(x, y, label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                xtmp, ytmp = self.ax.transData.transform_point((x, y))
                inv = self.ax.transData.inverted()
                xtmp, ytmp = inv.transform_point((xtmp, ytmp))

                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                line._infoText = self.ax.text(x, ytmp, text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x, label=legend, color=color)
            if text is not None:
                text = " " + text
                ymin, ymax = self.getGraphYLimits(axis='left')
                delta = abs(ymax - ymin)
                if ymin > ymax:
                    ymax = ymin
                ymax -= 0.005 * delta
                line._infoText = self.ax.text(x, ymax, text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y, label=legend, color=color)

            if text is not None:
                text = " " + text
                xmin, xmax = self.getGraphXLimits()
                delta = abs(xmax - xmin)
                if xmin > xmax:
                    xmax = xmin
                xmax -= 0.005 * delta
                line._infoText = self.ax.text(xmax, y, text,
                                              color=color,
                                              horizontalalignment='right',
                                              verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        if overlay:
            line.set_animated(True)
            self._overlays.add(line)

        return line

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        if hasattr(item, "_infoText"):  # For markers text
            item._infoText.remove()
            item._infoText = None
        self._overlays.discard(item)
        item.remove()

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(
                self.ax.get_ybound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(
                self.ax.get_xbound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setActiveCurve(self, curve, active, color=None):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if active:
                if isinstance(artist, (Line2D, LineCollection)):
                    artist._initialColor = artist.get_color()
                    artist.set_color(color)
                elif isinstance(artist, PathCollection):
                    artist._initialColor = artist.get_facecolors()
                    artist.set_facecolors(color)
                    artist.set_edgecolors(color)
                else:
                    _logger.warning(
                        'setActiveCurve ignoring artist %s', str(artist))
            else:
                if hasattr(artist, '_initialColor'):
                    if isinstance(artist, (Line2D, LineCollection)):
                        artist.set_color(artist._initialColor)
                    elif isinstance(artist, PathCollection):
                        artist.set_facecolors(artist._initialColor)
                        artist.set_edgecolors(artist._initialColor)
                    else:
                        _logger.info(
                            'setActiveCurve ignoring artist %s', str(artist))
                    del artist._initialColor

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def resetZoom(self, dataMargins):
        xAuto = self._plot.isXAxisAutoScale()
        yAuto = self._plot.isYAxisAutoScale()

        if not xAuto and not yAuto:
            _logger.debug("Nothing to autoscale")
        else:  # Some axes to autoscale
            xLimits = self.getGraphXLimits()
            yLimits = self.getGraphYLimits(axis='left')
            y2Limits = self.getGraphYLimits(axis='right')

            # Get data range
            ranges = self._plot.getDataRange()
            xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
            ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
            if ranges.yright is None:
                ymin2, ymax2 = None, None
            else:
                ymin2, ymax2 = ranges.yright

            # Add margins around data inside the plot area
            newLimits = list(_utils.addMarginsToLimits(
                dataMargins,
                self.ax.get_xscale() == 'log',
                self.ax.get_yscale() == 'log',
                xmin, xmax, ymin, ymax, ymin2, ymax2))

            if self.isKeepDataAspectRatio():
                # Compute bbox wth figure aspect ratio
                figW, figH = self.fig.get_size_inches()
                figureRatio = figH / figW

                dataRatio = (ymax - ymin) / (xmax - xmin)
                if dataRatio < figureRatio:
                    # Increase y range
                    ycenter = 0.5 * (newLimits[3] + newLimits[2])
                    yrange = (xmax - xmin) * figureRatio
                    newLimits[2] = ycenter - 0.5 * yrange
                    newLimits[3] = ycenter + 0.5 * yrange

                elif dataRatio > figureRatio:
                    # Increase x range
                    xcenter = 0.5 * (newLimits[1] + newLimits[0])
                    xrange_ = (ymax - ymin) / figureRatio
                    newLimits[0] = xcenter - 0.5 * xrange_
                    newLimits[1] = xcenter + 0.5 * xrange_

            self.setLimits(*newLimits)

            if not xAuto and yAuto:
                self.setGraphXLimits(*xLimits)
            elif xAuto and not yAuto:
                if y2Limits is not None:
                    self.setGraphYLimits(
                        y2Limits[0], y2Limits[1], axis='right')
                if yLimits is not None:
                    self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

    # Graph axes

    def setXAxisLogarithmic(self, flag):
        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # colormap

    def getSupportedColormaps(self):
        default = super(BackendMatplotlib, self).getSupportedColormaps()
        maps = [m for m in cm.datad]
        maps.sort()
        return default + maps

    def __getColormap(self, name):
        if not self._colormaps:  # Lazy initialization of own colormaps
            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 1.0, 1.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            self._colormaps['red'] = LinearSegmentedColormap(
                'red', cdict, 256)

            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 0.0, 0.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 1.0, 1.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            self._colormaps['green'] = LinearSegmentedColormap(
                'green', cdict, 256)

            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 0.0, 0.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 1.0, 1.0))}
            self._colormaps['blue'] = LinearSegmentedColormap(
                'blue', cdict, 256)

            # Temperature as defined in spslut
            cdict = {'red': ((0.0, 0.0, 0.0),
                             (0.5, 0.0, 0.0),
                             (0.75, 1.0, 1.0),
                             (1.0, 1.0, 1.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (0.25, 1.0, 1.0),
                               (0.75, 1.0, 1.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 1.0, 1.0),
                              (0.25, 1.0, 1.0),
                              (0.5, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            # but limited to 256 colors for a faster display (of the colorbar)
            self._colormaps['temperature'] = LinearSegmentedColormap(
                'temperature', cdict, 256)

            # reversed gray
            cdict = {'red':     ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0)),
                     'green':   ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0)),
                     'blue':    ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0))}

            self._colormaps['reversed gray'] = LinearSegmentedColormap(
                'yerg', cdict, 256)

        if name in self._colormaps:
            return self._colormaps[name]
        elif hasattr(MPLColormap, name):  # viridis and sister colormaps
            return getattr(MPLColormap, name)
        else:
            # matplotlib built-in
            return cm.get_cmap(name)

    # Data <-> Pixel coordinates conversion

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T
        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent().transformed(
            self.fig.dpi_scale_trans.inverted())
        dpi = self.fig.dpi
        # Warning this is not returning int...
        return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi,
                bbox.bounds[2] * dpi, bbox.bounds[3] * dpi)
Exemple #9
0
class BackendMPL(FigureCanvas, Backend):
    """matplotlib backend"""

    _signalRedisplay = QtCore.pyqtSignal()  # PyQt binds it to instances

    def __init__(self, plot, parent=None, **kwargs):
        Backend.__init__(self, plot)

        self.fig = Figure()
        self.fig.set_facecolor('w')
        FigureCanvas.__init__(self, self.fig)
        FigureCanvas.setSizePolicy(self,
                                   QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)

        self._items = {}  # Mapping: plot item -> matplotlib item

        # Set-up axes
        leftAxes = self.fig.add_axes([.15, .15, .75, .75], label="left")
        rightAxes = leftAxes.twinx()
        rightAxes.set_label("right")

        # critical for picking!!!!
        rightAxes.set_zorder(0)
        rightAxes.set_autoscaley_on(True)
        leftAxes.set_zorder(1)
        # this works but the figure color is left
        leftAxes.set_axis_bgcolor('none')
        self.fig.sca(leftAxes)

        self._axes = {self.plot.axes.left: leftAxes,
                      self.plot.axes.right: rightAxes}

        # Sync matplotlib and plot axes
        for plotAxes, mplAxes in self._axes.items():
            self._syncAxes(mplAxes, plotAxes)

        # TODO sync all the plot items to support backend switch !

        # Set-up events

        self.fig.canvas.mpl_connect('button_press_event',
                                    self.onMousePressed)
        self.fig.canvas.mpl_connect('button_release_event',
                                    self.onMouseReleased)
        self.fig.canvas.mpl_connect('motion_notify_event',
                                    self.onMouseMoved)
        self.fig.canvas.mpl_connect('scroll_event',
                                    self.onMouseWheel)

        # Connect draw to redisplay
        self._signalRedisplay.connect(
            self.draw, QtCore.Qt.QueuedConnection)

    @staticmethod
    def _syncAxes(mplAxes, plotAxes):
        """Sync matplotlib Axes with the plot Axes"""
        mplAxes.set_xlabel(plotAxes.xlabel)
        mplAxes.set_xlim(plotAxes.xlimits)
        mplAxes.set_xscale(plotAxes.xscale)

        mplAxes.set_ylabel(plotAxes.ylabel)
        mplAxes.set_ylim(plotAxes.ylimits)
        mplAxes.set_yscale(plotAxes.yscale)

    def triggerRedisplay(self):
        self._signalRedisplay.emit()

    def draw(self):
        # Apply all modifications before redraw
        for change in self._changes:
            if change['event'] == 'addItem':
                self._addItem(change['source'], change['item'])
            elif change['event'] == 'removeItem':
                self._removeItem(change)
            elif change['event'] == 'set':
                self._setAttr(
                    change['source'], change['attr'], change['value'])
            else:
                logger.warning('Unhandled event %s' % change['event'])

        Backend.draw(self)
        FigureCanvas.draw(self)

    def onMousePressed(self, event):
        pass  # TODO

    def onMouseMoved(self, event):
        pass  # TODO

    def onMouseReleased(self, event):
        pass  # TODO

    def onMouseWheel(self, event):
        pass  # TODO

    def _addItem(self, axes, item):
        mplAxes = self._axes[axes]

        if isinstance(item, items.Curve):
            x, y = item.getData(copy=False)
            line = Line2D(xdata=x, ydata=y,
                          color=item.color,
                          marker=item.marker,
                          linewidth=item.linewidth,
                          linestyle=item.linestyle,
                          zorder=item.z)
            # TODO error bars, scatter plot..
            # TODO set picker
            mplAxes.add_line(line)
            self._items[item] = line

        elif isinstance(item, items.Image):
            # TODO ModestImage, set picker
            data = item.getData(copy=False)

            if len(data.shape) == 3:  # RGB(A) images
                image = AxesImage(mplAxes,
                                  origin='lower',
                                  interpolation='nearest')
            else:  # Colormap
                # TODO use own colormaps
                cmap = cm.get_cmap(item.colormap.cmap)
                if item.colormap.norm == 'log':
                    norm = LogNorm(item.colormap.vbegin, item.colormap.vend)
                else:
                    norm = Normalize(item.colormap.vbegin, item.colormap.vend)
                image = AxesImage(mplAxes,
                                  origin='lower',
                                  cmap=cmap,
                                  norm=norm,
                                  interpolation='nearest')
            image.set_data(data)
            image.set_zorder(item.z)

            height, width = data.shape[0:2]
            xmin, ymin = item.origin
            xmax = xmin + item.scale[0] * width
            ymax = xmax + item.scale[1] * height

            # set extent (left, right, bottom, top)
            if image.origin == 'upper':
                image.set_extent((xmin, xmax, ymax, ymin))
            else:
                image.set_extent((xmin, xmax, ymin, ymax))

            mplAxes.add_artist(image)
            self._items[item] = image

        else:
            logger.warning('Unsupported item type %s' % str(type(item)))

    def _removeItem(self, axes, item):
        mplItem = self._items.pop(item)
        mplItem.remove()

    def _setAttr(self, obj, attr, value):
        if isinstance(obj, plot.Axis):
            plotAxes = obj.parents[0]
            if obj == plotAxes.x:
                direction = 'x'
            elif obj == plotAxes.y:
                direction = 'y'
            else:
                logging.warning('Incoherent axes information.')
                return

            mplAxes = self._axes[plotAxes]

            if attr == 'label':
                if direction == 'x':
                    mplAxes.set_xlabel(value)
                else:
                    mplAxes.set_ylabel(value)

            elif attr == 'scale':
                if direction == 'x':
                    mplAxes.set_xscale(value)
                else:
                    mplAxes.set_yscale(value)

            elif attr == 'limits':
                if direction == 'x':
                    mplAxes.set_xlim(value)
                else:
                    mplAxes.set_ylim(value)

            else:
                logger.warning('Unsupported attribute %s' % attr)

        elif isinstance(obj, plot.Axes):
            mplAxes = self._axes[obj]

            if attr == 'visible':
                # For twin axes, Axes can be not visible but Axis is visible
                # So use Axis (and not Axes) visible value.
                mplAxes.get_xaxis().set_visible(obj.x.visible)
                mplAxes.get_yaxis().set_visible(obj.y.visible)

            elif attr == 'aspectRatio':
                mplAxes.set_aspect('equal' if value else 'auto')

            else:
                logger.warning('Unsupported attribute %s' % attr)

        elif isinstance(obj, plot.Plot):
            leftAxes = self._axes[self.plot.axes.left]

            if attr == 'title':
                # Set title on left axes which should not be hidden
                leftAxes.set_title(value)

            elif attr == 'grid':
                leftAxes.grid(which=value)

            else:
                logger.warning('Unsupported attribute %s' % attr)

        elif isinstance(obj, items.Curve):  # TODO
            logger.warning('Unsupported item type %s' % str(type(obj)))

        elif isinstance(obj, items.Image):  # TODO
            logger.warning('Unsupported item type %s' % str(type(obj)))

        else:
            logger.warning('Unsupported item type %s' % str(type(obj)))
Exemple #10
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """

    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True
        self._axesDisplayed = True
        self._matplotlibVersion = _parse_version(matplotlib.__version__)

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # disable the use of offsets
        try:
            self.ax.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax.get_xaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
            self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
        except:
            _logger.warning('Cannot disabled axes offsets in %s '
                            % matplotlib.__version__)

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        if self._matplotlibVersion < _parse_version('2'):
            self.ax.set_axis_bgcolor('none')
        else:
            self.ax.set_facecolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()

        self._enableAxis('right', False)
        self._isXAxisTimeSeries = False

    # Add methods

    def addCurve(self, x, y, legend,
                 color, symbol, linewidth, linestyle,
                 yaxis,
                 xerror, yerror, z, selectable,
                 fill, alpha, symbolsize):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill, alpha, symbolsize):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4 and
                type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
                    yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x, y, label=legend,
                                      xerr=xerror, yerr=yerror,
                                      linestyle=' ', color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x, y, label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x, y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker,
                                   s=symbolsize**2)
            artists.append(scatter)

            if fill:
                artists.append(axes.fill_between(
                    x, FLOAT32_MINPOS, y, facecolor=actualColor[0], linestyle=''))

        else:  # Curve
            curveList = axes.plot(x, y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker,
                                  markersize=symbolsize)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color))

        for artist in artists:
            artist.set_zorder(z)
            if alpha < 1:
                artist.set_alpha(alpha)

        return Container(artists)

    def addImage(self, data, legend,
                 origin, scale, z,
                 selectable, draggable,
                 colormap, alpha):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z,
                          selectable, draggable):
            assert parameter is not None

        origin = float(origin[0]), float(origin[1])
        scale = float(scale[0]), float(scale[1])
        height, width = data.shape[0:2]

        picker = (selectable or draggable)

        # Debian 7 specific support
        # No transparent colormap with matplotlib < 1.2.0
        # Add support for transparent colormap for uint8 data with
        # colormap with 256 colors, linear norm, [0, 255] range
        if self._matplotlibVersion < _parse_version('1.2.0'):
            if (len(data.shape) == 2 and colormap.getName() is None and
                    colormap.getColormapLUT() is not None):
                colors = colormap.getColormapLUT()
                if (colors.shape[-1] == 4 and
                        not numpy.all(numpy.equal(colors[3], 255))):
                    # This is a transparent colormap
                    if (colors.shape == (256, 4) and
                            colormap.getNormalization() == 'linear' and
                            not colormap.isAutoscale() and
                            colormap.getVMin() == 0 and
                            colormap.getVMax() == 255 and
                            data.dtype == numpy.uint8):
                        # Supported case, convert data to RGBA
                        data = colors[data.reshape(-1)].reshape(
                            data.shape + (4,))
                    else:
                        _logger.warning(
                            'matplotlib %s does not support transparent '
                            'colormap.', matplotlib.__version__)

        if ((height * width) > 5.0e5 and
                origin == (0., 0.) and scale == (1., 1.)):
            imageClass = ModestImage
        else:
            imageClass = AxesImage

        # All image are shown as RGBA image
        image = imageClass(self.ax,
                           label="__IMAGE__" + legend,
                           interpolation='nearest',
                           picker=picker,
                           zorder=z,
                           origin='lower')

        if alpha < 1:
            image.set_alpha(alpha)

        # Set image extent
        xmin = origin[0]
        xmax = xmin + scale[0] * width
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin

        ymin = origin[1]
        ymax = ymin + scale[1] * height
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin

        image.set_extent((xmin, xmax, ymin, ymax))

        # Set image data
        if scale[0] < 0. or scale[1] < 0.:
            # For negative scale, step by -1
            xstep = 1 if scale[0] >= 0. else -1
            ystep = 1 if scale[1] >= 0. else -1
            data = data[::ystep, ::xstep]

        if self._matplotlibVersion < _parse_version('2.1'):
            # matplotlib 1.4.2 do not support float128
            dtype = data.dtype
            if dtype.kind == "f" and dtype.itemsize >= 16:
                _logger.warning("Your matplotlib version do not support "
                                "float128. Data converted to float64.")
                data = data.astype(numpy.float64)

        if data.ndim == 2:  # Data image, convert to RGBA image
            data = colormap.applyToData(data)

        image.set_data(data)

        self.ax.add_artist(image)

        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z,
                linestyle, linewidth, linebgcolor):
        if (linebgcolor is not None and
                shape not in ('rectangle', 'polygon', 'polylines')):
            _logger.warning(
                'linebgcolor not implemented for %s with matplotlib backend',
                shape)
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        linestyle = normalize_linestyle(linestyle)

        if shape == "line":
            item = self.ax.plot(x, y, label=legend, color=color,
                                linestyle=linestyle, linewidth=linewidth,
                                marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color,
                                   linestyle=linestyle, linewidth=linewidth)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color,
                             linestyle=linestyle,
                             linewidth=linewidth)
            if fill:
                item.set_hatch('.')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            points = numpy.array((xView, yView)).T
            if shape == 'polygon':
                closed = True
            else:  # shape == 'polylines'
                closed = numpy.all(numpy.equal(points[0], points[-1]))
            item = Polygon(points,
                           closed=closed,
                           fill=False,
                           label=legend,
                           color=color,
                           linestyle=linestyle,
                           linewidth=linewidth)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            if linestyle != "solid" and linebgcolor is not None:
                item = _DoubleColoredLinePatch(item)
                item.linebgcolor = linebgcolor

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color,
                  selectable, draggable,
                  symbol, linestyle, linewidth, constraint):
        legend = "__MARKER__" + legend

        textArtist = None

        xmin, xmax = self.getGraphXLimits()
        ymin, ymax = self.getGraphYLimits(axis='left')

        if x is not None and y is not None:
            line = self.ax.plot(x, y, label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                textArtist = self.ax.text(x, y, text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x,
                                   label=legend,
                                   color=color,
                                   linewidth=linewidth,
                                   linestyle=linestyle)
            if text is not None:
                # Y position will be updated in updateMarkerText call
                textArtist = self.ax.text(x, 1., " " + text,
                                          color=color,
                                          horizontalalignment='left',
                                          verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y,
                                   label=legend,
                                   color=color,
                                   linewidth=linewidth,
                                   linestyle=linestyle)

            if text is not None:
                # X position will be updated in updateMarkerText call
                textArtist = self.ax.text(1., y, " " + text,
                                          color=color,
                                          horizontalalignment='right',
                                          verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        # All markers are overlays
        line.set_animated(True)
        if textArtist is not None:
            textArtist.set_animated(True)

        artists = [line] if textArtist is None else [line, textArtist]
        container = _MarkerContainer(artists, x, y)
        container.updateMarkerText(xmin, xmax, ymin, ymax)
        self._overlays.add(container)

        return container

    def _updateMarkers(self):
        xmin, xmax = self.ax.get_xbound()
        ymin, ymax = self.ax.get_ybound()
        for item in self._overlays:
            if isinstance(item, _MarkerContainer):
                item.updateMarkerText(xmin, xmax, ymin, ymax)

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        self._overlays.discard(item)
        try:
            item.remove()
        except ValueError:
            pass  # Already removed e.g., in set[X|Y]AxisLogarithmic

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(
                self.ax.get_ybound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(
                self.ax.get_xbound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setCurveColor(self, curve, color):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if isinstance(artist, (Line2D, LineCollection)):
                artist.set_color(color)
            elif isinstance(artist, PathCollection):
                artist.set_facecolors(color)
                artist.set_edgecolors(color)
            else:
                _logger.warning(
                    'setActiveCurve ignoring artist %s', str(artist))

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

        self._updateMarkers()

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
        self._updateMarkers()

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

        self._updateMarkers()

    # Graph axes

    def setXAxisTimeZone(self, tz):
        super(BackendMatplotlib, self).setXAxisTimeZone(tz)

        # Make new formatter and locator with the time zone.
        self.setXAxisTimeSeries(self.isXAxisTimeSeries())

    def isXAxisTimeSeries(self):
        return self._isXAxisTimeSeries

    def setXAxisTimeSeries(self, isTimeSeries):
        self._isXAxisTimeSeries = isTimeSeries
        if self._isXAxisTimeSeries:
            # We can't use a matplotlib.dates.DateFormatter because it expects
            # the data to be in datetimes. Silx works internally with
            # timestamps (floats).
            locator = NiceDateLocator(tz=self.getXAxisTimeZone())
            self.ax.xaxis.set_major_locator(locator)
            self.ax.xaxis.set_major_formatter(
                NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
        else:
            try:
                scalarFormatter = ScalarFormatter(useOffset=False)
            except:
                _logger.warning('Cannot disabled axes offsets in %s ' %
                                matplotlib.__version__)
                scalarFormatter = ScalarFormatter()
            self.ax.xaxis.set_major_formatter(scalarFormatter)

    def setXAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.1.0 when one tries to set an axis
        # to log scale with both limits <= 0
        # In this case a draw with positive limits is needed first
        if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
            xlim = self.ax.get_xlim()
            if xlim[0] <= 0 and xlim[1] <= 0:
                self.ax.set_xlim(1, 10)
                self.draw()

        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        # Workaround for matplotlib 2.0 issue with negative bounds
        # before switching to log scale
        if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
            redraw = False
            for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
                ylim = axis.get_ylim()
                if ylim[0] <= 0 or ylim[1] <= 0:
                    dataRange = self._plot.getDataRange()[dataRangeIndex]
                    if dataRange is None:
                        dataRange = 1, 100  # Fallback
                    axis.set_ylim(*dataRange)
                    redraw = True
            if redraw:
                self.draw()

        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # Data <-> Pixel coordinates conversion

    def _mplQtYAxisCoordConversion(self, y):
        """Qt origin (top) to/from matplotlib origin (bottom) conversion.

        :rtype: float
        """
        height = self.fig.get_window_extent().height
        return height - y

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T

        # Convert from matplotlib origin (bottom) to Qt origin (top)
        yPixel = self._mplQtYAxisCoordConversion(yPixel)

        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        # Convert from Qt origin (top) to matplotlib origin (bottom)
        y = self._mplQtYAxisCoordConversion(y)

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent()
        # Warning this is not returning int...
        return (bbox.xmin,
                self._mplQtYAxisCoordConversion(bbox.ymax),
                bbox.width,
                bbox.height)

    def setAxesDisplayed(self, displayed):
        """Display or not the axes.

        :param bool displayed: If `True` axes are displayed. If `False` axes
            are not anymore visible and the margin used for them is removed.
        """
        BackendBase.BackendBase.setAxesDisplayed(self, displayed)
        if displayed:
            # show axes and viewbox rect
            self.ax.set_axis_on()
            self.ax2.set_axis_on()
            # set the default margins
            self.ax.set_position([.15, .15, .75, .75])
            self.ax2.set_position([.15, .15, .75, .75])
        else:
            # hide axes and viewbox rect
            self.ax.set_axis_off()
            self.ax2.set_axis_off()
            # remove external margins
            self.ax.set_position([0, 0, 1, 1])
            self.ax2.set_position([0, 0, 1, 1])
        self._synchronizeBackgroundColors()
        self._synchronizeForegroundColors()
        self._plot._setDirtyPlot()

    def _synchronizeBackgroundColors(self):
        backgroundColor = self._plot.getBackgroundColor().getRgbF()

        dataBackgroundColor = self._plot.getDataBackgroundColor()
        if dataBackgroundColor.isValid():
            dataBackgroundColor = dataBackgroundColor.getRgbF()
        else:
            dataBackgroundColor = backgroundColor

        if self.ax.axison:
            self.fig.patch.set_facecolor(backgroundColor)
            if self._matplotlibVersion < _parse_version('2'):
                self.ax.set_axis_bgcolor(dataBackgroundColor)
            else:
                self.ax.set_facecolor(dataBackgroundColor)
        else:
            self.fig.patch.set_facecolor(dataBackgroundColor)

    def _synchronizeForegroundColors(self):
        foregroundColor = self._plot.getForegroundColor().getRgbF()

        gridColor = self._plot.getGridColor()
        if gridColor.isValid():
            gridColor = gridColor.getRgbF()
        else:
            gridColor = foregroundColor

        if self.ax.axison:
            self.ax.spines['bottom'].set_color(foregroundColor)
            self.ax.spines['top'].set_color(foregroundColor)
            self.ax.spines['right'].set_color(foregroundColor)
            self.ax.spines['left'].set_color(foregroundColor)
            self.ax.tick_params(axis='x', colors=foregroundColor)
            self.ax.tick_params(axis='y', colors=foregroundColor)
            self.ax.yaxis.label.set_color(foregroundColor)
            self.ax.xaxis.label.set_color(foregroundColor)
            self.ax.title.set_color(foregroundColor)

            for line in self.ax.get_xgridlines():
                line.set_color(gridColor)

            for line in self.ax.get_ygridlines():
                line.set_color(gridColor)
Exemple #11
0
class ScrollableWindow(QMainWindow):
    cell_select = pyqtSignal(int, int)

    def __init__(self):
        self.qapp = QApplication([])

        QMainWindow.__init__(self)
        self.widget = QWidget()
        self.setCentralWidget(self.widget)
        self.widget.setLayout(QVBoxLayout())
        # self.widget.layout().setContentsMargins(0,0,0,0)
        # self.widget.layout().setSpacing(0)

        self.fig = Figure(dpi=200)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.draw()
        self.scroll = QScrollArea(self.widget)
        self.scroll.setWidget(self.canvas)

        # self.nav = NavigationToolbar(self.canvas, self.widget)
        # self.widget.layout().addWidget(self.nav)
        self.widget.layout().addWidget(self.scroll)

        self.axes = self.fig.add_subplot(111)

        self.fig.sca(self.axes)

        self.text_list = None
        self.cbar = None
        self.i = None
        self.j = None

        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.sel_rect = None

        self.fig.canvas.mpl_connect('scroll_event', self._on_mousewheel)

        # self.canvas.Bind("<Button-4>", self._on_mousewheel)
        # self.canvas.Bind("<Button-5>", self._on_mousewheel)

        ############################################################################
        #
        # self.fig = Figure(figsize=(width, height), dpi=dpi)
        # Canvas.__init__(self, self.fig)
        # self.axes = self.fig.add_subplot(111)
        # self.setParent(parent)
        # self.text_list = None
        # self.cbar = None
        # self.i = None
        # self.j = None
        #
        # self.scroll = QScrollArea(self)
        # self.scroll.setWidget(self)
        #
        # cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        # self.sel_rect = None
    def _on_mousewheel(self, event):
        vbar = self.scroll.verticalScrollBar()
        cur_val = vbar.value()
        max_val = vbar.maximum()

        if event.button == 'up':
            new_val = cur_val - max_val / 10
        else:
            new_val = cur_val + max_val / 10
        vbar.setValue(new_val)

    def get_clicked_symbol(self):
        if self.i is None:
            return -1
        else:
            return self.i

    def get_clicked_iv_rank(self):
        return self.j

    def heatmap(self,
                data,
                row_labels,
                col_labels,
                ax=None,
                cbar_kw={},
                cbarlabel="",
                **kwargs):
        """
        Create a heatmap from a numpy array and two lists of labels.

        Arguments:
            data       : A 2D numpy array of shape (N,M)
            row_labels : A list or array of length N with the labels
                         for the rows
            col_labels : A list or array of length M with the labels
                         for the columns
        Optional arguments:
            ax         : A matplotlib.axes.Axes instance to which the heatmap
                         is plotted. If not provided, use current axes or
                         create a new one.
            cbar_kw    : A dictionary with arguments to
                         :meth:`matplotlib.Figure.colorbar`.
            cbarlabel  : The label for the colorbar
        All other arguments are directly passed on to the imshow call.
        """

        if not ax:
            ax = plt.gca()

        # Plot the heatmap
        im = ax.imshow(data, **kwargs, vmin=0, vmax=100)

        # Create colorbar
        # cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
        # cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
        cbar = None

        # We want to show all ticks...
        ax.set_xticks(np.arange(data.shape[1]))
        ax.set_yticks(np.arange(data.shape[0]))
        # ... and label them with the respective list entries.
        ax.xaxis.set_ticks_position('top')
        ax.set_xticklabels(col_labels)
        ax.set_yticklabels(row_labels)

        # Let the horizontal axes labeling appear on top.
        ax.tick_params(top=False,
                       bottom=True,
                       labeltop=False,
                       labelbottom=True)

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(),
                 rotation=0,
                 ha="right",
                 rotation_mode="anchor")

        # Turn spines off and create white grid.
        for edge, spine in ax.spines.items():
            spine.set_visible(False)

        ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
        ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
        ax.tick_params(which="minor", bottom=False, left=False)

        return im, cbar

    def annotate_heatmap(self,
                         im,
                         data=None,
                         valfmt="{x:.2f}",
                         textcolors=["black", "black"],
                         threshold=None,
                         **textkw):
        """
        A function to annotate a heatmap.

        Arguments:
            im         : The AxesImage to be labeled.
        Optional arguments:
            data       : Data used to annotate. If None, the image's data is used.
            valfmt     : The format of the annotations inside the heatmap.
                         This should either use the string format method, e.g.
                         "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
            textcolors : A list or array of two color specifications. The first is
                         used for values below a threshold, the second for those
                         above.
            threshold  : Value in data units according to which the colors from
                         textcolors are applied. If None (the default) uses the
                         middle of the colormap as separation.

        Further arguments are passed on to the created text labels.
        """

        if not isinstance(data, (list, np.ndarray)):
            data = im.get_array()

        # Normalize the threshold to the images color range.
        if threshold is not None:
            threshold = im.norm(threshold)
        else:
            threshold = im.norm(data.max()) / 2.

        # Set default alignment to center, but allow it to be
        # overwritten by textkw.
        kw = dict(horizontalalignment="center", verticalalignment="center")
        kw.update(textkw)

        # Get the formatter in case a string is supplied
        if isinstance(valfmt, str):
            valfmt = ticker.StrMethodFormatter(valfmt)

        # Loop over the data and create a `Text` for each "pixel".
        # Change the text's color depending on the data.
        texts = []

        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                kw.update(color=textcolors[im.norm(data[i, j]) > threshold])
                text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
                texts.append(text)

        return texts

    def plot(self, data_in, x_labels, y_labels, title):
        """
        Add volume: https://stackoverflow.com/questions/13128647/matplotlib-finance-volume-overlay
        :param data_in: list of list [[bar.date, bar.open, bar.high, bar.low, bar.close, bar.volume], [...], ...]
        :return:
        """
        self.axes.clear()
        if self.cbar is not None:
            self.cbar.remove()
        self.axes.set_title(title)

        im, self.cbar = self.heatmap(data_in,
                                     y_labels,
                                     x_labels,
                                     ax=self.axes,
                                     cmap="YlGn",
                                     cbarlabel=title)
        self.text_list = self.annotate_heatmap(im, valfmt="{x:.0f}")

        self.fig.autofmt_xdate()
        # self.fig.tight_layout()
        self.canvas.draw()

    def highlight_selection(self, row, col):
        if self.sel_rect is None:
            self.sel_rect = Rectangle((col - 0.5 + 0.1, row - 0.5 + 0.1),
                                      width=0.8,
                                      height=0.8,
                                      edgecolor='red',
                                      fill=False,
                                      linewidth=4)
            self.axes.add_patch(self.sel_rect)
        else:
            self.sel_rect.set_xy((col - 0.5 + 0.1, row - 0.5 + 0.1))
        self.canvas.draw()

    def onclick(self, event):
        print('Event')
        if event.ydata is not None and event.xdata is not None:
            self.i = int(round(event.ydata))  # Rows (1. Index = 0)
            self.j = int(round(event.xdata))  # Columns (1. Index = 0)
            self.cell_select.emit(self.i, self.j)
            # self.highlight_selection(self.i, self.j)
            print('{}, {} '.format(self.i, self.j))
Exemple #12
0
class PriceNavPlot(QWidget):
    def __init__(self, parent=None, mw=None, width=5, height=4, dpi=100):
        super(QWidget, self).__init__(parent)
        self.fig = Figure()
        self.canvas = Canvas(self.fig)
        self.canvas.setParent(parent)
        self.toolbar = NavigationToolbar(self.canvas, self)
        # set the layout
        l1 = QVBoxLayout(parent)
        l1.addWidget(self.toolbar)
        l1.addWidget(self.canvas)

        self.axes = self.fig.add_subplot(111)
        self.axes2 = self.axes.twinx()
        self.fig.sca(self.axes)

        self.setParent(parent)

        self.data_list = []
        self.legend_list = []
        self.data_list_2 = []
        self.legend_list_2 = []
        self.title = ''

        self.left_lim = None
        self.right_lim = None

        self.plot_type = ''

    def set_xlim(self, left_lim, right_lim):
        self.left_lim = left_lim
        self.right_lim = right_lim
        if self.plot_type == 'line':
            self._plot_line()
        if self.plot_type == 'candlestick':
            self._plot_candlestick()

    def plot_candlestick(self, data_in, legend=' ', title=' '):
        self.plot_type = 'candlestick'
        self.data_list.append(data_in)
        self.legend_list.append(legend)
        self.title = title
        self._plot_candlestick()

    def _plot_candlestick(self):
        """
        Add volume: https://stackoverflow.com/questions/13128647/matplotlib-finance-volume-overlay
        :param data_in: list of list [[bar.date, bar.open, bar.high, bar.low, bar.close, bar.volume], [...], ...]
        :return:
        """
        # Loop over all candles
        # data_in = [1,1]
        ohlc_data = []
        self.axes.cla()
        for data_in in self.data_list:
            for i in range(0, len(data_in)):
                ohlc = []
                ohlc.append(date2num(data_in[i][BAR_DICT['Date']]))
                ohlc.append(float(data_in[i][1]))
                ohlc.append(float(data_in[i][2]))
                ohlc.append(float(data_in[i][3]))
                ohlc.append(float(data_in[i][4]))
                ohlc.append(float(data_in[i][5]))
                ohlc_data.append(ohlc)

            candlestick_ohlc(self.axes,
                             ohlc_data,
                             width=0.4,
                             colorup='#77d879',
                             colordown='#db3f3f')
        if self.right_lim is not None and self.left_lim is not None:
            self.axes.set_xlim(self.left_lim, self.right_lim)
        self.axes.set_title(self.title)
        self.axes.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        self.axes.xaxis.set_major_locator(mticker.MaxNLocator(10))
        self.fig.autofmt_xdate()
        self.axes.grid(True)
        self.canvas.draw()

    def clear(self):
        self.data_list = []
        self.data_list_2 = []
        self.legend_list = []
        self.legend_list_2 = []
        self.title = ''

    def plot_line(self, data_in, legend=' ', title=' ', twin=False):
        self.plot_type = 'line'
        self.title = title
        if twin is False:
            self.data_list.append(data_in)
        else:
            self.data_list_2.append(data_in)
            self.legend_list_2.append(legend)
        self.legend_list.append(legend)
        self._plot_line()

    def _plot_line(self):
        self.axes.cla()
        self.axes2.cla()

        for data in self.data_list:
            x = data[:, BAR_DICT['Date']]
            y = data[:, 1]
            self.axes.plot(x, y)
        for data in self.data_list_2:
            x = data[:, BAR_DICT['Date']]
            y = data[:, 1]
            self.axes2.plot(x, y, color=(1, 0, 0))
        self.axes2.set_ylabel(self.legend_list_2, color=(1, 0, 0))
        # self.axes2.legend(self.legend_list_2)

        if self.right_lim is not None and self.left_lim is not None:
            self.axes.set_xlim(self.left_lim, self.right_lim)
            self.axes2.set_xlim(self.left_lim, self.right_lim)

        self.axes.set_title(self.title)
        self.axes.legend(self.legend_list)
        self.axes.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        self.axes.xaxis.set_major_locator(mticker.MaxNLocator(10))
        self.fig.autofmt_xdate()
        self.axes.grid(True)
        self.canvas.draw()