Example #1
0
def set_title(fig: Figure, ax: Axes, title: str, digital: bool):
    """Sets the title for the given axes to the given value. This is more
    particular than the default matplotlib Axes.set_title.

    :param Figure fig: the figure the axes is in
    :param Axes ax: the axes which you want to have a title
    :param str title: the desired title text
    :param bool digital: if True, a large font size is selected. Otherwise,
        a smaller font size is selected

    :returns: TextCollection that was added
    """

    figw = fig.get_figwidth()
    figh = fig.get_figheight()
    figw_px = figw * fig.get_dpi()

    pad = max(int(0.3125 * figh), 1)

    font_size = int((8 / 1.92) * figw) if digital else int((4 / 1.92) * figw)
    font_size = max(5, font_size)

    axtitle = ax.set_title(title, pad=pad)
    axtitle.set_fontsize(font_size)
    renderer = fig.canvas.get_renderer()
    bb = axtitle.get_window_extent(renderer=renderer)
    while bb.width >= (figw_px - 26) * 0.9 and font_size > 9:
        font_size = max(5, font_size - 5)
        axtitle.set_fontsize(font_size)
        bb = axtitle.get_window_extent(renderer=renderer)
    return axtitle
Example #2
0
class MatplotlibWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self,
                 parent=None,
                 name=None,
                 width=5,
                 height=4,
                 dpi=100,
                 bgcolor=None):
        self.parent = parent
        #if self.parent:
        #bgc = parent.backgroundBrush().color()
        #bgcolor = float(bgc.red())/255.0, float(bgc.green())/255.0, float(bgc.blue())/255.0
        #bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())

        self.fig = Figure(figsize=(width, height),
                          dpi=dpi)  #, facecolor=bgcolor, edgecolor=bgcolor)
        self.axes = self.fig.add_subplot(111)
        # We want the axes cleared every time plot() is called
        self.axes.hold(False)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QSize(w, h)

    def minimumSizeHint(self):
        return QSize(10, 10)
Example #3
0
class MatplotlibWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent=None, name=None, width=5, height=4, dpi=100, bgcolor=None):
	self.parent = parent
	if self.parent:
		bgc = parent.backgroundBrush().color()
		bgcolor = float(bgc.red())/255.0, float(bgc.green())/255.0, float(bgc.blue())/255.0
		#bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())

        self.fig = Figure(figsize=(width, height), dpi=dpi, facecolor=bgcolor, edgecolor=bgcolor)
        self.axes = self.fig.add_subplot(111)
        # We want the axes cleared every time plot() is called
        self.axes.hold(False)

        self.compute_initial_figure()
        
        FigureCanvas.__init__(self, self.fig)
        self.reparent(parent, QPoint(0, 0))

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QSize(w, h)

    def minimumSizeHint(self):
        return QSize(10, 10)

    def compute_initial_figure(self):
        pass
Example #4
0
class MatplotlibWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self,
                 parent=None,
                 name=None,
                 width=5,
                 height=4,
                 dpi=100,
                 bgcolor=None,
                 mplPars=None):

        self.parent = parent

        if self.parent:
            bgc = parent.palette().color(QtGui.QPalette.Background)
            bgcolor = float(bgc.red()) / 255.0, float(
                bgc.green()) / 255.0, float(bgc.blue()) / 255.0
            #bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())

            if mplPars:
                matplotlib.rcParams.update(mplPars)

            self.fig = Figure(figsize=(width, height),
                              dpi=dpi,
                              facecolor=bgcolor,
                              edgecolor=bgcolor)
            self.axes = self.fig.add_subplot(111)

            self.compute_initial_figure()

            FigureCanvas.__init__(self, self.fig)
            self.setParent(parent)

            FigureCanvas.setSizePolicy(self, QtGui.QSizePolicy.Expanding,
                                       QtGui.QSizePolicy.Expanding)
            FigureCanvas.updateGeometry(self)

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QtCore.QSize(w, h)

    def minimumSizeHint(self):
        return QtCore.QSize(10, 10)

    def compute_initial_figure(self):
        t = np.arange(0.0, 3.0, 0.01)
        s = np.sin(2 * np.pi * t)
        self.axes.plot(t, s)
Example #5
0
class EJcanvas(FigureCanvas):
    """EJcanvas is designed for ErwinJr as a canvas for energy band and
    wavefunctions"""
    def __init__(self, xlabel='x', ylabel='y', parent=None):
        self.figure = Figure()
        super(EJcanvas, self).__init__(self.figure)
        #  NavigationToolbar2.__init__(self, self)
        self.setParent(parent)
        self.setMinimumWidth(200)
        self.setMinimumHeight(200)
        self.setSizePolicy(QSizePolicy.MinimumExpanding,
                           QSizePolicy.MinimumExpanding)
        self.updateGeometry()
        self.axes = self.figure.add_subplot(111)
        self.xlabel = xlabel
        self.ylabel = ylabel

    def set_axes(self, fsize=config["fontsize"]):
        self.axes.tick_params(axis='both', which='major', labelsize=fsize)
        self.axes.spines['top'].set_color('none')
        self.axes.spines['right'].set_color('none')
        self.axes.spines['bottom'].set_position(('outward', 5))
        self.axes.set_xlabel(self.xlabel, fontsize=fsize)
        self.axes.spines['left'].set_position(('outward', 5))
        self.axes.set_ylabel(self.ylabel, fontsize=fsize)
        self.axes.autoscale(enable=True, axis='x', tight=True)
        self.axes.autoscale(enable=True, axis='y', tight=True)

    def test(self):
        """A test function, plotting sin(x)"""
        x = np.linspace(0, 10, 100)
        self.axes.plot(x, np.sin(x))

    def resizeEvent(self, event):
        super(EJcanvas, self).resizeEvent(event)
        height = self.figure.get_figheight()
        width = self.figure.get_figwidth()
        margin = config["PlotMargin"]
        self.figure.subplots_adjust(left=margin['l'] / width,
                                    bottom=margin['b'] / height,
                                    right=1 - margin['r'] / width,
                                    top=1 - margin['t'] / height,
                                    wspace=0,
                                    hspace=0)

    def clear(self):
        self.axes.clear()
        self.set_axes()
Example #6
0
def make_square(fig: Figure, ax: Axes) -> None:
    bb: Any = ax.get_position()
    fwidth: float = fig.get_figwidth()
    fheight: float = fig.get_figheight()
    axwidth: float = fwidth * (bb.x1 - bb.x0)
    axheight: float = fheight * (bb.y1 - bb.y0)

    # square_edge = min((axwidth, axheight))
    if axwidth > axheight:
        narrow_by: float = (axwidth - axheight) / fwidth
        bb.x0 += narrow_by / 2
        bb.x1 -= narrow_by / 2
    elif axheight > axwidth:
        shrink_by: float = (axheight - axwidth) / fheight
        bb.y0 += shrink_by / 2
        bb.y1 -= shrink_by / 2
    ax.set_position(bb)
Example #7
0
def set_ticklabel_sizes(fig: Figure, ax: Axes, digital: bool):
    """Updates the sizes of the tick labels for the given figure based on its
    canvas size and canvas dpi.

    :param Figure fig: The figure to update
    :param Axes ax: The specific axes within the figure to update
    :param bool digital: True if this is for digital display,
        False for physical display
    """
    figw = fig.get_figwidth()

    font_size = int((30 / 19.2) * figw) if digital else int((20 / 19.2) * figw)
    font_size = max(font_size, 5)

    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(font_size)

    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(font_size)
Example #8
0
class EJcanvas(FigureCanvas):
    def __init__(self, xlabel='x', ylabel='y', parent=None):
        self.figure = Figure()
        super(EJcanvas, self).__init__(self.figure)
        #  NavigationToolbar2.__init__(self, self)
        self.setParent(parent)
        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.updateGeometry()
        self.axes = self.figure.add_subplot(111)
        self.xlabel = xlabel
        self.ylabel = ylabel

    def set_axes(self, fsize=12):
        self.axes.autoscale(enable=True, axis='x', tight=True)
        self.axes.autoscale(enable=True, axis='y', tight=False)
        self.axes.spines['top'].set_color('none')
        self.axes.spines['right'].set_color('none')
        self.axes.spines['bottom'].set_position(('outward', 5))
        self.axes.set_xlabel(self.xlabel, fontsize=fsize)
        self.axes.spines['left'].set_position(('outward', 5))
        self.axes.set_ylabel(self.ylabel, fontsize=fsize)

    def test(self):
        x = np.linspace(0, 10, 100)
        self.axes.plot(x, np.sin(x))

    def resizeEvent(self, event):
        super(EJcanvas, self).resizeEvent(event)
        height = self.figure.get_figheight()
        width = self.figure.get_figwidth()
        self.figure.subplots_adjust(left=margin['l'] / width,
                                    bottom=margin['b'] / height,
                                    right=1 - margin['r'] / width,
                                    top=1 - margin['t'] / height,
                                    wspace=0,
                                    hspace=0)

    def clear(self):
        #  print "clear"
        self.axes.clear()
        self.set_axes()
Example #9
0
class MatplotlibWidget(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(MatplotlibWidget, self).__init__(parent)

        self.figure = Figure(edgecolor='k')
        self.figure.subplots_adjust(left=0.05,
                                    right=0.95,
                                    top=0.85,
                                    bottom=0.15)
        self.figure.set_size_inches(7, 3)
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.axes = self.figure.add_subplot(111)
        self.axes.set_title('Dep/Arr')

        self.scroll = QtWidgets.QScrollArea(self)
        sizePolicy = QtWidgets.QSizePolicy(
            QtWidgets.QSizePolicy.MinimumExpanding,
            QtWidgets.QSizePolicy.MinimumExpanding)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.scroll.sizePolicy().hasHeightForWidth())
        self.scroll.setSizePolicy(sizePolicy)
        self.scroll.setWidget(self.canvas)

        self.layoutVertical = QtWidgets.QVBoxLayout(self)
        self.layoutVertical.setSizeConstraint(
            QtWidgets.QLayout.SetDefaultConstraint)
        self.layoutVertical.setContentsMargins(0, 0, 0, 0)
        self.nav = NavigationToolbar(self.canvas, self)

        self.prop = QtWidgets.QLabel(self)
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred,
                                           QtWidgets.QSizePolicy.Preferred)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.prop.sizePolicy().hasHeightForWidth())
        self.prop.setSizePolicy(sizePolicy)

        self.layoutVertical.addWidget(self.nav)
        self.layoutVertical.addWidget(self.prop)
        self.layoutVertical.addWidget(self.scroll)

        # self.scroll.resize(self.layoutVertical.geometry().width(), self.layoutVertical.geometry().height())

        self.bins = [1]
        self.binsize = 15 * 60  # min
        self.max_y = 0
        self.plot_off = False
        self.scroll_resize = False
        self.show_labels = False
        self.grid_on = True
        self.reset_fig()
        self.first_timestamp = None

    def reset_fig(self):
        self.axes.clear()
        self.axes.set_title('Dep/Arr', loc='center')
        self.axes.grid(self.grid_on)
        self.axes.set_xticklabels([''])
        self.resize_fig()

    def resize_fig(self):
        if not self.plot_off:
            dpi1 = self.figure.get_dpi()
            if self.scroll_resize and len(self.bins) >= 5 and len(
                    self.bins) * 50 >= self.scroll.width():
                self.figure.set_size_inches(
                    len(self.bins) * 50 / float(dpi1),
                    (self.scroll.height() - 20) / float(dpi1))
                self.canvas.resize(self.figure.get_figwidth() * float(dpi1),
                                   self.figure.get_figheight() * float(dpi1))
            else:
                self.figure.set_size_inches(
                    (self.scroll.width()) / float(dpi1),
                    (self.scroll.height() - 20) / float(dpi1))
                self.canvas.resize(self.figure.get_figwidth() * float(dpi1),
                                   self.figure.get_figheight() * float(dpi1))

    def update_figure(self, op_list, config_list):
        if not self.plot_off and op_list:
            self.reset_fig()
            dep = [0]
            arr = [0]
            labels = []

            if not self.first_timestamp:
                self.first_timestamp = op_list[-1].get_op_timestamp()

            for op in reversed(op_list):
                index = int((op.get_op_timestamp() - self.first_timestamp) /
                            self.binsize)
                # fill in labels and empty values for new bars
                if index >= len(labels):
                    for m in reversed(range(index - len(labels) + 1)):
                        bin_time = op.get_op_timestamp() - m * self.binsize
                        labels.append(
                            time_string(bin_time - (bin_time % self.binsize)))
                    dep.extend([0] * (index - len(dep) + 1))
                    arr.extend([0] * (index - len(arr) + 1))

                if op.LorT == 'T':
                    dep[index] += 1
                elif op.LorT == 'L':
                    arr[index] += 1

            N = len(labels)
            self.bins = np.arange(N)
            width = 0.8
            # plot stacked bars
            p1 = self.axes.bar(self.bins, arr, width, color='#9CB380')
            p2 = self.axes.bar(self.bins,
                               dep,
                               width,
                               bottom=arr,
                               color='#5CC8FF')
            # set legend and labels (not always so they don't overlap)
            self.axes.legend((p2[0], p1[0]), ('Dep', 'Arr'), loc=0)
            self.axes.set_xticks(self.bins - 0.5)
            if len(labels) > 10:
                i = 0
                one_label_per = len(labels) / 15
                for m, label in enumerate(labels):
                    if i > 0:
                        labels[m] = ''
                    i += 1
                    if i > one_label_per: i = 0
            self.axes.set_xticklabels(labels, rotation=-40)
            self.max_y = self.axes.get_ylim()[1]
            self.axes.set_xlim(-1, N)
            x_offset = 0.2
            # set text for count of each bar
            if self.show_labels:
                for m, label in enumerate(labels):
                    if arr[m] != 0:
                        self.axes.text(self.bins[m] - x_offset,
                                       arr[m] + self.max_y * 0.01,
                                       '{:.0f}'.format(arr[m]))
                    if dep[m] != 0:
                        self.axes.text(self.bins[m] + x_offset / 2,
                                       dep[m] + arr[m] + self.max_y * 0.01,
                                       '{:.0f}'.format(dep[m]))
            # config change vertical lines and text
            for m, config in enumerate(config_list):
                x = float(
                    (config.from_epoch - self.first_timestamp) / self.binsize)
                self.axes.plot([x] * 2, [0, 0.9 * self.max_y],
                               'b--',
                               linewidth=1)
                self.axes.text(x + x_offset,
                               self.max_y * 0.9,
                               config.config,
                               fontsize=16)
            # show plot
            self.canvas.draw()
Example #10
0
class PlotterWidget(QWidget):
    """Widget surrounding matplotlib plotter."""

    selectionChangedSignal = Signal(list)

    def __init__(self, parent=None):
        """Create PlotterWidget."""
        super(PlotterWidget, self).__init__(parent)

        self.selected = []
        self.ylabel = ""
        self.xlabel = ""
        self.fig = Figure(figsize=(300,300), dpi=72, facecolor=(1,1,1), \
            edgecolor=(0,0,0))

        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        self.canvas.mpl_connect('pick_event', self.onPick)

        # Toolbar Doesn't get along with Kate's MPL at the moment so the
        # mpl_connects will handle that for the moment
        #self.toolbar = NavigationToolbar(self.canvas, self.canvas)
        self.canvas.mpl_connect('motion_notify_event', self.onMouseMotion)
        self.canvas.mpl_connect('scroll_event', self.onScroll)
        self.canvas.mpl_connect('button_press_event', self.onMouseButtonPress)
        self.lastX = 0
        self.lastY = 0

        self.axes = self.fig.add_subplot(111)

        vbox = QVBoxLayout()
        vbox.addWidget(self.canvas)
        vbox.setContentsMargins(0, 0, 0, 0)
        self.setLayout(vbox)

        # Test
        self.axes.plot(range(6), range(6), 'ob')
        self.axes.set_title("Drag attributes to change graph.")
        self.axes.set_xlabel("Drag here to set x axis.")
        self.axes.set_ylabel("Drag here to set y axis.")
        self.canvas.draw()  # Why does this take so long on 4726 iMac?

    def setXLabel(self, label):
        """Changes the x label of the plot."""
        self.xlabel = label
        self.axes.set_xlabel(label)
        self.canvas.draw()

    def setYLabel(self, label):
        """Changes the y label of the plot."""
        self.ylabel = label
        self.axes.set_ylabel(label)
        self.canvas.draw()

    def plotData(self, xs, ys):
        """Plots the given x and y data."""
        self.axes.clear()
        self.axes.set_xlabel(self.xlabel)
        self.axes.set_ylabel(self.ylabel)
        self.xs = np.array(xs)
        self.ys = np.array(ys)
        self.axes.plot(xs, ys, 'ob', picker=3)
        if np.alen(self.selected) > 0:
            self.highlighted = self.axes.plot(self.xs[self.selected[0]],
                                              self.ys[self.selected[0]],
                                              'or')[0]
        self.canvas.draw()

    def onPick(self, event):
        """Handles pick event, taking the closest single point.

           Note that since the id associated with a given point may be
           associated with many points, the final selection displayed to the
           user may be serveral points.
        """
        selected = np.array(event.ind)

        mouseevent = event.mouseevent
        xt = self.xs[selected]
        yt = self.ys[selected]
        d = np.array((xt - mouseevent.xdata)**2 + (yt - mouseevent.ydata)**2)
        thepoint = selected[d.argmin()]
        selected = []
        selected.append(thepoint)

        self.selectionChangedSignal.emit(selected)

    @Slot(list)
    def setHighlights(self, ids):
        """Sets highlights based on the given ids. These ids are the indices
           of the x and y data, not the domain.
        """
        old_selection = list(self.selected)
        self.selected = ids
        if ids is None:
            self.selected = []

        if (old_selection == self.selected and old_selection != [])\
            or self.selected != []:

            if self.selected != [] and old_selection != self.selected:  # Color new selection
                for indx in self.selected:
                    self.highlighted = self.axes.plot(self.xs[indx],
                                                      self.ys[indx], 'or')[0]
            if old_selection == self.selected:  # Turn off existing selection
                self.selected = []
            if old_selection != []:  # Do not color old selection
                for indx in old_selection:
                    self.axes.plot(self.xs[indx],
                                   self.ys[indx],
                                   'ob',
                                   picker=3)

            self.canvas.draw()
            return True
        return False

# ---------------------- NAVIGATION CONTROLS ---------------------------

# Mouse movement (with or w/o button press) handling

    def onMouseMotion(self, event):
        """Handles the panning."""
        if event.button == 1:
            xmotion = self.lastX - event.x
            ymotion = self.lastY - event.y
            self.lastX = event.x
            self.lastY = event.y
            figsize = min(self.fig.get_figwidth(), self.fig.get_figheight())
            xmin, xmax = self.calcTranslate(self.axes.get_xlim(), xmotion,
                                            figsize)
            ymin, ymax = self.calcTranslate(self.axes.get_ylim(), ymotion,
                                            figsize)
            self.axes.set_xlim(xmin, xmax)
            self.axes.set_ylim(ymin, ymax)
            self.canvas.draw()

    # Note: the dtuple is in data coordinates, the motion is in pixels,
    # we estimate how much motion there is based on the figsize and then
    # scale it appropriately to the data coordinates to get the proper
    # offset in figure limits.
    def calcTranslate(self, dtuple, motion, figsize):
        """Calculates the translation necessary in one direction given a
           mouse drag in that direction.

           dtuple
               The current limits in a single dimension

           motion
               The number of pixels the mouse was dragged in the dimension.
               This may be negative.

           figsize
               The approximate size of the figure.
        """
        dmin, dmax = dtuple
        drange = dmax - dmin
        dots = self.fig.dpi * figsize
        offset = float(motion * drange) / float(dots)
        newmin = dmin + offset
        newmax = dmax + offset
        return tuple([newmin, newmax])

    # When the user clicks the left mouse button, that is the start of
    # their drag event, so we set the last-coordinates that are used to
    # calculate drag
    def onMouseButtonPress(self, event):
        """Records start of drag event."""
        if event.button == 1:
            self.lastX = event.x
            self.lastY = event.y

    # On mouse wheel scrool, we zoom
    def onScroll(self, event):
        """Zooms on mouse scroll."""
        zoom = event.step
        xmin, xmax = self.calcZoom(self.axes.get_xlim(), 1. + zoom * 0.05)
        ymin, ymax = self.calcZoom(self.axes.get_ylim(), 1. + zoom * 0.05)
        self.axes.set_xlim(xmin, xmax)
        self.axes.set_ylim(ymin, ymax)
        self.canvas.draw()

    # Calculates the zoom required by the wheel scroll for a single dimension
    # dtuple - the current limits in some dimension
    # scale - fraction to increase/decrease the image size
    # This does a zoom by scaling the limits in that direction appropriately
    def calcZoom(self, dtuple, scale):
        """Calculates the zoom in a single direction based on:

           dtuple
               The limits in the direction

           scale
               Fraction by which to increase/decrease the figure.
        """
        dmin, dmax = dtuple
        drange = dmax - dmin
        dlen = 0.5 * drange
        dcenter = dlen + dmin
        newmin = dcenter - dlen * scale
        newmax = dcenter + dlen * scale
        return tuple([newmin, newmax])
Example #11
0
class Chart(object):
    """
    Simple and clean facade to Matplotlib's plotting API.
    
    A chart instance abstracts a plotting device, on which one or
    multiple related plots can be drawn. Charts can be exported as images, or
    visualized interactively. Each chart instance will always open in its own
    GUI window, and this window will never block the execution of the rest of
    the program, or interfere with other L{Chart}s.
    The GUI can be safely opened in the background and closed infinite number
    of times, as long as the client program is still running.
    
    By default, a chart contains a single plot:
    
    >>> chart.plot
    matplotlib.axes.AxesSubplot
    >>> chart.plot.hist(...)
    
    If C{rows} and C{columns} are defined, the chart will contain
    C{rows} x C{columns} number of plots (equivalent to MPL's sub-plots).
    Each plot can be assessed by its index:
    
    >>> chart.plots[0]
    first plot
    
    or by its position in the grid:
    
    >>> chart.plots[0, 1]
    plot at row=0, column=1
    
    @param number: chart number; by default this a L{Chart.AUTONUMBER}
    @type number: int or None
    @param title: chart master title
    @type title: str
    @param rows: number of rows in the chart window
    @type rows: int
    @param columns: number of columns in the chart window
    @type columns: int
    
    @note: additional arguments are passed directly to Matplotlib's Figure
           constructor. 
    """

    AUTONUMBER = None

    _serial = 0

    def __init__(self,
                 number=None,
                 title='',
                 rows=1,
                 columns=1,
                 backend=Backends.WX_WIDGETS,
                 *fa,
                 **fk):

        if number == Chart.AUTONUMBER:
            Chart._serial += 1
            number = Chart._serial

        if rows < 1:
            rows = 1
        if columns < 1:
            columns = 1

        self._rows = int(rows)
        self._columns = int(columns)
        self._number = int(number)
        self._title = str(title)
        self._figure = Figure(*fa, **fk)
        self._figure._figure_number = self._number
        self._figure.suptitle(self._title)
        self._beclass = backend
        self._hasgui = False
        self._plots = PlotsCollection(self._figure, self._rows, self._columns)
        self._canvas = FigureCanvasAgg(self._figure)

        formats = [(f.upper(), f)
                   for f in self._canvas.get_supported_filetypes()]
        self._formats = csb.core.Enum.create('OutputFormats', **dict(formats))

    def __getitem__(self, i):
        if i in self._plots:
            return self._plots[i]
        else:
            raise KeyError('No such plot number: {0}'.format(i))

    def __enter__(self):
        return self

    def __exit__(self, *a, **k):
        self.dispose()

    @property
    def _backend(self):
        return Backend.get(self._beclass, started=True)

    @property
    def _backend_started(self):
        return Backend.query(self._beclass)

    @property
    def title(self):
        """
        Chart title
        @rtype: str
        """
        return self._title

    @property
    def number(self):
        """
        Chart number
        @rtype: int
        """
        return self._number

    @property
    def plots(self):
        """
        Index-based access to the plots in this chart
        @rtype: L{PlotsCollection}
        """
        return self._plots

    @property
    def plot(self):
        """
        First plot
        @rtype: matplotlib.AxesSubplot
        """
        return self._plots[0]

    @property
    def rows(self):
        """
        Number of rows in this chart
        @rtype: int
        """
        return self._rows

    @property
    def columns(self):
        """
        Number of columns in this chart
        @rtype: int
        """
        return self._columns

    @property
    def width(self):
        """
        Chart's width in inches
        @rtype: int
        """
        return self._figure.get_figwidth()

    @width.setter
    def width(self, inches):
        self._figure.set_figwidth(inches)
        if self._backend_started:
            self._backend.resize(self._figure)

    @property
    def height(self):
        """
        Chart's height in inches
        @rtype: int
        """
        return self._figure.get_figheight()

    @height.setter
    def height(self, inches):
        self._figure.set_figheight(inches)
        if self._backend_started:
            self._backend.resize(self._figure)

    @property
    def dpi(self):
        """
        Chart's DPI
        @rtype: int
        """
        return self._figure.get_dpi()

    @dpi.setter
    def dpi(self, dpi):
        self._figure.set_dpi(dpi)
        self._backend.resize(self._figure)

    @property
    def formats(self):
        """
        Supported output file formats
        @rtype: L{csb.core.enum}
        """
        return self._formats

    def show(self):
        """
        Show the GUI window (non-blocking).
        """
        if not self._hasgui:
            self._backend.add(self._figure)
            self._hasgui = True

        self._backend.show(self._figure)

    def hide(self):
        """
        Hide (but do not dispose) the GUI window.
        """
        self._backend.hide(self._figure)

    def dispose(self):
        """
        Dispose the GUI interface. Must be called at the end if any
        chart.show() calls have been made. Automatically called if using
        the chart in context manager ("with" statement).
        
        @note: Failing to call this method if show() has been called at least
        once may cause backend-related errors.
        """
        if self._backend_started:

            service = self._backend

            if service and service.running:
                service.destroy(self._figure, wait=True)
                service.client_disposed(self)

    def save(self, file, format='png', crop=False, dpi=None, *a, **k):
        """
        Save all plots to an image.
        
        @param file: destination file name
        @type file: str
        @param format: output image format; see C{chart.formats} for enumeration
        @type format: str or L{csb.core.EnumItem}
        @param crop: if True, crop the image (equivalent to MPL's bbox=tight)
        @type crop: bool
                
        @note: additional arguments are passed directly to MPL's savefig method
        """
        if 'bbox_inches' in k:
            bbox = k['bbox_inches']
            del k['bbox_inches']
        else:
            if crop:
                bbox = 'tight'
            else:
                bbox = None

        self._canvas.print_figure(file,
                                  format=str(format),
                                  bbox_inches=bbox,
                                  dpi=dpi,
                                  *a,
                                  **k)
Example #12
0
class Viewer():
    """
    A tkinter based GUI to look at fact events in the camera view.

    Attributes
    ----------

    dataset  : array like with shape (num_events, 1440)
        the data you want to plot into the pixels
    label    : str
        the label for the colormap
    pixelset : boolean array with shape (num_events, 1440)
        the pixels where pixelset is True are marked with 'pixelsetcolour'
        [default: None]
    pixelsetcolour : a matplotlib conform colour representation
        the colour for the pixels in 'pixelset',
        [default: green]
    clickedcolour: a matplotlib conform colour represantation
        the coulour for clicked pixel
        [default: red]
    mapfile : str
        path/to/fact/pixelmap.csv
        [default pixel-map.csv]
    cmap : str or matplotlib colormap instance
        the colormap to use for plotting the 'dataset'
        [default: gray]
    vmin : float
        the minimum for the colorbar, if None min(dataset[event]) is used
        [default: None]
    vmax : float
        the maximum for the colorbar, if None max(dataset[event]) is used
        [default: None]
    """
    def __init__(self,
                 dataset,
                 label,
                 pixelset=None,
                 pixelsetcolour="g",
                 clickedcolour="r",
                 mapfile="pixel-map.csv",
                 cmap="gray",
                 vmin=None,
                 vmax=None,
                 ):
        matplotlib.use('TkAgg', warn=False, force=True)
        matplotlib.rcdefaults()
        self.event = 0
        if dataset.shape == (1440, ):
            self.dataset = np.reshape(dataset, (1, 1440))
        elif dataset.shape[1] == 1440:
            self.dataset = dataset
        else:
            raise ValueError('Viewer expects dataset with shape (1440, )\n'
                             'or (n_events, 1440)'
                             )
        self.numEvents = dataset.shape[0]
        self.pixelset = pixelset
        self.pixelsetcolour = pixelsetcolour
        self.clickedcolour = clickedcolour
        self.label = label
        self.cmap = cmap
        self.vmin = vmin
        self.vmax = vmax
        self.pixel_x, self.pixel_y = get_pixel_coords()
        self.fig = Figure(figsize=(7, 6), dpi=100)

        self.init_plot()

        # ---- GUI Stuff ----

        self.root = tk.Tk()
        self.root.geometry(("1024x768"))
        self.root.wm_title("PyFactViewer")

        buttonFrame = tk.Frame(self.root)
        plotFrame = tk.Frame(self.root)
        infoFrame = tk.Frame(plotFrame)

        buttonFrame.pack(side=tk.TOP)
        plotFrame.pack(side=tk.BOTTOM, expand=True, fill=tk.BOTH)
        infoFrame.pack(side=tk.BOTTOM)

        self.canvas = FigureCanvasTkAgg(self.fig, master=plotFrame)
        self.canvas.mpl_connect("pick_event", self.onpick)
        self.canvas.mpl_connect("resize_event", self.redraw)
        self.canvas.show()
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        self.canvas._tkcanvas.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=1)
        self.root.bind('<Key>', self.on_key_event)

        self.quit_button = tk.Button(master=buttonFrame,
                                     text='Quit',
                                     command=self.quit,
                                     expand=None,
                                     )
        self.next_button = tk.Button(master=buttonFrame,
                                     text='Next',
                                     command=self.next,
                                     expand=None,
                                     )
        self.previous_button = tk.Button(master=buttonFrame,
                                         text='Previous',
                                         command=self.previous,
                                         expand=None,
                                         )
        self.save_button = tk.Button(master=buttonFrame,
                                     text='Save Image',
                                     command=self.save,
                                     expand=None,
                                     )

        self.eventstring = tk.StringVar()
        self.eventstring.set("EventNum: {:05d}".format(self.event))
        self.eventbox = tk.Label(master=buttonFrame,
                                 textvariable=self.eventstring,
                                 )

        self.eventbox.pack(side=tk.LEFT)
        self.previous_button.pack(side=tk.LEFT)
        self.next_button.pack(side=tk.LEFT)
        self.quit_button.pack(side=tk.RIGHT)
        self.save_button.pack(side=tk.RIGHT)

        self.infotext = tk.StringVar()
        self.infotext.set("Click on a Pixel")
        self.infobox = tk.Label(master=infoFrame, textvariable=self.infotext)
        self.infobox.pack(side=tk.LEFT)

        self.update()
        tk.mainloop()

    def init_plot(self):
        self.width, self.height = self.fig.get_figwidth(), self.fig.get_figheight()
        self.fig.clf()
        self.ax = self.fig.add_subplot(1, 1, 1, aspect=1)
        divider = make_axes_locatable(self.ax)
        self.cax = divider.append_axes("right", size="5%", pad=0.1)
        self.ax.set_axis_off()

        if self.vmin is None:
            vmin = np.min(self.dataset[self.event])
        else:
            vmin = self.vmin
        if self.vmax is None:
            vmax = np.max(self.dataset[self.event])
        else:
            vmax = self.vmax

        if self.pixelset is None:
            pixelset = np.zeros(1440, dtype=bool)
        else:
            pixelset = self.pixelset[self.event]

        self.plot = self.ax.factcamera(
            data=self.dataset[self.event],
            pixelcoords=None,
            cmap=self.cmap,
            vmin=vmin,
            vmax=vmax,
            pixelset=pixelset,
            pixelsetcolour=self.pixelsetcolour,
            linewidth=None,
            picker=False,
        )
        self.plot.set_picker(0)
        self.clicked_pixel = None

        self.cb = self.fig.colorbar(self.plot, cax=self.cax, label=self.label)
        self.cb.set_clim(vmin=vmin, vmax=vmax)
        self.cb.draw_all()

    def save(self):
        filename = filedialog.asksaveasfilename(
            initialdir=os.getcwd(),
            parent=self.root,
            title="Choose a filename for the saved image",
            defaultextension=".pdf",
        )
        if filename:
            fig = self.fig
            fig.savefig(filename, dpi=300, bbox_inches="tight", transparent=True)
            print("Image sucessfully saved to", filename)

    def redraw(self, event):
        self.linewidth = calc_linewidth(self.ax)
        self.plot.set_linewidth(self.linewidth)
        self.fig.tight_layout(pad=0)
        self.canvas.draw()

    def quit(self):
        self.root.quit()     # stops mainloop
        self.root.destroy()  # this is necessary on Windows to prevent

    def next(self):
        self.event = (self.event + 1) % len(self.dataset)
        self.update()

    def previous(self):
        if self.event == 0:
            self.event = self.numEvents - 1
        else:
            self.event -= 1
        self.update()

    def on_key_event(self, event):
        if event.keysym == "Right":
            self.next()
        elif event.keysym == "Left":
            self.previous()
        elif event.keysym == "q":
            self.quit()

    def update(self):
        self.plot.set_array(self.dataset[self.event])
        edgecolors = np.array(1440*["k"])
        if self.pixelset is not None:
            edgecolors[self.pixelset[self.event]] = self.pixelsetcolour
        if self.clicked_pixel is not None:
            edgecolors[self.clicked_pixel] = self.clickedcolour
        self.plot.set_edgecolors(edgecolors)
        self.plot.changed()
        if self.vmin is None:
            vmin = np.min(self.dataset[self.event])
        else:
            vmin = self.vmin
        if self.vmax is None:
            vmax = np.max(self.dataset[self.event])
        else:
            vmax = self.vmax
        self.linewidth = calc_linewidth(self.ax)
        self.plot.set_linewidths(self.linewidth)
        self.cb.set_clim(vmin=vmin, vmax=vmax)
        self.cb.draw_all()
        self.plot.set_picker(0)
        self.canvas.draw()
        self.eventstring.set("EventNum: {:05d}".format(self.event))

    def onpick(self, event):
        hitpixel = event.ind[0]
        if hitpixel != self.clicked_pixel:
            self.clicked_pixel = hitpixel
            self.update()
            self.infotext.set(
                "chid: {:04d}, {} = {:4.2f}".format(
                    hitpixel,
                    self.label,
                    self.plot.get_array()[hitpixel]
                )
            )
Example #13
0
class PlotPanel(wx.Panel):
	"""Plot Panel
	Contains the plot and some plotting based functions"""
	def __init__(self, parent, root):
		wx.Panel.__init__(self,parent)
		self.parent = parent
		self.root   = root

		#important paramenters
		self.inFileS = None
		self.data    = None
		#lma
		self.lma     = None
		self.lmaColor= 'Black'
		#waveforms
		self.waveHead= None
		self.waveData= None
		self.waveLpf = 5	#waveform low pass filter
		self.waveColor= 'Black'
		#self.mskData = None
		self.face    = 'w'
		self.txtc    = 'k'
		self.color   = 'k'
		self.colorMap= gCmap
		self.colorOp = 1
		self.colorHL = (0,1,0,.25)
		self.alphaOp = 0
		self.sizeOp  = 1
		self.markerSz= 6
		self.marker  = 'D'
		self.cosine  = True
		self.maxA    = np.log10(65000)
		self.minA    = np.log10(  700)
		#qualities
		self.eCls = 2.25
		self.eXpk = 0.3
		self.eMlt = .40
		self.eStd = 2.0
		self._draw_pending = False
		self._draw_counter = 0
		
		self.tOffset = 0
		self.limitsHistory = []
		

		self.SetBackgroundColour(wx.NamedColour("WHITE"))

		self.figure = Figure(figsize=(8.0,4.0))
		self.figure_canvas = FigureCanvas(self, -1, self.figure)

		# Note that event is a MplEvent
		self.figure_canvas.mpl_connect('motion_notify_event', self.UpdateStatusBar)
		#self.figure_canvas.Bind(wx.EVT_ENTER_WINDOW, self.ChangeCursor)

		#this status bar is actually part of the main frame
		self.statusBar = wx.StatusBar(self.root, -1)
		self.statusBar.SetFieldsCount(1)
		self.root.SetStatusBar(self.statusBar)

		self.mkPlot()
		
		self.sizer = wx.BoxSizer(wx.VERTICAL)
		self.sizer.Add(self.figure_canvas, 1, wx.LEFT | wx.TOP | wx.EXPAND)
		self.sizer.Add(self.statusBar, 0, wx.BOTTOM)

		self.SetSizer(self.sizer)
		self.Bind(wx.EVT_SIZE,self.OnSize)

	###
	# Makers

	def mkPlot(self, lims=None):
		#clear the figure
		self.figure.clf()
		self.ax1Coll = None
		self.ax2Coll = None
		self.ax3Coll = None
		self.ax1Lma  = None
		self.ax3Lma  = None
		self.ax3Wave = None
		#clear the history
		self.limitsHistory = [ ]
		
		#clear the lma and wave data
		self.lma     = None
		#waveforms
		self.waveHead= None
		self.waveData= None
		
		#initialize the plot region		
		gs = gridspec.GridSpec(2, 2)

		#axis #1, the Az-El (or cosa-cosb) plot
		self.ax1 = self.figure.add_subplot(gs[:,0],axisbg=self.face)
		self.ax1.yaxis.set_tick_params(labelcolor=self.txtc)
		self.ax1.yaxis.set_tick_params(color=self.txtc)
		self.ax1.xaxis.set_tick_params(labelcolor=self.txtc)
		self.ax1.xaxis.set_tick_params(color=self.txtc)        

		#axis #2, the time-El plot (overview)
		self.ax2 = self.figure.add_subplot(gs[0,1],axisbg=self.face)
		self.ax2.yaxis.set_tick_params(labelcolor=self.txtc)
		self.ax2.yaxis.set_tick_params(color=self.txtc)
		self.ax2.xaxis.set_tick_params(labelcolor=self.txtc)
		self.ax2.xaxis.set_tick_params(color=self.txtc)        
		#self.ax2b = self.figure.add_subplot(gs[0,1],sharex=self.ax2, sharey=self.ax2, frameon=False)

		#axis #3, the time-El plot (Zoom)
		self.ax3 = self.figure.add_subplot(gs[1,1],axisbg=self.face)
		self.ax3.yaxis.set_tick_params(labelcolor=self.txtc)
		self.ax3.yaxis.set_tick_params(color=self.txtc)
		self.ax3.xaxis.set_tick_params(labelcolor=self.txtc)
		self.ax3.xaxis.set_tick_params(color=self.txtc)

		SelectorColor = self.colorHL
		self.span_ax1 = RectangleSelector(self.ax1, self.OnSelectAx1,
			minspanx=0.01, minspany=0.01,
			rectprops=dict(facecolor=self.colorHL, alpha=0.25),useblit=True)

		self.span_ax2 = SpanSelector(self.ax2, self.OnSelectAx2, 'horizontal',\
			rectprops=dict(facecolor=self.colorHL, alpha=0.25),useblit=True,minspan=0.01)

		self.span_ax3 = SpanSelector(self.ax3, self.OnSelectAx3, 'horizontal',\
			rectprops=dict(facecolor=self.colorHL, alpha=0.25),useblit=True,minspan=0.01)

		

		########
		# Make the plot
		if self.data == None:
			return
		#initialize all the ranges
		self.data.reset_limits()
		self.data.update()

		self.mkColorMap()

		print 'Making New Plot'
		
		#make the title
		self.title = self.figure.suptitle( self.data.TriggerTimeS )
		
		#Main Plot
		if self.cosine:
			print 'Cosine Projection'
			theta = np.linspace(0,2*np.pi,1000)
			X = np.cos(theta)
			Y = np.sin(theta)
			self.ax1Coll = self.ax1.scatter( 
								self.data.cosb, self.data.cosa,
								s=self.markerSz,
								marker=self.marker,
								facecolor=self.color,
								edgecolor='None' )
			self.ax1.plot(X,Y,'k-', linewidth=2)
			self.ax1.plot(np.cos(30*np.pi/180)*X,np.cos(30*np.pi/180)*Y,'k--', linewidth=2)
			self.ax1.plot(np.cos(60*np.pi/180)*X,np.cos(60*np.pi/180)*Y,'k--', linewidth=2)

			self.ax1.set_xlabel('cos($\\alpha$)')
			self.ax1.set_ylabel('cos($\\beta$)')
			self.ax1.set_xlim( self.data.cbRange )
			self.ax1.set_ylim( self.data.caRange )
			self.ax1.set_aspect('equal')
		else:
			print 'Az-El Projection'
			self.ax1Coll = self.ax1.scatter( 
								self.data.azim,
								self.data.elev,
								s=self.markerSz,
								marker=self.marker,
								facecolor=self.color,
								edgecolor='None' )
			self.ax1.set_xlim( self.data.azRange )
			self.ax1.set_ylim( self.data.elRange )
			self.ax1.set_ylabel('Elevation')
			self.ax1.set_xlabel('Azimuth')
			self.ax1.set_aspect('auto')

		#the zoomed plot
		self.ax3Coll = self.ax3.scatter( 
							  self.data.time,
							  self.data.elev,
							  s=self.markerSz,
							  marker=self.marker,
							  facecolor=self.color,
							  edgecolor='None' )
		self.ax3.set_xlim( self.data.tRange )
		self.ax3.set_ylim( self.data.elRange )
		self.ax3.set_xlabel('Time (ms)')


		#the overview plot
		self.ax2.pcolormesh( self.data.rawDataHist[2], 
					self.data.rawDataHist[1], self.data.rawDataHist[0]**.1, 
					edgecolor='None',cmap=cm.binary)

		self.ax2Coll = 	  self.ax2.scatter( 
						  self.data.time,
						  self.data.elev,
						  s=3,
						  marker=self.marker,
						  facecolor=self.colorHL,
						  edgecolor='None' )
		#these limits shouldn't change though
		self.ax2.set_xlim( self.data.tRange  )
		self.ax2.set_ylim( self.data.elRange )

		self.root.ctrlPanel.filtTab.set_values()
		self.redraw()
		
	def mkColorMap(self):
		"""Makes a colormap"""
		print 'Color:',
		#most color maps use static sizing
		if self.data == None:
			return
		if self.colorOp == 0:
			print 'Greyscale'
			self.data.sort( self.data.time )
			#none
			self.color = np.zeros( (len(self.data.mask),4) )
		elif self.colorOp == 1:
			#time
			print 'By time'
			self.data.sort( self.data.time )
			c = self.data.time - self.data.time.min()
			c /= c.max()
			self.color = self.colorMap( c )
		elif self.colorOp == 2:
			#points
			print 'by points'
			self.data.sort( self.data.time )
			c = np.arange( len(self.data.mask), dtype='f' )
			c /=max(c)
			self.color = self.colorMap( c )
		elif self.colorOp == 3:
			#amplitude
			print 'by Amplitude'
			self.data.sort( self.data.pkpk )
			aMin = np.log10( self.data.a05 )
			aMax = np.log10( self.data.a95 )
			c = np.log10(self.data.pkpk)
			c = (c-aMin)/(aMax-aMin)
			c[c>1] = 1
			self.color = self.colorMap( c )
					
		self.mkAlpha()
	
	def mkSize(self):
		print 'MarkerSize:',
		if self.sizeOp == 0:
			#small
			print 'small'
			self.markerSz = 3
		elif self.sizeOp == 1:
			#medium
			print 'medium'
			self.markerSz = 6
		elif self.sizeOp == 2:
			#large
			print 'large'
			self.markerSz = 12
		elif self.sizeOp == 3:
			#size by amplitude
			print 'by Amplitude'
			s = np.log10( self.data.pkpk )
			s = (s-self.minA)/(self.maxA-self.minA)
			s[s>1] = 1
			s = (1+3*s**2)
			self.markerSz = 6*s
		elif self.sizeOp == 4:
			#exagerated size by ampltiude
			print 'exagerated'
			s = np.log10( self.data.pkpk )
			aMin = np.log10(self.data.aMin)
			aMax = np.log10(self.data.aMax)
			s = (s-aMin)/(aMax-aMin)
			s[s>1] = 1
			s = (1+3*s**2)**2
			self.markerSz = 6*s
			
	
	def mkAlpha(self):
		print 'Alpha:',
		if self.alphaOp == 0:
			#no alpha
			print 'None'
			self.color[:,3] = 1
			return
		elif self.alphaOp == 1:
			#some alpha
			print '0.2'
			alphaEx = .2
		elif self.alphaOp == 2:
			#more alpha
			print '0.4'
			alphaEx = .4
		else:
			#don't know this option, don't do anything
			return
		a = self.data.pkpk.copy()
		a -= min(a)
		a /= max(a)
		self.color[:,3] = a**alphaEx
		
	
	###
	#On Catches

	def OnSelectAx1(self,click, release):
		xlims = [click.xdata, release.xdata]
		xlims.sort()
		ylims = [click.ydata, release.ydata]
		ylims.sort()
		
		if self.cosine:
			self.SetLimits(caRange=ylims, cbRange=xlims)
		else:
			self.SetLimits(elRange=ylims, azRange=xlims)

		#update the plots
		self.UpdatePlot()
			
	def OnSelectAx2(self,xmin,xmax):
		self.figure_canvas.draw()
		if self.data == None:
			return
		self.SetLimits(tRange=[xmin,xmax])
		
		#update the mask and plot
		self.UpdatePlot()

	def OnSelectAx3(self,xmin,xmax):
		#mask the data array
		if self.data == None:
			return
		self.SetLimits(tRange=[xmin,xmax])
		
		#update the mask and plot
		self.UpdatePlot()

	
	def OnSize(self,e):
		if self.GetAutoLayout():
			self.Layout()
		left   = 60
		right  = 30
		top    = 30
		bottom = 40
		wspace = 100
		dpi = self.figure.dpi
		h   = self.figure.get_figheight()*dpi
		w   = self.figure.get_figwidth()*dpi
		#figure out the margins
		self.figure.subplots_adjust(left=left/w,
									right=1-right/w,
									bottom=bottom/h,
									top=1-top/h,
									wspace=wspace/w)
		self.redraw()
	
	###
	#Updaters
		
	def UpdateStatusBar(self, event):
		if event.inaxes:
			x, y = event.xdata, event.ydata
			self.statusBar.SetStatusText(( "x= " + str(x) +
										   "  y=" +str(y) ),
											0)
	#~ def UpdateMask(self):
		#~ if self.data == None:
			#~ return
		#~ 
		#~ self.data.mask = np.where( 
			#~ (self.data.time>=self.tRange[ 0])&(self.data.time<=self.tRange[ 1])&
			#~ (self.data.azim>=self.azRange[0])&(self.data.azim<=self.azRange[1])&
			#~ (self.data.elev>=self.elRange[0])&(self.data.elev<=self.elRange[1])&
			#~ (self.data.cosa>=self.caRange[0])&(self.data.cosa<=self.caRange[1])&
			#~ (self.data.cosb>=self.cbRange[0])&(self.data.cosb<=self.cbRange[1]) )[0]

	def OffsetLimits(self,offset):
		"""OffsetLimits(self,offset)
		this comes up because sometimes you want the time from the second, 
		and sometimes you want the time from the trigger.
		
		This takes care of updating the time limits history so that 
		things refer to the same section of the flash
		"""
		
		for i in range(len(self.limitsHistory)):
			if not 'tRange' in self.limitsHistory[i]:
				#can this even happen?
				continue
			self.limitsHistory[i]['tRange'][0] += offset - self.tOffset
			self.limitsHistory[i]['tRange'][1] += offset - self.tOffset
		
		#update the waveform
		if self.waveData != None:
			self.waveData[0,:] += offset - self.tOffset
		
		self.tOffset = offset
			
				
				

	def GetLimits(self):
		if self.data == None:
			#no nothing
			return
		
		#the limits get stored in a dictionary
		lims = {}
		lims['caRange'] = self.data.caRange
		lims['cbRange'] = self.data.cbRange
		lims['elRange'] = self.data.elRange
		lims['azRange'] = self.data.azRange
		lims['tRange']  = self.data.tRange
		
		return lims

	def SetLimits(self, caRange=None, cbRange=None, 
		elRange=None, azRange=None, tRange=None, save=True ):
		
		if self.data == None:
			#Do Nothing
			return
		
		#append the old limits to the history
		if self.limitsHistory != None and save:
			self.limitsHistory.append(self.GetLimits())

		#the limits that aren't passed aren't changed, 
		#get them from the data and store in history
		lims = {}
		if caRange != None:
			self.data.caRange=caRange

		if cbRange != None:
			self.data.cbRange=cbRange

		if elRange != None:
			self.data.elRange=elRange

		if azRange != None:
			self.data.azRange=azRange
		
		if tRange != None:
			self.data.tRange=tRange

	def UpdatePlot(self, update_overview=False):
		"""redraws the main axis
		if update_overview=True, also redraws the upper righthand plot"""
		if self.data == None:
			return
		self.data.limits()
		self.data.update()
		self.mkColorMap()
		self.mkSize()
		
		#Main plot (remake)
		if self.ax1Coll != None:
			self.ax1Coll.remove()
		if self.ax1Lma  != None:
			self.ax1Lma.remove()
		if self.cosine:
			print 'Cosine Projection'
			self.ax1Coll = 	self.ax1.scatter( 
							self.data.cosb,
							self.data.cosa,
							s=self.markerSz,
							marker=self.marker,
							facecolor=self.color,
							edgecolor='None' )
			
			if self.lma != None:
				self.ax1Lma = self.ax1.scatter(
							self.lma.cosb,
							self.lma.cosa,
							s=6,
							marker=self.marker,
							facecolor=self.lmaColor,
							edgecolor='None' )
			
			self.ax1.set_ylabel('cosa')
			self.ax1.set_xlabel('cosb')
			self.ax1.set_ylim( self.data.caRange )
			self.ax1.set_xlim( self.data.cbRange )
			self.ax1.set_aspect('equal')
		else:
			print 'Az-El Projection'
			self.ax1Coll = 	self.ax1.scatter( 
							self.data.azim,
							self.data.elev,
							s=self.markerSz,
							marker=self.marker,
							facecolor=self.color,
							edgecolor='None' )
			if self.lma != None:
				self.ax1Lma = self.ax1.scatter(
							self.lma.azim,
							self.lma.elev,
							s=6,
							marker=self.marker,
							facecolor=self.lmaColor,
							edgecolor='None' )
			
			self.ax1.set_xlim( self.data.azRange )
			self.ax1.set_ylim( self.data.elRange )
			self.ax1.set_ylabel('Elevation')
			self.ax1.set_xlabel('Azimuth')
			self.ax1.set_aspect('auto')

		#Zoom plot (remake)
		if self.ax3Coll != None:
			self.ax3Coll.remove()
		if self.ax3Lma != None:
			self.ax3Lma.remove()
		if self.ax3Wave != None:
			self.ax3Wave.remove()

		#plot waveforms?
		if self.waveHead != None:
			#this starts with a complicated conditional to determine if 
			#we need to reload the waveform.  this should be avoided, as 
			#it takes a while, especially for longer durations
			if self.waveData == None:
				self.readWave()
			elif ( self.waveData[0,0] - self.data.tRange[0] < 0.1 ) and \
				 ( self.waveData[0,-1]- self.data.tRange[1] > -0.1 ) and \
				 ( self.waveData[0,-1]-self.waveData[0,0] < 
					1.5*(self.data.tRange[1]-self.data.tRange[0])):
				#we don't need to read the data
				pass
			else:
				print self.waveData[0,0], self.data.tRange[0], self.waveData[0,0] - self.data.tRange[0] < 0.1 
				print self.waveData[0,-1], self.data.tRange[1],self.waveData[0,-1]- self.data.tRange[1] > -0.1
				print self.waveData[0,-1]-self.waveData[0,0], 1.5*(self.data.tRange[1]-self.data.tRange[0]), ( self.waveData[0,-1]-self.waveData[0,0] < 1.5*(self.data.tRange[1]-self.data.tRange[0]))
				self.readWave()

			#plot the data
			self.ax3Wave, = self.ax3.plot( 
						self.waveData[0,:], 
						self.waveData[1,:], 
						self.waveColor,
						zorder=-10 )
			
		#Scatter INTF
		self.ax3Coll = self.ax3.scatter( 
						  self.data.time,
						  self.data.elev,
						  s=self.markerSz,
						  marker=self.marker,
						  facecolor=self.color,
						  edgecolor=(1,1,1,0) )
		
		#plot LMA?
		if self.lma != None:
			self.ax3Lma = self.ax3.scatter( 
							self.lma.time, 
							self.lma.elev, 
							s = 6,
							marker = self.marker, 
							facecolor=self.lmaColor,
							edgecolor='None' )
		self.ax3.set_xlim( self.data.tRange )
		self.ax3.set_ylim( self.data.elRange )
		self.ax3.set_xlabel('Time (ms)')
		
		#overview plot
		#Remake current stuff only
		if self.ax2Coll != None:
			self.ax2Coll.remove()
		if update_overview:
			#the overview plot likely just moved to a new location
			#reset the limits
			self.ax2.set_xlim( 	self.data.tStart+self.data.tOffset, 
								self.data.tStop+ self.data.tOffset  )
			self.ax2.pcolormesh( self.data.rawDataHist[2], 
				self.data.rawDataHist[1], self.data.rawDataHist[0]**.1, 
				edgecolor='None',cmap=cm.binary)

		self.ax2Coll    = self.ax2.scatter( 
						  self.data.time,
						  self.data.elev,
						  s=3,
						  marker=self.marker,
						  facecolor=self.colorHL,
						  edgecolor='None' )
		
		print "redrawing figure"
		self.root.ctrlPanel.filtTab.set_values()
		self.redraw()

	def readWave(self):
			#get the start sample, surprisingly difficult this
			#t = (iMax-Settings.preTriggerSamples)/1000./Settings.sampleRate
			#t*sRage*1000+preTrig = iMax
			sRate = self.data.header.SampleRate
			pSamp = self.data.header.PreTriggerSamples
			sSam = int( (self.data.tRange[0]-self.data.tOffset)*sRate/1000+pSamp )
			#get the number of samples, not so hard
			numSam = int( (self.data.tRange[1]-self.data.tRange[0])/1000.*sRate )
			#read in wave data and plot it under
			self.waveData = it.read_raw_waveform_file_data( self.root.waveFileS, 
						self.waveHead,
						sSam, 
						numSam,
						lpf=self.waveLpf )
			#normalize the wavedata
			self.waveData[1,:] -= min( self.waveData[1,:] )
			self.waveData[1,:] /= max( self.waveData[1,:] )
			self.waveData[1,:] *= self.data.elRange[1]-self.data.elRange[0]
			self.waveData[1,:] += self.data.elRange[0]
			self.waveData[0,:] += self.data.tOffset

	def redraw(self):
		if self._draw_pending:
			self._draw_counter += 1
			return
		def _draw():
			self.figure_canvas.draw()
			self._draw_pending = False
			if self._draw_counter > 0:
				self._draw_counter = 0
				self.redraw()
		wx.CallLater(40, _draw).Start()
		self._draw_pending = True
Example #14
0
class MatplotlibWidget(FigureCanvas):
    """
    Class handling 1D, 2D, 3D data and displaying them in a canvas as 1D curve
    graphs, 2D images or multiple 2D images.
    """

    MASK_HIDDEN = 0
    MASK_SHOWN = 1
    MASK_ONLY = 2

    def __init__(self, graphMode=None, parent=None, name=None, width=5, height=4,
                 dpi=100, bgColor=None, valueRange=None,
                 maskLabels=None):
        """
        Create matplotlib 'front-end' widget which can render 1D,2D,3D data as
        1D or 2D graphs and handle masks.
        """
        if debug : print '**xndarrayViewRenderer.__init__  ...'
        self.parent = parent

        if graphMode: self.graphMode = graphMode
        else: self.graphMode = viewModes.MODE_2D

        self.fwidth = width
        self.fheight = height
        self.dpi = dpi

        # Will define the range of the colormap associated to values:
        if debug: print 'valueRange :', valueRange
        #valueRange = [0.001, 0.2] #noise var
        #valueRange = [0.001, 0.5] #noise ARp
        #valueRange = [0, 11]
        if valueRange is not None:
            self.norm = Normalize(valueRange[0],
                                  valueRange[1]+_N.abs(valueRange[1])*.01,
                                  clip=True)
            self.backgroundValue = valueRange[0] - 100
        else:
            self.norm = None
            self.backgroundValue = 0 #?
        # Define the range of the colormap associated to the mask:
        # will be used to draw contours of mask
        self.maskCm = None
        self.maskLabels = maskLabels
        if debug: print '######### maskLabels :', maskLabels
        if maskLabels is not None:
            _N.random.seed(1) # ensure we get always the same random colors
            #TODO: put the random seed back in the same state as before!!!
            rndm = _N.random.rand(len(maskLabels),3)
            # black:
            #fixed = _N.zeros((len(maskLabels),3)) + _N.array([0.,0.,0.])
            # green:
            #fixed = _N.zeros((len(maskLabels),3)) + _N.array([0.,1.,0.])
            #white:
            fixed = _N.zeros((len(maskLabels),3)) + _N.array([1.,1.,1.])
            # Create uniform colormaps for every mask label
            # self.maskCm = dict(zip(maskLabels,
            #                       [ListedColormap([ tuple(r) ]) for r in rndm]))
            self.maskCm = dict(zip(maskLabels,
                                   [ListedColormap([tuple(r)]) for r in fixed]))
        self.displayMaskFlag = self.MASK_HIDDEN

        # Set the color of the widget background
        if self.parent:
            bgc = parent.backgroundBrush().color()
            #bgcolor = float(bgc.red())/255.0, float(bgc.green())/255.0, \
            #          float(bgc.blue())/255.0
            bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())
        else: bgcolor = 'w'

        # Create the matplotlib figure:
        self.fig = Figure(figsize=(width, height), dpi=dpi,
                          facecolor=bgcolor, edgecolor=bgcolor)
        # Size of the grid of plots:
        self.subplotsH = 0
        self.subplotsW = 0
        self.axes = None
        self.showAxesFlag = True
        self.showAxesLabels = True

        # Init the parent Canvas:
        FigureCanvas.__init__(self, self.fig)

        # Some QT size stuffs
        self.reparent(parent, QPoint(0, 0))
        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        # Color bar related stuffs:
        self.showColorBar = False
        self.colorBars = None
        # color associated to position where mask=0 :
        self.bgColor = 'w'#QColor('white') if bgColor==None else bgColor

        # Default colormap (~rainbow) : black-blue-green-yellow-red
        self.colorMapString = '0;0;0.5;0.0;0.75;1.0;1.;1.0#' \
                              '0;0;0.5;1;0.75;1;1;0.#'       \
                              '0;0;0.25;1;0.5;0;1;0.'
        self.setColorMapFromString(self.colorMapString)
        self.update()

        # Signal stuffs:
        #self.mpl_connect('button_release_event', self.onRelease)
        self.mpl_connect('motion_notify_event', self.onMove)
        self.mpl_connect('button_press_event', self.onClick)

    def setBackgroundColor(self, color):
        if type(color) == QColor:
            self.bgColor = tuple([c/255. for c in color.getRgb()])
        else:
            self.bgColor = color # assume letter code/ hex color string / ...
                                 # anything supported by matplotlib.colors
                                 #TODO maybe do some checks ?
        if debug:print 'setting over/under color ', self.bgColor
        self.colorMap.set_under(color=self.bgColor)
        self.colorMap.set_over(color=self.bgColor)

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QSize(w, h)

    def getColorMapString(self):
        return self.colorMapString

    def setColorMapFromString(self, scm):

        if debug : print 'MatplotlibWidget.setColorMapFromString'
        scm = str(scm)
        if debug : print ' recieved scm:', scm
        # Make sure we do not have any negative values :
        subResult = re.subn('(\+[0-9]+(\.\d*)?)','0.',scm)
        if debug and subResult[1] > 0:
            print ' !!! found negative values !!!'
            sys.exit(1)
        scm = subResult[0]
        self.colorMapString = scm
        if debug : print ' negative filtered scm :', scm

        self.colorMap = cmstring_to_mpl_cmap(scm)
        # Set colors corresponding to [minval-1 , minval] to background color:
        self.setBackgroundColor(self.bgColor)

    def setGraphMode(self, m):
        if self.graphMode == m:
            return
        if debug: print 'xndarrayViewRenderer.setGraphMode:',m
        self.graphMode = m
        self.computeFigure()

    def minimumSizeHint(self):
        return QSize(100, 100)


    def setMaskDisplay(self, flag):
        if self.displayMaskFlag == flag:
            return
        self.displayMaskFlag = flag
        self.resetGraph()
        self.refreshFigure()

    def toggleColorbar(self):
        if debug: print "MatplotlibWidget toggleColorbar..."
        self.showColorBar = not self.showColorBar
        if debug: print ' showColorBar = %s' %(['On','Off'][self.showColorBar])
        #if not self.showColorBar:
        self.resetGraph()

        #self.resetGraph()
        self.refreshFigure()
        #self.draw()

    def hideAxes(self):
        if not self.showAxesFlag:
            return
        self.showAxesFlag = False
        for ax in self.axes:
            ax.set_axis_off()
        self.adjustSubplots()
        self.draw()

    def showAxes(self):
        if self.showAxesFlag:
            return
        self.showAxesFlag = True
        for ax in self.axes:
            ax.set_axis_on()
        self.adjustSubplots()
        self.draw()

    def toggleAxesLabels(self):
        self.showAxesLabels = not self.showAxesLabels
        for a in self.axes:
            self.plotAxesLabels(a)
        self.adjustSubplots()
        self.draw()

    def setMaskLabel(self, label):
        if debug: print "MatplotlibWidget setMaskLabel ..."
        self.maskLabel = label
        self._applyMask()
        self.resetGraph()
        self.refreshFigure()


    def _applyMask(self):
        if debug: print "MatplotlibWidget applyMask ..."
        self.maskedValues = self.values.copy()
        self.maskedErrors = None if self.errors==None else self.errors.copy()
        if self.mask != None:
            if debug: print 'self.maskLabel:', self.maskLabel
            if self.maskLabel != 0:
                m = (self.mask != self.maskLabel)
            else:
                m = (self.mask == 0)
            #print 'm :', m
            if debug: print 'backgroundValue :', self.backgroundValue
            self.maskedValues[m] = self.backgroundValue
            if self.errors != None:
                self.maskedErrors[m] = 0

    def updateData(self, cub, mask=None, maskLabel=0, maskName='mask',
                   otherSlices=None):
        if debug:
            print "MatplotlibWidget update data ..."
            print "Got cuboid:"
            print cub.descrip()
            if mask is not None:
                print 'Got mask:', mask.shape
                print mask
        # Retrieve data and make everything be 3D-shaped:
        deltaDims = 3-cub.get_ndims()
        targetSh = cub.data.shape + (1,) * deltaDims
        self.domainNames = cub.axes_names + [None] * deltaDims
        self.domainValues = [cub.axes_domains[a] for a in cub.axes_names] + \
            [None] * deltaDims
        self.values = cub.data.reshape(targetSh)
        self.mask = None if mask==None else mask.reshape(targetSh)
        self.maskName = maskName
        #self.errors = None if cub.errors==None else cub.errors.reshape(targetSh)
        self.errors = None
        self.valueLabel = cub.value_label
        self.maskLabel = maskLabel
        self.otherSlices = otherSlices
        self._applyMask()
        if debug:
            print 'MatplotlibWidget.update: got view:'
            print '**Domains:'
            for i in xrange(len(self.domainNames)):
                if self.domainNames[i]!=None:
                    if _N.isreal(self.domainNames[i][0]) :
                        print '%s : [%d %d]' %(self.domainNames[i],
                                               self.domainValues[i].min(),
                                               self.domainValues[i].max())
                    else:
                        print '%s : [%s ... %s]' %(self.domainNames[i],
                                                   self.domainValues[i][0],
                                                   self.domainValues[i][-1])

            if debug:
                print 'self.maskedValues:', self.maskedValues.shape
                print self.maskedValues[:,:,0]
            #print 'self.values :', self.values.shape
            if self.errors != None: print 'self.errors:', self.maskedErrors.shape

        self.emit(PYSIGNAL('valuesChanged'),(self.getShortDataDescription(),))
        self.computeFigure()

    def getShortDataDescription(self):
        if self.mask!=None:
            if self.maskLabel > 0:
                v = self.values[self.mask == self.maskLabel]
            else: v = self.values[self.mask != 0]
        else:
            v = self.maskedValues
        if v.size > 0:
            descr = '%1.3g(%1.3g)[%1.3g,%1.3g]' %(v.mean(),v.std(),
                                                  v.min(),v.max())
        else: descr = ''
        return descr

    def resetGraph(self):
        if debug : print 'MatplotlibWidget.resetGraph ...'

        if self.colorBars != None:
            if debug: print 'self.colorbars :', self.colorBars
            for cb in self.colorBars:
                self.fig.delaxes(cb.ax)
            self.colorBars = None

        if self.graphMode == viewModes.MODE_HIST:
            if self.axes:
                for a in self.axes:
                    self.fig.delaxes(a)

            ax = self.fig.add_subplot(111)
            ax.zValue = 0
            self.axes = [ax]
            #self.axes.clear()
            self.axes[0].hold(False)

        else:
            nbZVal = self.values.shape[2]
            if debug: print 'nbZVal :', nbZVal
            h = round(nbZVal**.5)
            w = _N.ceil(nbZVal/h)
            if 1 or (h != self.subplotsH or w != self.subplotsW):
                if self.axes:
                    for a in self.axes:
                        self.fig.delaxes(a)
                self.subplotsW = w
                self.subplotsH = h
                self.axes = []
                if debug: print 'add subplots : w=%d, h=%d' %(w,h)
                for i in xrange(1,nbZVal+1):
                   #if debug: print ' add(%d, %d, %d)' %(h,w,i)
                   ax = self.fig.add_subplot(h,w,i)
                   ax.zValue = i-1
                   self.axes.append(ax)

        self.adjustSubplots()

    def adjustSubplots(self):
        if debug: print "adjustSubplots ..."
        if self.graphMode == viewModes.MODE_HIST:
            self.fig.subplots_adjust(left=0.2, right=.95,
                                     bottom=0.2, top=.9,
                                     hspace=0., wspace=0.)
            return

        if self.values.shape[2] == 1:
            if not self.showAxesFlag:
                self.fig.subplots_adjust(left=0.2, right=.8,
                                         bottom=0.2, top=.8,
                                         hspace=0.05, wspace=0.05)
            elif not self.showAxesLabels:
                if self.graphMode == viewModes.MODE_2D:
                    self.fig.subplots_adjust(left=0.05, right=.95,
                                             bottom=0.01, top=.95,
                                             hspace=0., wspace=0.)
                else:
                    self.fig.subplots_adjust(left=0.2, right=.95,
                                             bottom=0.1, top=.95,
                                             hspace=0., wspace=0.)

            else:
                self.fig.subplots_adjust(left=0.1, right=.95,
                                         bottom=0.1, top=.9,
                                         hspace=0.05, wspace=0.05)
        else:
            if not self.showAxesFlag:
                self.fig.subplots_adjust(left=0., right=1.,
                                         bottom=0., top=1.,
                                         hspace=0.05, wspace=0.05)
            elif not self.showAxesLabels:
                self.fig.subplots_adjust(left=0.1, right=.9,
                                         bottom=0.2, top=.9,
                                         hspace=0.7, wspace=0.3)
            else:
                self.fig.subplots_adjust(left=0.1, right=.9,
                                         bottom=0.1, top=.9,
                                         hspace=0.01, wspace=0.7)

    def showMask(self, splot, mask):
        if mask != None:
            nr, nc = mask.shape
            extent = [-0.5, nc-0.5, -0.5, nr-0.5]

            if self.displayMaskFlag == self.MASK_SHOWN:
                labels = _N.unique(mask)
                for il, lab in enumerate(labels):
                    if lab != 0:
                        if self.maskCm!=None : cm = self.maskCm[lab]
                        else : cm = get_cmap('binary')
                        splot.contour((mask==lab).astype(int), 1,
                                      cmap=cm, linewidths=2., extent=extent,alpha=.7)
                                     #cmap=cm, linewidths=1.5, extent=extent,alpha=.7)
            if self.displayMaskFlag == self.MASK_ONLY:
                if self.maskLabel == 0:
                    labels = _N.unique(mask)
                    for il, lab in enumerate(labels):
                        if lab != 0:
                            if self.maskCm != None:
                                cm = self.maskCm[lab]
                            else:
                                cm = get_cmap('binary')
                            ml = (mask==lab).astype(int)
                            print 'contouf of mask label %d -> %d pos' \
                                %(lab, ml.sum())
                            splot.contourf(ml, 1,
                                           cmap=cm, linewidths=1., extent=extent)
                elif (mask==self.maskLabel).sum() > 0:
                    if self.maskCm != None:
                        cm = self.maskCm[_N.where(mask==self.maskLabel)[0]]
                    else:
                        cm = get_cmap('binary')
                    ax.contourf((mask==self.maskLabel).astype(int), 1,
                                cmap=cm, linewidths=1.5, extent=extent)

    def plot1D(self):
        if debug: print 'MatplotlibWidget.computeFigure: MODE_1D'
        di2 = 0
        nbins = 100.

        d1 = self.domainValues[1]
        d0 = self.domainValues[0]
        plotPDF = False
        if d0[0] == 'mean' and len(d0)>1 and d0[1] == 'var':
            plotPDF = True
            if (self.values[1,:] < 10.).all():
                xMin = (self.values[0,:] - 6*self.values[1,:]**.5).min()
                xMax = (self.values[0,:] + 6*self.values[1,:]**.5).max()
            else:
                xMin = (self.values[0,:] - 10*self.values[1,:]**.5).min()
                xMax = (self.values[0,:] + 10*self.values[1,:]**.5).max()
            bins = _N.arange(xMin, xMax, (xMax-xMin)/nbins)

        x = d0 if _N.isreal(d0[0]) else _N.arange(len(d0))

        me = self.maskedErrors.max() if self.errors!=None else 0
        yMax = self.maskedValues.max()+me
        yMin = self.maskedValues.min()-me
        if 1 or self.errors !=None :
            dy = (yMax-yMin)*0.05
            dx = (x.max()-x.min())*0.05
        else: dx,dy = 0,0

        for ax in self.axes:
            ax.hold(True)
            #ax.set_axis_off()
            vSlice = self.maskedValues[:,:,di2]
            if self.errors != None:
                errSlice = self.maskedErrors[:,:,di2]
            for di1 in xrange(self.values.shape[1]):
                if plotPDF:
                    plotNormPDF(ax, bins, vSlice[0,di1], vSlice[1,di1])
                else:
                    print 'di1:',di1
                    print 'domainValues:', self.domainValues
                    if self.domainValues[1] is not None:
                        val = str(self.domainNames[1]) + ' = ' + \
                            str(self.domainValues[1][di1])
                    else:
                        val = 'nothing'
                    ax.plot(x, vSlice[:,di1], picker=ValuePicker(val))
                    #ax.plot(vSlice[:,di1], picker=ValuePicker(val))
                    if not _N.isreal(d0[0]):
                        setTickLabels(self.axes[0].xaxis, d0)
                    if self.errors != None and not _N.allclose(errSlice[:,di1],0.) :
                        ax.errorbar(x, vSlice[:,di1], errSlice[:,di1], fmt=None)
            if not plotPDF:
                ax.set_xlim(x.min()-dx, x.max()+dx)
                ax.set_ylim(yMin-dy, yMax+dy)
            elif ax.get_ylim()[1] > 1.0:
                ax.set_ylim(0, 1.)
            if not self.showAxesFlag:
                ax.set_axis_off()

            self.plotAxesLabels(ax)

            #ax.set_title(self.domainNames[2]+' ' \
            #             +str(self.domainValues[2][di2]))
            di2 += 1

    def plotAxesLabels(self, axis):
        if not self.showAxesLabels:
            axis.set_xlabel('')
            axis.set_ylabel('')
            return

        if self.graphMode == viewModes.MODE_1D:
            axis.set_xlabel(self.domainNames[0])
            axis.set_ylabel(self.valueLabel)
        elif self.graphMode == viewModes.MODE_2D:
            axis.set_ylabel(self.domainNames[0])

            if self.domainValues[1] != None:
                axis.set_xlabel(self.domainNames[1])
        else: #MODE_HIST
            axis.set_xlabel(self.valueLabel)
            axis.set_ylabel('density')


    def plot2D(self):
        if debug: print 'MatplotlibWidget.computeFigure: MODE_2D'
        di2 = 0
        self.colorBars = []
        for ax in self.axes:
            ax.hold(True)
            if self.mask != None:
                self.showMask(ax, self.mask[:,:,di2])
            #print 'maskedValues:', self.maskedValues.min(), self.maskedValues.max()
            if not hasattr(ax, 'matshow'): # matplotlib version < 0.9:
                ms = _matshow(ax, self.maskedValues[:,:,di2], cmap=self.colorMap,
                              norm=self.norm)
            else:
                ms = ax.matshow(self.maskedValues[:,:,di2], cmap=self.colorMap,
                                norm=self.norm)

            if self.showColorBar and len(self.axes)<2:
                if debug: print ' plot colorbar ...'
                self.colorBars.append(self.fig.colorbar(ms))

            if not self.showAxesFlag:
                ax.set_axis_off()
            else:
                setTickLabels(ax.yaxis, self.domainValues[0])
                if self.domainValues[1] != None:
                    setTickLabels(ax.xaxis, self.domainValues[1])

                self.plotAxesLabels(ax)

            #ax.set_title(self.domainNames[2]+' ' \
            #             +str(self.domainValues[2][di2]))
            di2 += 1

    def plotHist(self):
        if debug: print 'MatplotlibWidget.computeFigure: MODE_HIST'

        v = self.values
        if 0 and self.mask != None:
            if debug: print 'self.mask:', _N.unique(self.mask)
            if self.maskLabel > 0:
                #if debug: print 'self.values[self.mask == %d] :' %self.maskLabel
                #if debug: print self.values[self.mask == self.maskLabel]
                vs = [self.values[self.mask == self.maskLabel]]
            else:
                #if debug: print 'self.values[self.mask != 0] :'
                #if debug: print self.values[self.mask != 0]
                vs = [self.values[self.mask == ml] for ml in self.maskLabels
                      if ml!=0]

        else:
            vs = [self.values]
        bins = 30
        #bins = _N.arange(0,1.,0.05)
        normed = True
        colors = ['b','r','g']
        n,bins = _N.histogram(self.values, bins=bins, normed=normed)

        for iv, v in enumerate(vs):
            if v.size > 0:
                fColor = colors[iv]
                alpha = 0.5 if iv > 0 else 1.
                self.axes[0].hold(True)
                n,b,p = self.axes[0].hist(v, bins=bins, normed=normed, fc=fColor,
                                          alpha=alpha)
                #if type(bins) == int :
                #    bins = b
            else:
                print "Nothing in histogram"

        self.plotAxesLabels(self.axes[0])

    def computeFigure(self):
        if debug: print 'MatplotlibWidget.computeFigure: ...'

        # Reset subplots adjustment:
        #self.fig.subplots_adjust(left=0.15, right=.9, bottom=0.2, top=.8,
        #                         hspace=0.1, wspace=0.1)
        self.resetGraph()
        self.refreshFigure()

    def refreshFigure(self):
        if debug : print 'MatplotlibWidget.refreshFigure ...'
        if self.graphMode == viewModes.MODE_1D:
            self.plot1D()
        elif self.graphMode == viewModes.MODE_2D:
            self.plot2D()
        elif self.graphMode == viewModes.MODE_HIST:
            self.plotHist()

        if debug: print 'fig:', self.fig.get_figwidth(), self.fig.get_figheight()

        self.draw()

    def save(self, fn):
        if debug : print 'MatplotlibWidget: self.fig.savefig ...'
        self.fig.savefig(fn)


    def printVal(self, v):
        if not _N.isreal(v): # assume string
            return v
        #else assume number
        elif int(v) == v:  #integer
            return '%d' %v
        else: # float
            return '%1.3g' %v

    def onClick(self, event):
        if debug:
            print 'mpl press event !'
        if not event.inaxes:
            return
        i = round(event.ydata)
        j = round(event.xdata)
        if hasattr(event.inaxes, 'zValue'):
            #if self.graphMode == viewModes.MODE_2D:
            k = event.inaxes.zValue
        else:
            k = -1
        if debug:
            print 'click on :', (i,j,k)
            print 'self.values.shape :', self.values.shape
        if self.otherSlices is not None:
            pos = self.otherSlices.copy()
        else:
            pos = {}
        if self.graphMode == viewModes.MODE_2D:
            if i<self.values.shape[0] and j<self.values.shape[1]:
                for n in xrange(3):
                    if debug: print 'self.domainNames[n]:', self.domainNames[n]
                    if self.domainNames[n] != None:
                        dv = self.domainValues[n][[i,j,k][n]]
                        pos[self.domainNames[n]] = dv
                if self.mask is not None:
                    pos[self.maskName] = self.mask[i,j,k]
        pos[self.valueLabel] = self.values[i,j,k]
        if debug:
            print '-> ', pos
            print "emitting positionClicked ..."
        self.emit(PYSIGNAL("positionClicked"), (pos.keys(), pos.values()) )


    def onMove(self, event):
        if event.inaxes and hasattr(event.inaxes, 'zValue'):
            #print 'zVal:', event.inaxes.zValue
            k = event.inaxes.zValue
            i = round(event.ydata)
            j = round(event.xdata)
            #print 'xdata : %f, ydata : %f' %(event.xdata, event.ydata)
            #print 'i:%d, j:%d, k:%d' %(i,j,k)

            if self.graphMode == viewModes.MODE_2D:
                if i >=self.values.shape[0] or j>=self.values.shape[1]:
                    msg = ''
                else:
                    if self.mask==None or (self.mask[i,j,k] != 0 and          \
                                           (self.maskLabel==0 or          \
                                            self.mask[i,j,k]==self.maskLabel)):
                        msg = '%s: %1.3g' %(self.valueLabel, self.values[i,j,k])
                    else:
                        msg = 'background'
                    if self.mask != None:
                        msg += ', %s:%d' %(self.maskName, self.mask[i,j,k])

                if msg != '':
                    for n in xrange(3):
                        if self.domainNames[n] is not None:
                            dv = self.domainValues[n][[i,j,k][n]]
                            msg += ', %s: %s' \
                                   %(self.domainNames[n], self.printVal(dv))
                    if self.errors != None:
                        msg += ', error: %1.3g' %(self.errors[i,j,k])

            elif self.graphMode == viewModes.MODE_1D:
                msg = '%s: %1.3g, %s: %1.3g' %(self.domainNames[0],event.xdata,
                                               self.valueLabel,event.ydata)


            elif self.graphMode == viewModes.MODE_HIST:
                msg = '%s: %1.3g, %s: %1.3g' %(self.valueLabel,event.xdata,
                                               'freq',event.ydata)
        else:
            msg = ''
        self.emit(PYSIGNAL('pointerInfoChanged'), (msg,))


    def onRelease(self, event):
        if debug:print event.name
        if event.inaxes:
            self.matData = _N.random.randn(10,10)
            self.ai.set_data(self.matData)
Example #15
0
from matplotlib import pyplot as plt

fig = plt.figure()
r = fig.canvas.get_renderer()
t = plt.title('0\n0')    
ax = fig.gca()

ax.set_ylabel('a')
ax.set_xlabel('b\n\n\nb')
ax.get_xlabel().get_position().get_points()

bb = t.get_window_extent(renderer=r)

dpi = fig.dpi
fw = fig.get_figwidth()
fh = fig.get_figheight()

ax.get_yticklabels()[0].get_window_extent(r).get_points()


yt.set(text='aaa') # Does not work.
ylabels = [item.get_text() for item in ax.get_yticklabels()]
ylabels[0] = 'aaa'
ax.set_yticklabels(ylabels)
 
# Need to iterate over all labels to set properties for each.
yt = ax.get_yticklabels()[0]
yt.set(color='yellow')

Example #16
0
class PlotWidget(FigureCanvas):
    """a QtWidget and FigureCanvasAgg that displays the model plots"""

    def __init__(self, parent=None, name=None, width=5, height=4, dpi=100, bgcolor=None):
        self.fig = Figure(figsize=(width, height), dpi=dpi, facecolor=bgcolor, edgecolor=bgcolor)
        self.axes = self.fig.add_axes([0.125,0.1,0.6,0.8]) #self.fig.add_subplot(111)

        self.have_range = []

        self.axes.set_xlabel('Time (years)')
        self.axes.set_ylabel(
                "Greenhouse Gas \nEmissions/Removals (t CO"+SUB_2+ "e / ha)",
                multialignment='center')

        self.axes.axhline(y=0.,ls=':',color='k')
        self.axes.text(0, 0,'Emission',va='bottom')
        self.axes.text(0, 0,'Removal',va='top')
        self.emission_patch = None
        self.removal_patch = None
        self.updateShading()

        self.totalsBox = self.fig.text(0.75,0.1,"")

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        self.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
                           QtWidgets.QSizePolicy.Expanding)
        self.updateGeometry()

        self.plots = {}
        self.totals = {}

    def addPlot(self,data,plotID,name):
        # No support for ranges (uncertainties) yet
        plots = []
        plots += self.axes.plot(
                np.array(range(len(data)))+1, # to make it start at year 1
                data,color=COLOURS[plotID],label=name)
        totals = (name,np.sum(data))

        self.axes.autoscale(enable=True)
        self.axes.relim()
        self.plots[plotID] = plots
        self.totals[plotID] = totals
        self.updateLegend()
        self.updateTotals()
        self.updateShading()
        self.draw()

    def updateShading(self):
        # update vertical limits
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        ylim = (min(ylim[0],-1),max(ylim[1],1))
        self.axes.set_ylim(ylim)
        upper = [(xlim[0],0),(xlim[1],0),(xlim[1],ylim[1]),(xlim[0],ylim[1])]
        lower = [(xlim[0],ylim[0]),(xlim[1],ylim[0]),(xlim[1],0),(xlim[0],0)]
        # shade positive and negative values
        if self.emission_patch == None:
            self.emission_patch = Polygon(
                    upper, facecolor='r', alpha=0.25, fill=True,
                    edgecolor=None, zorder=-1000)
            self.axes.add_patch(self.emission_patch)
        else:
            self.emission_patch.set_xy(upper)
        if self.removal_patch == None:
            self.removal_patch = Polygon(
                    lower, facecolor='b', alpha=0.25, fill=True,
                    edgecolor=None, zorder=-1000)
            self.axes.add_patch(self.removal_patch)
        else:
            self.removal_patch.set_xy(lower)
        #self.removal_patch = None


    def updateTotals(self):
        # No support for ranges (yet)
        # get model names (as keys)
        models = {}
        for c in self.totals:
            models[self.totals[c][0]] = self.totals[c][1:]
        ms = sorted(models.keys())
        
        years =  self.axes.dataLim.get_points()[1,0]
        outstr = "Total Emissions/Removals\nover %d years " % cfg.N_ACCT
        outstr += "(t CO"+SUB_2+"e / ha):\n"
        for m in ms:
            outstr += '%s: %.1f\n' % (m, models[m][0])
        
        # Figure out total of project - baseline for each project
        if len(self.totals)>1 and 0 in self.totals:
            outstr+="\nNet impact (t CO"+SUB_2+"e / ha):\n"
            for c in self.totals:
                if c == 0: 
                    continue
                outstr += "%s: %.1f\n" % (
                        self.totals[c][0],
                        self.totals[c][1]-self.totals[0][1])

        self.totalsBox.set_text(outstr)

    def removePlot(self,plotID):
        if plotID in self.plots:
            for p in self.plots[plotID]:
                p.remove()
            self.draw()
            del self.totals[plotID]
            del self.plots[plotID]
            if plotID in self.have_range:
                self.have_range.remove(plotID)
            self.updateLegend()
            self.updateTotals()
        

    def updateLegend(self):
        if len(self.plots.keys())>0:
            handles, labels = self.axes.get_legend_handles_labels()
            if len(self.have_range) > 0:
                l = Line2D([0,1],[0,1],linestyle="--",color='k')
                handles.append(l)
                labels.append("range")
            self.axes.legend(handles,labels,bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        else:
            self.axes.legend_ = None

    def saveFigure(self,fname):
        self.fig.savefig(fname)
                

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QtCore.QSize(w, h)

    def minimumSizeHint(self):
        return QtCore.QSize(10, 10)
Example #17
0
class MapWidget(FigureCanvas):
    """a QtWidget and FigureCanvasAgg that displays a map"""
    def __init__(self,
                 parent=None,
                 name=None,
                 width=5,
                 height=4,
                 dpi=100,
                 bgcolor=None):
        self.fig = Figure(figsize=(width, height),
                          dpi=dpi,
                          facecolor=bgcolor,
                          edgecolor=bgcolor)
        self.axes = self.fig.add_subplot(111)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        self.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
                           QtWidgets.QSizePolicy.Expanding)
        self.updateGeometry()

        self.mpl_connect('button_press_event', self.onpick)

        self.m = None
        self.point = None
        self.pickCallback = None

    def onpick(self, event):
        if event.inaxes != None and self.pickCallback != None:
            self.pickCallback(self.m(event.xdata, event.ydata, inverse=True))

    def drawBasemap(self, basemapData):
        if basemapData.projection == 'tmerc':
            self.m = Basemap(llcrnrlon=basemapData.llcrnrlon,
                             llcrnrlat=basemapData.llcrnrlat,
                             urcrnrlon=basemapData.urcrnrlon,
                             urcrnrlat=basemapData.urcrnrlat,
                             resolution='i',
                             projection=basemapData.projection,
                             lon_0=basemapData.lon_0,
                             lat_0=basemapData.lat_0,
                             ax=self.axes)
        else:
            raise ValueError, 'Unknown Projection %s' % basemapData.projection

        if basemapData.geotiff == None:
            self.m.fillcontinents(color='coral', lake_color='aqua')
            self.m.drawmapboundary(fill_color='aqua')
        else:
            img = matplotlib.image.imread(basemapData.geotiff)
            self.m.imshow(img)

        self.m.drawcoastlines()
        # draw parallels and meridians.
        self.m.drawcountries()
        self.m.drawparallels(numpy.arange(
            basemapData.first_parallel,
            basemapData.last_parallel + basemapData.delta_parallel,
            basemapData.delta_parallel),
                             labels=[False, True, True, False])
        self.m.drawmeridians(numpy.arange(
            basemapData.first_meridian,
            basemapData.last_meridian + basemapData.delta_meridian,
            basemapData.delta_meridian),
                             labels=[True, False, False, True])

        x, y = self.m([
            basemapData.llcrnrlon +
            (basemapData.urcrnrlon - basemapData.llcrnrlon) / 2.
        ], [
            basemapData.llcrnrlat +
            (basemapData.urcrnrlat - basemapData.llcrnrlat) / 2.
        ])

    def plotPoint(self, lon, lat):
        if self.m != None:
            x, y = self.m(lon, lat)
            if self.point != None:
                self.point.set_data(x, y)
            else:
                self.point = self.m.plot(x, y, marker='o', color='k')[0]
            self.draw()

    def plotLocations(self, locations):
        for loc in locations.locations:
            x, y = self.m(loc[0], loc[1])
            self.m.plot(x, y, marker='o', color='r')

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QtCore.QSize(w, h)

    def minimumSizeHint(self):
        return QtCore.QSize(10, 10)

    def saveFigure(self, fname):
        self.fig.savefig(fname)
Example #18
0
class Chart(object):
    """
    Simple and clean facade to Matplotlib's plotting API.
    
    A chart instance abstracts a plotting device, on which one or
    multiple related plots can be drawn. Charts can be exported as images, or
    visualized interactively. Each chart instance will always open in its own
    GUI window, and this window will never block the execution of the rest of
    the program, or interfere with other L{Chart}s.
    The GUI can be safely opened in the background and closed infinite number
    of times, as long as the client program is still running.
    
    By default, a chart contains a single plot:
    
    >>> chart.plot
    matplotlib.axes.AxesSubplot
    >>> chart.plot.hist(...)
    
    If C{rows} and C{columns} are defined, the chart will contain
    C{rows} x C{columns} number of plots (equivalent to MPL's sub-plots).
    Each plot can be assessed by its index:
    
    >>> chart.plots[0]
    first plot
    
    or by its position in the grid:
    
    >>> chart.plots[0, 1]
    plot at row=0, column=1
    
    @param number: chart number; by default this a L{Chart.AUTONUMBER}
    @type number: int or None
    @param title: chart master title
    @type title: str
    @param rows: number of rows in the chart window
    @type rows: int
    @param columns: number of columns in the chart window
    @type columns: int
    
    @note: additional arguments are passed directly to Matplotlib's Figure
           constructor. 
    """

    AUTONUMBER = None
    
    _serial = 0
    
    
    def __init__(self, number=None, title='', rows=1, columns=1, backend=Backends.WX_WIDGETS, *fa, **fk):
        
        if number == Chart.AUTONUMBER:
            Chart._serial += 1
            number = Chart._serial
        
        if rows < 1:
            rows = 1
        if columns < 1:
            columns = 1
            
        self._rows = int(rows)
        self._columns = int(columns)
        self._number = int(number)
        self._title = str(title)
        self._figure = Figure(*fa, **fk)
        self._figure._figure_number = self._number
        self._figure.suptitle(self._title)
        self._beclass = backend
        self._hasgui = False
        self._plots = PlotsCollection(self._figure, self._rows, self._columns)        
        self._canvas = FigureCanvasAgg(self._figure)
        
        formats = [ (f.upper(), f) for f in self._canvas.get_supported_filetypes() ]
        self._formats = csb.core.Enum.create('OutputFormats', **dict(formats))
    
    def __getitem__(self, i):
        if i in self._plots:
            return self._plots[i]
        else:
            raise KeyError('No such plot number: {0}'.format(i))
        
    def __enter__(self):
        return self
    
    def __exit__(self, *a, **k):
        self.dispose()
    
    @property
    def _backend(self):
        return Backend.get(self._beclass, started=True)

    @property
    def _backend_started(self):
        return Backend.query(self._beclass)
      
    @property
    def title(self):
        """
        Chart title
        @rtype: str
        """
        return self._title
        
    @property
    def number(self):
        """
        Chart number
        @rtype: int
        """        
        return self._number
    
    @property
    def plots(self):
        """
        Index-based access to the plots in this chart
        @rtype: L{PlotsCollection}
        """
        return self._plots
    
    @property
    def plot(self):
        """
        First plot
        @rtype: matplotlib.AxesSubplot
        """
        return self._plots[0]
    
    @property
    def rows(self):
        """
        Number of rows in this chart
        @rtype: int
        """
        return self._rows
    
    @property
    def columns(self):
        """
        Number of columns in this chart
        @rtype: int
        """        
        return self._columns
    
    @property
    def width(self):
        """
        Chart's width in inches
        @rtype: int
        """
        return self._figure.get_figwidth()
    @width.setter
    def width(self, inches):
        self._figure.set_figwidth(inches)
        if self._backend_started:
            self._backend.resize(self._figure)

    @property
    def height(self):
        """
        Chart's height in inches
        @rtype: int
        """        
        return self._figure.get_figheight()
    @height.setter
    def height(self, inches):
        self._figure.set_figheight(inches)
        if self._backend_started:
            self._backend.resize(self._figure)
                
    @property
    def dpi(self):
        """
        Chart's DPI
        @rtype: int
        """        
        return self._figure.get_dpi()
    @dpi.setter
    def dpi(self, dpi):
        self._figure.set_dpi(dpi)
        self._backend.resize(self._figure)
            
    @property
    def formats(self):
        """
        Supported output file formats
        @rtype: L{csb.core.enum}
        """
        return self._formats
            
    def show(self):
        """
        Show the GUI window (non-blocking).
        """
        if not self._hasgui:
            self._backend.add(self._figure)
            self._hasgui = True
            
        self._backend.show(self._figure)
                
    def hide(self):
        """
        Hide (but do not dispose) the GUI window.
        """
        self._backend.hide(self._figure)
        
    def dispose(self):
        """
        Dispose the GUI interface. Must be called at the end if any
        chart.show() calls have been made. Automatically called if using
        the chart in context manager ("with" statement).
        
        @note: Failing to call this method if show() has been called at least
        once may cause backend-related errors.
        """
        if self._backend_started:
        
            service = self._backend
            
            if service and service.running:
                service.destroy(self._figure, wait=True)
                service.client_disposed(self)    
        
    def save(self, file, format='png', crop=False, dpi=None, *a, **k):
        """
        Save all plots to an image.
        
        @param file: destination file name
        @type file: str
        @param format: output image format; see C{chart.formats} for enumeration
        @type format: str or L{csb.core.EnumItem}
        @param crop: if True, crop the image (equivalent to MPL's bbox=tight)
        @type crop: bool
                
        @note: additional arguments are passed directly to MPL's savefig method
        """
        if 'bbox_inches' in k:
            bbox = k['bbox_inches']
            del k['bbox_inches']
        else:
            if crop:
                bbox = 'tight'
            else:
                bbox = None
            
        self._canvas.print_figure(file, format=str(format), bbox_inches=bbox, dpi=dpi, *a, **k)
Example #19
0
class PlotWindow:
    def __init__(self, plot, title="", lines=[], shown=False):

        self.plot=plot

        self.window=None
        self.vbox=None
        self.figure=None
        self.canvas=None
        self.axes=None
        self.legend=None

        self.show_cursors=False

        self.plot.shown=shown
        if shown:
            self.show()


    def show(self):
        self.vbox = gtk.VBox()
        self.figure = Figure(figsize=(5,4))

        self.figure.set_size_inches(self.plot.figwidth, self.plot.figheight)

        self.window = gtk.Window()
        self.window.connect("destroy", self.destroy_cb)
    #        self.window.connect("set-focus", self.set_focus_cb)
        self.window.connect("notify::is-active", self.window_focus_cb)
        self.window.add(self.vbox)

        self.canvas = FigureCanvas(self.figure)  # a gtk.DrawingArea

        self.draw()
        self.update(limits=True)

        self.vbox.pack_start(self.canvas)

        toolbar = NavigationToolbar(self.canvas, self.window)
        self.vbox.pack_start(toolbar, False, False)

        if self.plot.window_size != (0,0):
            self.window.resize(self.plot.window_size[0],
                               self.plot.window_size[1])
        else:
            self.window.resize(400, 300)
        if self.plot.window_pos != (0,0):
            self.window.move(self.plot.window_pos[0],
                             self.plot.window_pos[1])

        self.window.set_title(self.plot.title)

        self.cursors, = self.axes.plot(self.plot.lines[0].get_data()[0], self.plot.lines[0].get_data()[1])
        self.cursors.set_linestyle("None")
        self.cursors.set_markersize(10)
        self.cursors.set_markeredgewidth(2)
        self.cursors.set_markeredgecolor("k")
        self.cursors.set_antialiased(False)

        self.window.show_all()

#        self.plot.figwidth=self.figure.get_figwidth()
#        self.plot.figheight=self.figure.get_figheight()



        #   self.pos=self.window.get_position()
        self.plot.shown=True

    def set_focus_cb(self,window,data):
        print "Hej!"

    def window_focus_cb(self,window,data):
        print self.plot.window_size, self.plot.window_pos
        print "window_focus_cb:", self.plot.title
        if window.get_property('is-active'):
            #self.plot.parent.notebook.set_current_page(1)
            print "is-active"
            if self.plot.parent.plt_combo.get_selected_data() != self.plot:
                print "selecting item..."
                self.plot.parent.plt_combo.select_item_by_data(self.plot)
            self.plot.window_size=self.window.get_size()
            self.plot.window_pos=self.window.get_position()

            self.plot.figwidth=self.figure.get_figwidth()
            self.plot.figheight=self.figure.get_figheight()

    def draw(self, items=None, sources=None):
        legend=[]
        print "drawing "+self.plot.title
        def myfmt(x,y): return 'x=%1.6g\ny=%1.6g'%(x,y)
        self.figure.clf()
        self.axes = self.figure.add_subplot(111)
        #self.axes = self.figure.add_axes([0.10,0.10,0.85,0.85])
        #self.figure.subplots_adjust(bottom=0.15, left=0.15)
        self.axes.set_autoscale_on(False)
        self.axes.format_coord = myfmt

    #        self.btn_axes=self.figure.add_axes([0,0,0.1,0.05], frameon=True)
    #        self.cursor_a_btn=Button(self.btn_axes,"A")

        #self.selector=RectangleSelector(self.axes, self.rectangle_cb, useblit=True)
        self.canvas.mpl_connect('button_release_event', self.button_up_cb)

        #self.axes.callbacks.connect("xlim_changed",self.xlim_cb)
        #self.axes.callbacks.connect("ylim_changed",self.ylim_cb)
        self.figure.canvas.mpl_connect('pick_event',self.pick_cb)

        # xaxis=self.axes.get_xaxis()
        # yaxis=self.axes.get_yaxis()

        # xaxis.set_picker(axis_picker)
        # yaxis.set_picker(axis_picker)


        legend=[]

        for line in self.plot.lines:
            self.draw_line(line, draw_canvas=False)
            #source=line.source
            # if line.source is not None:
                # x_data, y_data=line.get_data()

                # line.handle, = self.axes.plot(x_data, y_data,
                                           # color=line.color, ls=line.style,
                                           # linewidth=line.width, picker=5.0)
                                           # #data_clipping=True)
                # line.handle.parent=line
                # legend.append(line.label)
                # #line.handle.set_label(line.label)

        #self.update()


        self.update_legend(draw_canvas=False)
        self.update_ticks(draw_canvas=False)

        self.canvas.draw()

    def draw_line(self, line, draw_canvas=True):
        #source=line.source
        if line.source is not None:
            x_data, y_data=line.get_data()

            line.handle, = self.axes.plot(x_data, y_data,
                                       color=line.color, ls=line.style,
                                       marker= line.marker, mew=0,
                                       linewidth=line.width, picker=5.0,
                                       label=line.label)
                                       #data_clipping=True)
            line.handle.parent=line
            #legend.append(line.label)
            #line.handle.set_label(line.label)

        #self.update()
        if draw_canvas:
            self.canvas.draw()

    def update_ticks(self, draw_canvas=True):

        xMajorFormatter = ScalarFormatter()
        yMajorFormatter = ScalarFormatter()
        xMajorFormatter.set_powerlimits((-3,4))
        yMajorFormatter.set_powerlimits((-3,4))

        xaxis=self.axes.get_xaxis()
        yaxis=self.axes.get_yaxis()

        xaxis.set_major_formatter(xMajorFormatter)
        yaxis.set_major_formatter(yMajorFormatter)

        if self.plot.x_majorticks_enable:
            xMajorLocator = MaxNLocator(self.plot.x_majorticks_maxn)
            xaxis.set_major_locator(xMajorLocator)
        else:
            xaxis.set_major_locator(NullLocator())

        if self.plot.y_majorticks_enable:
            yMajorLocator = MaxNLocator(self.plot.y_majorticks_maxn)
            yaxis.set_major_locator(yMajorLocator)
        else:
            yaxis.set_major_locator(NullLocator())

        if self.plot.x_minorticks_enable:
            xMinorLocator = MaxNLocator(self.plot.x_minorticks_maxn)
            xaxis.set_minor_locator(xMinorLocator)
        else:
            xaxis.set_minor_locator(NullLocator())

        if self.plot.y_minorticks_enable:
            yMinorLocator = MaxNLocator(self.plot.y_minorticks_maxn)
            yaxis.set_minor_locator(yMinorLocator)
        else:
            yaxis.set_minor_locator(NullLocator())

        self.update_margins(draw_canvas=False)

        if draw_canvas:
            self.canvas.draw()

    def update_margins(self, draw_canvas=True):

        margins={"left":0.05, "bottom":0.05}

        if self.plot.x_axis_label_enable:
            margins["bottom"]+=0.05
        if self.plot.y_axis_label_enable:
            margins["left"]+=0.05
        if self.plot.x_majorticks_enable:
            margins["bottom"]+=0.05
        if self.plot.y_majorticks_enable:
            margins["left"]+=0.05

        print margins

        self.figure.subplots_adjust(**margins)

        if draw_canvas:
            self.canvas.draw()

    def update_legend(self, draw_canvas=True):
        if self.plot.legend_enable:
            print "update_legend()"
            lines=[]
            labels=[]
            for line in self.plot.lines:
                labels.append(line.label)
                lines.append(line.handle)
                #line.handle.set_label(line.label)

            self.legend=self.axes.legend(lines, labels, loc=self.plot.legend_loc,
                               prop=FontProperties(size=self.plot.legend_size))
            self.legend.draw_frame(self.plot.legend_border)
            self.legend.set_picker(legend_picker)
        else:
            self.legend=None
            self.axes.legend_=None
        if draw_canvas:
            self.canvas.draw()

    def gupdate(self, source=None):
        """Takes care of updating relevant parts"""

        self.redraw(sources=[source])

        for part in parts:
            if part == "all":
                self.draw()
            elif part == "legend":
                self.update_legend()
            elif part == "margins":
                self.update_margins()
            elif part == "rest":
                self.update()

    def update(self, limits=True, draw_canvas=True):
        """Updates everything but the Lines and legend"""
    #        if self.plot.shown:
        #self.draw()

        #if self.plot.legend_enable:
        #    self.update_legend()

        if self.plot.x_axis_label_enable:
            self.axes.set_xlabel(self.plot.x_axis_label)
        else:
            self.axes.set_xlabel("")

        if self.plot.y_axis_label_enable:
            self.axes.set_ylabel(self.plot.y_axis_label)
        else:
            self.axes.set_ylabel("")

        if self.plot.x_log_enable:
            self.axes.set_xscale("log")
        else:
            self.axes.set_xscale("linear")
        if self.plot.y_log_enable:
            self.axes.set_yscale("log")
        else:
            self.axes.set_yscale("linear")

        xaxis=self.axes.get_xaxis()
        xaxis.grid(self.plot.x_grid_enable, which="major")

        yaxis=self.axes.get_yaxis()
        yaxis.grid(self.plot.y_grid_enable, which="major")

        if limits:
            extr=self.plot.get_extremes()
            print "sxtr:", extr
            if len(extr) == 4:
                print "extr:", extr
                y_pad=(extr[3]-extr[2])*0.05
                #self.axes.set_xlim(extr[0], extr[1])
                #self.axes.set_ylim(extr[2]-y_pad, extr[3]+y_pad)


                if self.plot.xlim_enable:
                    print "xlim"
                    self.axes.set_xlim(self.plot.xlim_min, self.plot.xlim_max,
                                   emit=False)
                else:
                    self.axes.set_xlim(#map(lambda x: round_to_n(x, 5),
                                       extr[0], extr[1]) #)

                if self.plot.ylim_enable:
                    self.axes.set_ylim(self.plot.ylim_min, self.plot.ylim_max,
                                   emit=False)
                else:
                    y_limits=(extr[2], extr[3])#)#map(lambda y: round_to_n(y, 5),

                    y_pad=(y_limits[1]-y_limits[0])/20
                    self.axes.set_ylim(y_limits[0]-y_pad, y_limits[1]+y_pad)


        try:
            mpl_code=compile(self.plot.mpl_commands,'<string>','exec')
            eval(mpl_code, None, {"figure": self.figure,
                                  "axes": self.axes,
                                  "legend": self.legend,
                                  "s": self.plot.parent.source_list[:],
                                  "p": self.plot.parent.plt_combo.get_model_items().values(),
                                  "plot": self.plot})
        except:
            print "Invalid MPL code!"

        if draw_canvas:
            self.canvas.draw()

    def redraw(self, sources, draw_canvas=True):
        if sources != []:
            lines=[]
            for line in self.plot.lines:
                if line.source in sources and line not in lines:
                    lines.append(line)
            #legend=[]

            for line in lines:
                print("Redraw: "+line.source.name)
                source=line.source
                if source:
                    x_data, y_data=line.get_data()
                    #print x_data, y_data
                    # if source.norm_enable:
                        # print "NORMALIZE!"
                        # y_data=(source.y_data-source.y_data[source.norm_min_pt])/\
                                    # (source.y_data[source.norm_max_pt]-\
                                     # source.y_data[source.norm_min_pt])*\
                                    # (source.norm_max_y-source.norm_min_y)+source.norm_min_y
                    # if line.x_scale_enable:
                        # x_data=x_data*line.x_scale
                    # if line.y_scale_enable:
                        # y_data=y_data*line.y_scale
                    # if line.x_shift_enable:
                        # x_data=x_data+line.x_shift
                    # if line.y_shift_enable:
                        # y_data=y_data+line.y_shift
        #                if source.shift_enable:
        #                    x_data=source.x_data+source.shift

                    line.handle.set_data(x_data, y_data)
                    line.update_extremes()

                    line.handle.set_color(line.color)
                    line.handle.set_linewidth(line.width)
                    line.handle.set_marker(line.marker)

            try:
                s=self.plot.parent.source_list.get_selected_rows()[0]
            except:
                pass
            else:
                if self.show_cursors and s.norm_enable:
                    self.cursors.set_data(   #s.x_data, s.y_data)
                            [s.x_data[s.norm_min_pt],
                            s.x_data[s.norm_max_pt]],
                            [s.norm_min_y, s.norm_max_y])
                    self.cursors.set_marker('+')
                else:
                    self.cursors.set_marker("None")

            #print "getting axis limits"
            extr=self.plot.get_extremes()
            if not self.plot.xlim_enable:
                self.axes.set_xlim(extr[0], extr[1])
            if not self.plot.ylim_enable:
                y_pad=(extr[3]-extr[2])*0.05
                self.axes.set_ylim(extr[2]-y_pad, extr[3]+y_pad)

            #self.axes.redraw_in_frame() #????
            if draw_canvas:
                self.canvas.draw()

    def destroy_cb(self, widget):
        self.plot.shown=False
        #TreeDisplay.update_plot_state()
        #self.pos=self.window.get_position()
        #self.size=self.window.get_size()
        #print self.pos
        self.window.destroy()
        #self.plot.shown=False
        self.plot.parent.shown.update(False)
        #self.plot.parent.shown=False
        #self.plot.update_plot_info()

    #callbacks
    def rectangle_cb(self, event1, event2):
        print event1.xdata, event1.ydata, event2.xdata, event2.ydata
        self.plot.x_lim_min=event1.xdata
        self.plot.x_lim_max=event2.xdata
        self.plot.y_lim_min=event1.ydata
        self.plot.y_lim_max=event2.ydata
        self.axes.set_xlim(min(event1.xdata,event2.xdata), max(event1.xdata,event2.xdata))
        self.axes.set_ylim(min(event1.ydata,event2.ydata), max(event1.ydata,event2.ydata))
        self.canvas.draw()

    def button_up_cb(self,event):
        self.plot.xlim_min, self.plot.xlim_max=self.axes.get_xlim()
        self.plot.ylim_min, self.plot.ylim_max=self.axes.get_ylim()

        if not self.plot.xlim_enable:
            self.plot.parent.xlim_min.update(self.plot.xlim_min)
            self.plot.parent.xlim_max.update(self.plot.xlim_max)

        if not self.plot.ylim_enable:
            self.plot.parent.ylim_min.update(self.plot.ylim_min)
            self.plot.parent.ylim_max.update(self.plot.ylim_max)


    def xlim_cb(self,event):
        #print "xlim changed to: "+str(self.axes.get_xlim())
        self.plot.xlim_min, self.plot.xlim_max=self.axes.get_xlim()
        if not self.plot.xlim_enable:
            self.plot.parent.xlim_min.update(self.plot.xlim_min)
            self.plot.parent.xlim_max.update(self.plot.xlim_max)
        #pass

    def ylim_cb(self,event):
        #print "ylim changed to: "+str(self.axes.get_ylim())
        self.plot.ylim_min, self.plot.ylim_max=self.axes.get_ylim()
        if not self.plot.ylim_enable:
            self.plot.parent.ylim_min.update(self.plot.ylim_min)
            self.plot.parent.ylim_max.update(self.plot.ylim_max)
        #pass

    def pick_cb (self, event ) :
        print event.artist
        if isinstance(event.artist, Line2D):
            print event.artist.parent.label
            xdata=event.artist.get_xdata()
            ydata=event.artist.get_ydata()
            print event.ind[0]
            print xdata[event.ind[0]], ydata[event.ind[0]]
            axes_h=self.axes.get_ylim()
            print axes_h

            self.plot.parent.plot_notebook.set_current_page(0)
            self.plot.parent.lines_list.select(event.artist.parent)

    #            arrow_h=0.1*(axes_h[1]-axes_h[0])
    #            self.axes.arrow(xdata[event.ind[0]],ydata[event.ind[0]],0,arrow_h,label="A", visible=True)
    #            self.canvas.draw()
            #self.axes.arrow(0.5,0.5,0.1,0.1)

        if isinstance(event.artist, Legend):
            print "legend clicked"
            self.plot.parent.plot_notebook.set_current_page(2)

        else:
            print event
Example #20
0
class PlotterWidget(QWidget):
    """Widget surrounding matplotlib plotter."""

    selectionChangedSignal = Signal(list)

    def __init__(self, parent=None):
        """Create PlotterWidget."""
        super(PlotterWidget, self).__init__(parent)

        self.selected = []
        self.ylabel = ""
        self.xlabel = ""
        self.fig = Figure(figsize=(300,300), dpi=72, facecolor=(1,1,1), \
            edgecolor=(0,0,0))

        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        self.canvas.mpl_connect('pick_event', self.onPick)

        # Toolbar Doesn't get along with Kate's MPL at the moment so the
        # mpl_connects will handle that for the moment
        #self.toolbar = NavigationToolbar(self.canvas, self.canvas)
        self.canvas.mpl_connect('motion_notify_event', self.onMouseMotion)
        self.canvas.mpl_connect('scroll_event', self.onScroll)
        self.canvas.mpl_connect('button_press_event', self.onMouseButtonPress)
        self.lastX = 0
        self.lastY = 0

        self.axes = self.fig.add_subplot(111)

        vbox = QVBoxLayout()
        vbox.addWidget(self.canvas)
        vbox.setContentsMargins(0,0,0,0)
        self.setLayout(vbox)

        # Test
        self.axes.plot(range(6),range(6), 'ob')
        self.axes.set_title("Drag attributes to change graph.")
        self.axes.set_xlabel("Drag here to set x axis.")
        self.axes.set_ylabel("Drag here to set y axis.")
        self.canvas.draw() # Why does this take so long on 4726 iMac?

    def setXLabel(self, label):
        """Changes the x label of the plot."""
        self.xlabel = label
        self.axes.set_xlabel(label)
        self.canvas.draw()

    def setYLabel(self, label):
        """Changes the y label of the plot."""
        self.ylabel = label
        self.axes.set_ylabel(label)
        self.canvas.draw()

    def plotData(self, xs, ys):
        """Plots the given x and y data."""
        self.axes.clear()
        self.axes.set_xlabel(self.xlabel)
        self.axes.set_ylabel(self.ylabel)
        self.xs = np.array(xs)
        self.ys = np.array(ys)
        self.axes.plot(xs, ys, 'ob', picker=3)
        if np.alen(self.selected) > 0:
            self.highlighted = self.axes.plot(self.xs[self.selected[0]],
                self.ys[self.selected[0]], 'or')[0]
        self.canvas.draw()

    def onPick(self, event):
        """Handles pick event, taking the closest single point.

           Note that since the id associated with a given point may be
           associated with many points, the final selection displayed to the
           user may be serveral points.
        """
        selected = np.array(event.ind)

        mouseevent = event.mouseevent
        xt = self.xs[selected]
        yt = self.ys[selected]
        d = np.array((xt - mouseevent.xdata)**2 + (yt-mouseevent.ydata)**2)
        thepoint = selected[d.argmin()]
        selected = []
        selected.append(thepoint)

        self.selectionChangedSignal.emit(selected)

    @Slot(list)
    def setHighlights(self, ids):
        """Sets highlights based on the given ids. These ids are the indices
           of the x and y data, not the domain.
        """
        old_selection = list(self.selected)
        self.selected = ids
        if ids is None:
            self.selected = []

        if (old_selection == self.selected and old_selection != [])\
            or self.selected != []:

            if self.selected != [] and old_selection != self.selected: # Color new selection
                for indx in self.selected:
                    self.highlighted = self.axes.plot(self.xs[indx],
                        self.ys[indx], 'or')[0]
            if old_selection == self.selected: # Turn off existing selection
                self.selected = []
            if old_selection != []: # Do not color old selection
                for indx in old_selection:
                    self.axes.plot(self.xs[indx], self.ys[indx],
                        'ob', picker = 3)

            self.canvas.draw()
            return True
        return False



# ---------------------- NAVIGATION CONTROLS ---------------------------

    # Mouse movement (with or w/o button press) handling
    def onMouseMotion(self, event):
        """Handles the panning."""
        if event.button == 1:
            xmotion = self.lastX - event.x
            ymotion = self.lastY - event.y
            self.lastX = event.x
            self.lastY = event.y
            figsize = min(self.fig.get_figwidth(), self.fig.get_figheight())
            xmin, xmax = self.calcTranslate(self.axes.get_xlim(),
                xmotion, figsize)
            ymin, ymax = self.calcTranslate(self.axes.get_ylim(),
                ymotion, figsize)
            self.axes.set_xlim(xmin, xmax)
            self.axes.set_ylim(ymin, ymax)
            self.canvas.draw()

    # Note: the dtuple is in data coordinates, the motion is in pixels,
    # we estimate how much motion there is based on the figsize and then
    # scale it appropriately to the data coordinates to get the proper
    # offset in figure limits.
    def calcTranslate(self, dtuple, motion, figsize):
        """Calculates the translation necessary in one direction given a
           mouse drag in that direction.

           dtuple
               The current limits in a single dimension

           motion
               The number of pixels the mouse was dragged in the dimension.
               This may be negative.

           figsize
               The approximate size of the figure.
        """
        dmin, dmax = dtuple
        drange = dmax - dmin
        dots = self.fig.dpi * figsize
        offset = float(motion * drange) / float(dots)
        newmin = dmin + offset
        newmax = dmax + offset
        return tuple([newmin, newmax])

    # When the user clicks the left mouse button, that is the start of
    # their drag event, so we set the last-coordinates that are used to
    # calculate drag
    def onMouseButtonPress(self, event):
        """Records start of drag event."""
        if event.button == 1:
            self.lastX = event.x
            self.lastY = event.y

    # On mouse wheel scrool, we zoom
    def onScroll(self, event):
        """Zooms on mouse scroll."""
        zoom = event.step
        xmin, xmax = self.calcZoom(self.axes.get_xlim(), 1. + zoom*0.05)
        ymin, ymax = self.calcZoom(self.axes.get_ylim(), 1. + zoom*0.05)
        self.axes.set_xlim(xmin, xmax)
        self.axes.set_ylim(ymin, ymax)
        self.canvas.draw()

    # Calculates the zoom required by the wheel scroll for a single dimension
    # dtuple - the current limits in some dimension
    # scale - fraction to increase/decrease the image size
    # This does a zoom by scaling the limits in that direction appropriately
    def calcZoom(self, dtuple, scale):
        """Calculates the zoom in a single direction based on:

           dtuple
               The limits in the direction

           scale
               Fraction by which to increase/decrease the figure.
        """
        dmin, dmax = dtuple
        drange = dmax - dmin
        dlen = 0.5*drange
        dcenter = dlen + dmin
        newmin = dcenter - dlen*scale
        newmax = dcenter + dlen*scale
        return tuple([newmin, newmax])
Example #21
0
class SubWindow(QtGui.QWidget):
    "Base class for rasviewer document windows"

    cropChanged = QtCore.pyqtSignal()

    def __init__(self, ui_file, data_file, channel_file=None):
        super(SubWindow, self).__init__(None)
        self._load_interface(ui_file)
        try:
            self._load_data(data_file, channel_file)
        except (ValueError, IOError) as exc:
            QtGui.QMessageBox.critical(self, self.tr('Error'), str(exc))
            self.close()
            return
        self._config_interface()
        self._config_handlers()
        self.channel_changed()

    def _load_interface(self, ui_file):
        "Called by __init__ to load the Qt interface file"
        self.ui = None
        self.ui = uic.loadUi(get_ui_file(ui_file), self)

    def _load_data(self, data_file, channel_file=None):
        "Called by __init__ to load the data file"
        self._file = None
        self._progress = 0
        self._progress_update = None
        self._progress_dialog = None
        QtGui.QApplication.instance().setOverrideCursor(QtCore.Qt.WaitCursor)
        try:
            from rastools.data_parsers import DATA_PARSERS
        finally:
            QtGui.QApplication.instance().restoreOverrideCursor()
        # Open the selected file
        ext = os.path.splitext(data_file)[-1]
        parsers = dict(
            (ext, cls) for (cls, exts, _) in DATA_PARSERS for ext in exts)
        try:
            parser = parsers[ext]
        except KeyError:
            raise ValueError(
                self.tr('Unrecognized file extension "{0}"').format(ext))
        self._file = parser(data_file,
                            channel_file,
                            delay_load=False,
                            progress=(
                                self.progress_start,
                                self.progress_update,
                                self.progress_finish,
                            ))
        self.setWindowTitle(os.path.basename(data_file))

    def _config_interface(self):
        "Called by __init__ to configure the interface elements"
        self._info_dialog = None
        self._drag_start = None
        self._pan_id = None
        self._pan_crop = None
        self._zoom_id = None
        self._zoom_rect = None
        # Create a figure in a tab for the file
        self.figure = Figure(figsize=(5.0, 5.0),
                             dpi=FIGURE_DPI,
                             facecolor='w',
                             edgecolor='w')
        self.canvas = FigureCanvas(self.figure)
        self.image_axes = self.figure.add_axes((0.1, 0.1, 0.8, 0.8))
        self.histogram_axes = None
        self.colorbar_axes = None
        self.title_axes = None
        self.ui.splitter.addWidget(self.canvas)
        # Set up the redraw timer
        self.redraw_timer = QtCore.QTimer()
        self.redraw_timer.setInterval(REDRAW_TIMEOUT_DEFAULT)
        self.redraw_timer.timeout.connect(self.redraw_timeout)
        # Set up the limits of the crop spinners
        self.ui.crop_left_spinbox.setRange(0, self._file.x_size - 1)
        self.ui.crop_right_spinbox.setRange(0, self._file.x_size - 1)
        self.ui.crop_top_spinbox.setRange(0, self._file.y_size - 1)
        self.ui.crop_bottom_spinbox.setRange(0, self._file.y_size - 1)
        # Configure the common combos
        default = -1
        for interpolation in sorted(matplotlib.image.AxesImage._interpd):
            if interpolation == DEFAULT_INTERPOLATION:
                default = self.ui.interpolation_combo.count()
            self.ui.interpolation_combo.addItem(interpolation)
        self.ui.interpolation_combo.setCurrentIndex(default)
        if hasattr(self.ui, 'colorbar_check'):
            default = -1
            for color in sorted(matplotlib.cm.datad):
                if not color.endswith('_r'):
                    if color == DEFAULT_COLORMAP:
                        default = self.ui.colormap_combo.count()
                    self.ui.colormap_combo.addItem(color)
            self.ui.colormap_combo.setCurrentIndex(default)

    def _config_handlers(self):
        "Called by __init__ to connect events to handlers"
        # Set up common event connections
        self.ui.interpolation_combo.currentIndexChanged.connect(
            self.invalidate_image)
        self.ui.crop_top_spinbox.valueChanged.connect(self.crop_changed)
        self.ui.crop_left_spinbox.valueChanged.connect(self.crop_changed)
        self.ui.crop_right_spinbox.valueChanged.connect(self.crop_changed)
        self.ui.crop_bottom_spinbox.valueChanged.connect(self.crop_changed)
        self.ui.axes_check.toggled.connect(self.invalidate_image)
        self.ui.x_label_edit.textChanged.connect(self.invalidate_image)
        self.ui.y_label_edit.textChanged.connect(self.invalidate_image)
        self.ui.x_scale_spinbox.valueChanged.connect(self.x_scale_changed)
        self.ui.y_scale_spinbox.valueChanged.connect(self.y_scale_changed)
        self.ui.x_offset_spinbox.valueChanged.connect(self.x_offset_changed)
        self.ui.y_offset_spinbox.valueChanged.connect(self.y_offset_changed)
        self.ui.grid_check.toggled.connect(self.invalidate_image)
        self.ui.histogram_check.toggled.connect(self.invalidate_image)
        self.ui.histogram_bins_spinbox.valueChanged.connect(
            self.invalidate_image)
        self.ui.title_edit.textChanged.connect(self.invalidate_image)
        self.ui.default_title_button.clicked.connect(
            self.default_title_clicked)
        self.ui.clear_title_button.clicked.connect(self.clear_title_clicked)
        self.ui.title_info_button.clicked.connect(self.title_info_clicked)
        self.ui.splitter.splitterMoved.connect(self.splitter_moved)
        QtGui.QApplication.instance().focusChanged.connect(self.focus_changed)
        self.canvas.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
        self.canvas.customContextMenuRequested.connect(self.canvas_popup)
        if hasattr(self.ui, 'colorbar_check'):
            self.ui.colorbar_check.toggled.connect(self.invalidate_image)
            self.ui.colormap_combo.currentIndexChanged.connect(
                self.invalidate_image)
            self.ui.reverse_check.toggled.connect(self.invalidate_image)
        self.press_id = self.canvas.mpl_connect('button_press_event',
                                                self.canvas_press)
        self.release_id = self.canvas.mpl_connect('button_release_event',
                                                  self.canvas_release)
        self.motion_id = self.canvas.mpl_connect('motion_notify_event',
                                                 self.canvas_motion)

    def splitter_moved(self, pos, index):
        self.invalidate_image()

    def progress_start(self):
        "Handler for loading progress start event"
        self._progress = 0
        self._progress_dialog = ProgressDialog(self.window())
        self._progress_dialog.show()
        self._progress_dialog.task = self.tr('Opening file')
        QtGui.QApplication.instance().setOverrideCursor(QtCore.Qt.WaitCursor)

    def progress_update(self, progress):
        "Handler for loading progress update event"
        now = time.time()
        if ((self._progress_update is None)
                or (now - self._progress_update) > 0.2):
            if self._progress_dialog.cancelled:
                raise KeyboardInterrupt
            self._progress_update = now
            if progress != self._progress:
                self._progress_dialog.progress = progress
                self._progress = progress

    def progress_finish(self):
        "Handler for loading progress finished event"
        QtGui.QApplication.instance().restoreOverrideCursor()
        if self._progress_dialog is not None:
            self._progress_dialog.close()
            self._progress_dialog = None

    def canvas_popup(self, pos):
        "Handler for canvas context menu event"
        menu = QtGui.QMenu(self)
        menu.addAction(self.window().ui.zoom_mode_action)
        menu.addAction(self.window().ui.pan_mode_action)
        menu.addSeparator()
        menu.addAction(self.window().ui.zoom_in_action)
        menu.addAction(self.window().ui.zoom_out_action)
        menu.addAction(self.window().ui.reset_zoom_action)
        menu.addSeparator()
        menu.addAction(self.window().ui.home_axes_action)
        menu.addAction(self.window().ui.reset_axes_action)
        menu.popup(self.canvas.mapToGlobal(pos))

    def canvas_motion(self, event):
        "Handler for mouse movement over graph canvas"
        raise NotImplementedError

    def canvas_press(self, event):
        "Handler for mouse press on graph canvas"
        if event.button != 1:
            return
        if event.inaxes != self.image_axes:
            return
        self._drag_start = Coord(event.x, event.y)
        if self.window().ui.zoom_mode_action.isChecked():
            self._zoom_id = self.canvas.mpl_connect('motion_notify_event',
                                                    self.canvas_zoom_motion)
        elif self.window().ui.pan_mode_action.isChecked():
            self._pan_id = self.canvas.mpl_connect('motion_notify_event',
                                                   self.canvas_pan_motion)
            self._pan_crop = Crop(top=self.ui.crop_top_spinbox.value(),
                                  left=self.ui.crop_left_spinbox.value(),
                                  bottom=self.ui.crop_bottom_spinbox.value(),
                                  right=self.ui.crop_right_spinbox.value())
            self.redraw_timer.setInterval(REDRAW_TIMEOUT_PAN)

    def canvas_pan_motion(self, event):
        "Handler for mouse movement in pan mode"
        inverse = self.image_axes.transData.inverted()
        start_x, start_y = inverse.transform_point(self._drag_start)
        end_x, end_y = inverse.transform_point((event.x, event.y))
        delta = Coord(int(start_x - end_x), int(start_y - end_y))
        if (self._pan_crop.left + delta.x >=
                0) and (self._pan_crop.right - delta.x >= 0):
            self.ui.crop_left_spinbox.setValue(self._pan_crop.left + delta.x)
            self.ui.crop_right_spinbox.setValue(self._pan_crop.right - delta.x)
        if (self._pan_crop.top + delta.y >=
                0) and (self._pan_crop.bottom - delta.y >= 0):
            self.ui.crop_top_spinbox.setValue(self._pan_crop.top + delta.y)
            self.ui.crop_bottom_spinbox.setValue(self._pan_crop.bottom -
                                                 delta.y)

    def canvas_zoom_motion(self, event):
        "Handler for mouse movement in zoom mode"
        # Calculate the display coordinates of the selection
        box_left, box_top, box_right, box_bottom = self.image_axes.bbox.extents
        height = self.figure.bbox.height
        band_left = max(min(self._drag_start.x, event.x), box_left)
        band_right = min(max(self._drag_start.x, event.x), box_right)
        band_top = max(min(self._drag_start.y, event.y), box_top)
        band_bottom = min(max(self._drag_start.y, event.y), box_bottom)
        rectangle = (band_left, height - band_top, band_right - band_left,
                     band_top - band_bottom)
        # Calculate the data coordinates of the selection. Note that top and
        # bottom are reversed by this conversion
        inverse = self.image_axes.transData.inverted()
        data_left, data_bottom = inverse.transform_point((band_left, band_top))
        data_right, data_top = inverse.transform_point(
            (band_right, band_bottom))
        # Ignore the drag operation until the total number of data-points in
        # the selection exceeds the threshold
        if (abs(data_right - data_left) *
                abs(data_bottom - data_top)) > ZOOM_THRESHOLD:
            self._zoom_rect = (data_left, data_top, data_right, data_bottom)
            self.window().statusBar().showMessage(
                self.tr('Crop from ({left:.0f}, {top:.0f}) to '
                        '({right:.0f}, {bottom:.0f})').format(
                            left=data_left,
                            top=data_top,
                            right=data_right,
                            bottom=data_bottom))
            self.canvas.drawRectangle(rectangle)
        else:
            self._zoom_rect = None
            self.window().statusBar().clearMessage()
            self.canvas.draw()

    def canvas_release(self, event):
        "Handler for mouse release on graph canvas"
        if self._pan_id:
            self.window().statusBar().clearMessage()
            self.canvas.mpl_disconnect(self._pan_id)
            self._pan_id = None
            self.redraw_timer.setInterval(REDRAW_TIMEOUT_DEFAULT)
        if self._zoom_id:
            self.window().statusBar().clearMessage()
            self.canvas.mpl_disconnect(self._zoom_id)
            self._zoom_id = None
            if self._zoom_rect:
                (
                    data_left,
                    data_top,
                    data_right,
                    data_bottom,
                ) = self._zoom_rect
                data_left = ((data_left / self.ui.x_scale_spinbox.value()) -
                             self.ui.x_offset_spinbox.value())
                data_right = ((data_right / self.ui.x_scale_spinbox.value()) -
                              self.ui.x_offset_spinbox.value())
                data_top = ((data_top / self.ui.y_scale_spinbox.value()) -
                            self.ui.y_offset_spinbox.value())
                data_bottom = (
                    (data_bottom / self.ui.y_scale_spinbox.value()) -
                    self.ui.y_offset_spinbox.value())
                self.ui.crop_left_spinbox.setValue(data_left)
                self.ui.crop_top_spinbox.setValue(data_top)
                self.ui.crop_right_spinbox.setValue(self._file.x_size -
                                                    data_right)
                self.ui.crop_bottom_spinbox.setValue(self._file.y_size -
                                                     data_bottom)
                self.canvas.draw()

    def channel_changed(self):
        "Handler for data channel change event"
        self.invalidate_data()
        self.crop_changed()

    def crop_changed(self, value=None):
        "Handler for crop_*_spinbox change event"
        self.cropChanged.emit()

    @property
    def zoom_factor(self):
        "Returns the percentage by which zoom in/out will operate"
        factor = 0.2
        height, width = self.data_cropped.shape[:2]
        return (max(1.0, width * factor), max(1.0, height * factor))

    @property
    def can_zoom_in(self):
        "Returns True if the image can be zoomed"
        height, width = self.data_cropped.shape[:2]
        x_factor, y_factor = self.zoom_factor
        return (width - x_factor * 2) * (height -
                                         y_factor * 2) > ZOOM_THRESHOLD

    @property
    def can_zoom_out(self):
        "Returns True if the image is zoomed"
        return (self.ui.crop_left_spinbox.value() > 0
                or self.ui.crop_right_spinbox.value() > 0
                or self.ui.crop_top_spinbox.value() > 0
                or self.ui.crop_bottom_spinbox.value())

    def zoom_in(self):
        "Zooms the image in by a fixed amount"
        x_factor, y_factor = self.zoom_factor
        self.ui.crop_left_spinbox.setValue(self.ui.crop_left_spinbox.value() +
                                           x_factor)
        self.ui.crop_right_spinbox.setValue(
            self.ui.crop_right_spinbox.value() + x_factor)
        self.ui.crop_top_spinbox.setValue(self.ui.crop_top_spinbox.value() +
                                          y_factor)
        self.ui.crop_bottom_spinbox.setValue(
            self.ui.crop_bottom_spinbox.value() + y_factor)

    def zoom_out(self):
        "Zooms the image out by a fixed amount"
        x_factor, y_factor = self.zoom_factor
        self.ui.crop_left_spinbox.setValue(
            max(0.0,
                self.ui.crop_left_spinbox.value() - x_factor))
        self.ui.crop_right_spinbox.setValue(
            max(0.0,
                self.ui.crop_right_spinbox.value() - x_factor))
        self.ui.crop_top_spinbox.setValue(
            max(0.0,
                self.ui.crop_top_spinbox.value() - y_factor))
        self.ui.crop_bottom_spinbox.setValue(
            max(0.0,
                self.ui.crop_bottom_spinbox.value() - y_factor))

    def reset_zoom(self):
        "Handler for reset_zoom_action triggered event"
        self.ui.crop_left_spinbox.setValue(0)
        self.ui.crop_right_spinbox.setValue(0)
        self.ui.crop_top_spinbox.setValue(0)
        self.ui.crop_bottom_spinbox.setValue(0)

    def reset_axes(self):
        "Handler for the reset_axes_action triggered event"
        self.ui.scale_locked_check.setChecked(True)
        self.ui.x_scale_spinbox.setValue(1.0)
        self.ui.y_scale_spinbox.setValue(1.0)
        self.ui.offset_locked_check.setChecked(True)
        self.ui.x_offset_spinbox.setValue(0.0)
        self.ui.y_offset_spinbox.setValue(0.0)

    def home_axes(self):
        "Handler for home_axes_action triggered event"
        self.ui.scale_locked_check.setChecked(True)
        self.ui.x_scale_spinbox.setValue(1.0)
        self.ui.y_scale_spinbox.setValue(1.0)
        self.ui.offset_locked_check.setChecked(False)
        self.ui.x_offset_spinbox.setValue(-self.ui.crop_left_spinbox.value())
        self.ui.y_offset_spinbox.setValue(-self.ui.crop_top_spinbox.value())

    def x_scale_changed(self, value):
        "Handler for x_scale_spinbox change event"
        if self.ui.scale_locked_check.isChecked():
            self.ui.y_scale_spinbox.setValue(value)
        self.invalidate_image()

    def y_scale_changed(self, value):
        "Handler for y_scale_spinbox change event"
        if self.ui.scale_locked_check.isChecked():
            self.ui.x_scale_spinbox.setValue(value)
        self.invalidate_image()

    def x_offset_changed(self, value):
        "Handler for x_offset_spinbox change event"
        if self.ui.offset_locked_check.isChecked():
            self.ui.y_offset_spinbox.setValue(value)
        self.invalidate_image()

    def y_offset_changed(self, value):
        "Handler for x_offset_spinbox change event"
        if self.ui.offset_locked_check.isChecked():
            self.ui.x_offset_spinbox.setValue(value)
        self.invalidate_image()

    def default_title_clicked(self):
        "Handler for default_title_button click event"
        raise NotImplementedError

    def clear_title_clicked(self):
        "Handler for clear_title_button click event"
        self.ui.title_edit.clear()

    def title_info_clicked(self, items):
        "Handler for title_info_button click event"
        from rastools.rasviewer.title_info_dialog import TitleInfoDialog
        if not self._info_dialog:
            self._info_dialog = TitleInfoDialog(self)
        self._info_dialog.ui.template_list.clear()
        for key, value in sorted(self.format_dict().items()):
            if isinstance(value, type('')):
                if '\n' in value:
                    value = value.splitlines()[0].rstrip()
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(['{{{0}}}'.format(key), value]))
            elif isinstance(value, int):
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(
                        ['{{{0}}}'.format(key), '{0}'.format(value)]))
                if 0 < value < 10:
                    self._info_dialog.ui.template_list.addTopLevelItem(
                        QtGui.QTreeWidgetItem([
                            '{{{0}:02d}}'.format(key), '{0:02d}'.format(value)
                        ]))
            elif isinstance(value, float):
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(
                        ['{{{0}}}'.format(key), '{0}'.format(value)]))
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(
                        ['{{{0}:.2f}}'.format(key), '{0:.2f}'.format(value)]))
            elif isinstance(value, dt.datetime):
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(
                        ['{{{0}}}'.format(key), '{0}'.format(value)]))
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem([
                        '{{{0}:%Y-%m-%d}}'.format(key),
                        '{0:%Y-%m-%d}'.format(value)
                    ]))
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem([
                        '{{{0}:%H:%M:%S}}'.format(key),
                        '{0:%H:%M:%S}'.format(value)
                    ]))
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem([
                        '{{{0}:%A, %d %b %Y, %H:%M:%S}}'.format(key),
                        '{0:%A, %d %b %Y, %H:%M:%S}'.format(value)
                    ]))
            else:
                self._info_dialog.ui.template_list.addTopLevelItem(
                    QtGui.QTreeWidgetItem(
                        ['{{{0}}}'.format(key), '{0}'.format(value)]))
        self._info_dialog.show()

    @property
    def data(self):
        "Returns the original data array"
        raise NotImplementedError

    @property
    def data_cropped(self):
        "Returns the data after cropping"
        raise NotImplementedError

    @property
    def x_limits(self):
        "Returns a tuple of the X-axis limits after scaling and offset"
        if self.data_cropped is not None:
            return Range(
                (self.ui.x_scale_spinbox.value() or 1.0) *
                (self.ui.x_offset_spinbox.value() +
                 self.ui.crop_left_spinbox.value()),
                (self.ui.x_scale_spinbox.value() or 1.0) *
                (self.ui.x_offset_spinbox.value() + self._file.x_size -
                 self.ui.crop_right_spinbox.value()))

    @property
    def y_limits(self):
        "Returns a tuple of the Y-axis limits after scaling and offset"
        if self.data_cropped is not None:
            return Range(
                (self.ui.y_scale_spinbox.value() or 1.0) *
                (self.ui.y_offset_spinbox.value() + self._file.y_size -
                 self.ui.crop_bottom_spinbox.value()),
                (self.ui.y_scale_spinbox.value() or 1.0) *
                (self.ui.y_offset_spinbox.value() +
                 self.ui.crop_top_spinbox.value()))

    @property
    def axes_visible(self):
        "Returns True if the axes should be shown"
        return hasattr(self.ui,
                       'axes_check') and self.ui.axes_check.isChecked()

    @property
    def colorbar_visible(self):
        "Returns True if the colorbar should be shown"
        return hasattr(
            self.ui, 'colorbar_check') and self.ui.colorbar_check.isChecked()

    @property
    def histogram_visible(self):
        "Returns True if the histogram should be shown"
        return hasattr(
            self.ui,
            'histogram_check') and self.ui.histogram_check.isChecked()

    @property
    def margin_visible(self):
        "Returns True if the image margins should be shown"
        return (self.axes_visible or self.histogram_visible
                or self.colorbar_visible or bool(self.image_title))

    @property
    def x_margin(self):
        "Returns the size of the left and right margins when drawing"
        return 0.75 if self.margin_visible else 0.0

    @property
    def y_margin(self):
        "Returns the size of the top and bottom margins when drawing"
        return 0.25 if self.margin_visible else 0.0

    @property
    def sep_margin(self):
        "Returns the size of the separator between image elements"
        return 0.3

    @property
    def image_title(self):
        "Returns the text of the image title after substitution"
        result = ''
        try:
            if self.ui.title_edit.toPlainText():
                result = str(self.ui.title_edit.toPlainText()).format(
                    **self.format_dict())
        except KeyError as exc:
            self.ui.title_error_label.setText(
                'Unknown template "{}"'.format(exc))
            self.ui.title_error_label.show()
        except ValueError as exc:
            self.ui.title_error_label.setText(str(exc))
            self.ui.title_error_label.show()
        else:
            self.ui.title_error_label.hide()
        return result

    @property
    def figure_box(self):
        "Returns the overall bounding box"
        return BoundingBox(0.0, 0.0, self.figure.get_figwidth(),
                           self.figure.get_figheight())

    @property
    def colorbar_box(self):
        "Returns the colorbar bounding box"
        return BoundingBox(self.x_margin, self.y_margin,
                           self.figure_box.width - (self.x_margin * 2),
                           0.5 if self.colorbar_visible else 0.0)

    @property
    def title_box(self):
        "Returns the title bounding box"
        return BoundingBox(
            self.x_margin, self.figure_box.height -
            (self.y_margin + 1.0 if bool(self.image_title) else 0.0),
            self.figure_box.width - (self.x_margin * 2),
            1.0 if bool(self.image_title) else 0.0)

    @property
    def histogram_box(self):
        "Returns the histogram bounding box"
        return BoundingBox(
            self.x_margin, self.colorbar_box.top +
            (self.sep_margin if self.colorbar_visible else 0.0),
            self.figure_box.width - (self.x_margin * 2),
            (self.figure_box.height - (self.y_margin * 2) -
             self.colorbar_box.height - self.title_box.height -
             (self.sep_margin if self.colorbar_visible else 0.0) -
             (self.sep_margin if bool(self.image_title) else 0.0)) /
            2.0 if self.histogram_visible else 0.0)

    @property
    def image_box(self):
        "Returns the image bounding box"
        return BoundingBox(
            self.x_margin,
            self.histogram_box.top + (self.sep_margin if self.colorbar_visible
                                      or self.histogram_visible else 0.0),
            self.figure_box.width - (self.x_margin * 2),
            (self.figure_box.height -
             (self.y_margin * 2) - self.colorbar_box.height -
             self.title_box.height - self.histogram_box.height -
             (self.sep_margin if self.colorbar_visible else 0.0) -
             (self.sep_margin if self.histogram_visible else 0.0) -
             (self.sep_margin if bool(self.image_title) else 0.0)))

    def invalidate_image(self):
        "Invalidate the image"
        # Actually, this method doesn't immediately invalidate the image (as
        # this results in a horribly sluggish UI), but starts a timer which
        # causes a redraw after no invalidations have occurred for a period
        # (see __init__ for the duration)
        if self.redraw_timer.isActive():
            self.redraw_timer.stop()
        self.redraw_timer.start()

    def redraw_timeout(self):
        "Handler for the redraw_timer's timeout event"
        self.redraw_timer.stop()
        self.redraw_figure()

    def redraw_figure(self):
        "Called to redraw the channel image"
        # The following tests ensure we don't try and draw anything while we're
        # still loading the file
        if self._file and self.data is not None:
            # Draw the various image elements within bounding boxes calculated
            # from the metrics above
            image = self.draw_image()
            self.draw_histogram()
            self.draw_colorbar(image)
            self.draw_title()
            self.canvas.draw()

    def draw_image(self):
        "Draws the image of the data within the specified figure"
        raise NotImplementedError

    def draw_histogram(self):
        "Draws the data's historgram within the figure"
        raise NotImplementedError

    def draw_colorbar(self, image):
        "Draws a range color-bar within the figure"
        raise NotImplementedError

    def draw_title(self):
        "Draws a title within the specified figure"
        box = self.title_box.relative_to(self.figure_box)
        if bool(self.image_title):
            if self.title_axes is None:
                self.title_axes = self.figure.add_axes(box)
            else:
                self.title_axes.clear()
                self.title_axes.set_position(box)
            self.title_axes.set_axis_off()
            # Render the title
            self.title_axes.text(0.5,
                                 0,
                                 self.image_title,
                                 horizontalalignment='center',
                                 verticalalignment='baseline',
                                 multialignment='center',
                                 size='medium',
                                 family='sans-serif',
                                 transform=self.title_axes.transAxes)
        elif self.title_axes:
            self.figure.delaxes(self.title_axes)
            self.title_axes = None

    def format_dict(self):
        "Returns UI settings in a dict for use in format substitutions"
        raise NotImplementedError
class BrowserMatPlotFrame(QtGui.QWidget):
    "定义画图的页面"
    def __init__(self, parent = None):
        QtGui.QWidget.__init__(self)
        self.parent = parent
        self.status_bar = parent.status_bar

        #State
        self.draw_node_labels_tf = True
        self.draw_axis_units_tf = False
        self.draw_grid_tf = False
        self.g = None

        #PATH used in drawing STEP hierarchy, co-occurence, context
        self.step_path = parent.step_path
        
        #MPL figure
        self.dpi = 100
        self.fig = Figure((5.0, 4.0), dpi=self.dpi)
        self.fig.subplots_adjust(left=0,right=1,top=1,bottom=0)
        
        #QT canvas
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        self.canvas.mpl_connect('pick_event', self.on_pick) #used when selectingpyth	 canvas objects
        self.canvas.setSizePolicy(QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding) 
        
        self.axes = self.fig.add_subplot(111)
        #self.axes.hold(False) #clear the axes every time plot() is called

        self.mpl_toolbar = NavigationToolbar(self.canvas, self)

        #GUI controls
        self.mode_combo = QComboBox()
        self.mode_combo.addItems(["Graph Test", 
                                  "Graph Test Numpy", 
                                  "STEP Hierarchy", 
                                  "STEP Co-occurence",
                                  "STEP Context"])
        self.mode_combo.setMinimumWidth(200)
        
        self.draw_button = QPushButton("&Draw/Refresh")
        self.connect(self.draw_button, QtCore.SIGNAL('clicked()'), self.on_draw)
        
        self.node_size = QSpinBox(self)
        self.node_size.setSingleStep(5)
        self.node_size.setMaximum(100)
        self.node_size.setValue(25)
        self.node_size_label = QLabel('Node Size (%):')
        # connection set in on_draw() method

        #Horizontal layout
        hbox = QtGui.QHBoxLayout()
    
        #Adding matplotlib widgets
        for w in [self.mode_combo, 's', self.node_size_label, self.node_size, self.draw_button]:
            if w == 's': hbox.addStretch()
            else:
                hbox.addWidget(w)
                hbox.setAlignment(w, Qt.AlignVCenter)

        #Vertical layout. Adding all other widgets, and hbox layout.
        vbox = QtGui.QVBoxLayout()
        vbox.addWidget(self.mpl_toolbar)
        vbox.addWidget(self.canvas)
        vbox.addLayout(hbox)

        self.setLayout(vbox)
        self.canvas.setFocus(True)
        
    def draw_axis_units(self):
        fw = self.fig.get_figwidth()
        fh = self.fig.get_figheight()

        l_margin = .4 / fw #.4in
        b_margin = .3 / fh #.3in

        if self.draw_axis_units_tf == True:
            self.fig.subplots_adjust(left=l_margin,right=1,top=1,bottom=b_margin)
        else: 
            self.fig.subplots_adjust(left=0,right=1,top=1,bottom=0)

        self.canvas.draw()

    def draw_grid(self):
        if self.draw_grid_tf == False:
            self.draw_grid_tf = True
        else:
            self.draw_grid_tf = False
            
        self.axes.grid(self.draw_grid_tf)
        self.canvas.draw()

    def on_draw(self): 
        draw_mode = self.mode_combo.currentText()
        
        self.axes.clear()
        if self.g != None:
            if hasattr(self.g, 'destruct'):
                self.g.destruct()

        if draw_mode == "Graph Test":
            self.g = GraphTest(self)
        elif draw_mode == "Graph Test Numpy":
            self.g = GraphTestNumPy(self)
        elif draw_mode == "STEP Hierarchy":
            self.g = GraphHierarchy(self)
        elif draw_mode == "STEP Co-occurence":
            self.g = GraphCoOccurrence(self)
        elif draw_mode == "STEP Context":
            self.g = GraphCoOccurrence(self)

        self.connect(self.node_size, QtCore.SIGNAL('valueChanged(int)'), self.g.set_node_mult)
        self.axes.grid(self.draw_grid_tf)
        self.canvas.draw()
        
    def on_pick(self, args):
        print "in matplotframe: ", args

    def resizeEvent(self, ev):
        self.draw_axis_units()
        
    def set_step_path(self, path):
        self.step_path = path
        self.parent.set_step_path(path)

    def toggle_axis_units(self):
        if self.draw_axis_units_tf == False: 
            self.draw_axis_units_tf = True
        else:
            self.draw_axis_units_tf = False
        self.draw_axis_units()

    def toggle_node_labels(self):
        if self.draw_node_labels_tf == False: 
            self.draw_node_labels_tf = True
        else:
            self.draw_node_labels_tf = False

        if self.g != None:
            self.g.redraw()
Example #23
0
class Chart(FigureCanvas):
    """Klasa (widget Qt) odpowiedzialna za rysowanie wykresu. Zgodnie z tym, co zasugerował
    Paweł, na jednym wykresie wyświetlam jeden wskaźnik i jeden oscylator, a jak ktoś
    będzie chciał więcej, to kliknie sobie jakiś guzik, który mu pootwiera kilka wykresów
    w nowym oknie."""
    #margines (pionowy i poziomy oraz maksymalna wysokość/szerokość wykresu)
    margin, maxSize = 0.1, 0.8
    #wysokość wolumenu i wykresu oscylatora
    volHeight, oscHeight = 0.1, 0.15
    
    def __init__(self, parent, finObj=None, width=8, height=6, dpi=100):
        """Konstruktor. Tworzy domyślny wykres (liniowy z wolumenem, bez wskaźników)
dla podanych danych. Domyślny rozmiar to 800x600 pixli"""
        self.mainPlot=None
        self.volumeBars=None
        self.oscPlot=None
        self.additionalLines = [] #lista linii narysowanych na wykresie (przez usera, albo przez wykrycie trendu)
        self.rectangles = [] #lista prostokątów (do zaznaczania świec)
        self.mainType = None #typ głównego wykresu
        self.oscType = None #typ oscylatora (RSI, momentum, ...)
        self.mainIndicator = None #typ wskaźnika rysowany dodatkowo na głównym wykresie (średnia krocząca, ...)
        self.x0, self.y0 = None,None #współrzędne początku linii
        self.drawingMode = False #zakładam, że możliwość rysowania będzie można włączyć/wyłączyć
        self.scaleType = 'linear' #rodzaj skali na osi y ('linear' lub 'log')
        self.grid = True #czy rysujemy grida
        self.setData(finObj)
        self.mainType='line'
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QtGui.QSizePolicy.Expanding,
                                   QtGui.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.addMainPlot()
        self.addVolumeBars()
        self.mpl_connect('button_press_event', self.onClick)     
           
    def setData(self, finObj, start=None, end=None, step='daily'):
        """Ustawiamy model danych, który ma reprezentować wykres. Następnie
        konieczne jest jego ponowne odrysowanie"""
        if(finObj==None):
            return;
        self.data=ChartData(finObj, start, end, step)
        if(self.mainPlot!=None):
            self.updatePlot()
    
    def getData(self):
        return self.data
        
    def setGrid(self, grid):
        """Włącza (True) lub wyłącza (False) rysowanie grida"""
        self.grid=grid
        self.updateMainPlot()
            
    def setMainType(self, type):
        """Ustawiamy typ głównego wykresu ('point','line','candlestick','none')"""
        self.mainType=type
        self.updateMainPlot()        
        
    def updatePlot(self):
        """Odświeża wszystkie wykresy"""                
        self.updateMainPlot()
        self.updateVolumeBars()
        self.updateOscPlot()                                
        self.draw()        
        #self.drawGeometricFormation()
        #self.drawRateLines()
        #self.drawTrend()
        #self.drawCandleFormations()
        #self.drawGaps()
		
    
    def addMainPlot(self):
        """Rysowanie głównego wykresu (tzn. kurs w czasie)"""                                            
        bounds=[self.margin, self.margin, self.maxSize, self.maxSize]
        self.mainPlot=self.fig.add_axes(bounds)                        
        self.updateMainPlot()
    
    def updateMainPlot(self):        
        if(self.mainPlot==None or self.data==None or self.data.corrupted):
            return
        ax=self.mainPlot                
        ax.clear()  
        x=range(len(self.data.close))
        if self.mainType=='line' :
            ax.plot(x,self.data.close,'b-',label=self.data.name)
        elif self.mainType=='point':
            ax.plot(x,self.data.close,'b.',label=self.data.name)
        elif self.mainType=='candlestick':
            self.drawCandlePlot()
        elif self.mainType=='bar':
            self.drawBarPlot()
        else:            
            return
        if self.mainIndicator != None:
            self.updateMainIndicator()       
        ax.set_xlim(x[0],x[-1])
        ax.set_yscale(self.scaleType)
        ax.set_ylim(0.995*min(self.data.low),1.005*max(self.data.high))                 
        for line in self.additionalLines:
            ax.add_line(line)
            line.figure.draw_artist(line)         
        for rect in self.rectangles:
            ax.add_patch(rect)
            rect.figure.draw_artist(rect)        
        if(self.scaleType=='log'):            
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))            
            ax.yaxis.set_minor_formatter(FormatStrFormatter('%.2f'))            
        for tick in ax.yaxis.get_major_ticks():
            tick.label2On=True
            if(self.grid):
                tick.gridOn=True        
        for label in (ax.get_yticklabels() + ax.get_yminorticklabels()):
            label.set_size(8)
        #legenda
        leg = ax.legend(loc='best', fancybox=True)
        leg.get_frame().set_alpha(0.5)
        self.formatDateAxis(self.mainPlot)        
        self.fixTimeLabels()
        if(self.grid):
            for tick in ax.xaxis.get_major_ticks():
                # print tick.get_loc()
                tick.gridOn=True
    
    def addVolumeBars(self):
        """Dodaje do wykresu wyświetlanie wolumenu."""        
        #tworzymy nowy wykres tylko za pierwszym razem, potem tylko pokazujemy i odświeżamy                
        if(self.volumeBars==None):
            volBounds=[self.margin, self.margin, self.maxSize, self.volHeight]
            self.volumeBars=self.fig.add_axes(volBounds, sharex=self.mainPlot)                                                                               
        self.updateVolumeBars()
        self.volumeBars.set_visible(True)
        self.fixPositions()
        self.fixTimeLabels()
    
    def rmVolumeBars(self):
        """Ukrywa wykres wolumenu"""
        if self.volumeBars==None:
            return
        self.volumeBars.set_visible(False)        
        self.fixPositions()                            
        self.fixTimeLabels()
    
    def setScaleType(self,type):    
        """Ustawia skalę liniową lub logarytmiczną na głównym wykresie."""
        if(type) not in ['linear','log']:
            return        
        self.scaleType=type
        self.updateMainPlot()
        
    def updateVolumeBars(self):
        """Odświeża rysowanie wolumenu"""                
        if self.data==None or self.data.corrupted:
            return        
        ax=self.volumeBars
        ax.clear()
        x=range(len(self.data.close))
        ax.vlines(x,0,self.data.volume)        
        ax.set_xlim(x[0],x[-1])
        if(max(self.data.volume)>0):
            ax.set_ylim(0,1.2*max(self.data.volume))
        for label in self.volumeBars.get_yticklabels():
            label.set_visible(False)                                
        for o in ax.findobj(Text):
            o.set_visible(False)
        self.formatDateAxis(ax)
        self.fixTimeLabels()
        
    def drawCandlePlot(self):
        """Wyświetla główny wykres w postaci świecowej"""            
        if self.data==None or self.data.corrupted:
            return
        ax=self.mainPlot
        rectsList=[]
        open=self.data.open
        close=self.data.close
        xvals=range(len(close))
        lines=ax.vlines(xvals,self.data.low,self.data.high,label=self.data.name,linewidth=0.5)
        lines.set_zorder(lines.get_zorder()-1)
        for i in xvals:
            height=max(abs(close[i]-open[i]),0.001)
            width=0.7
            x=i-width/2
            y=min(open[i],close[i])
            print x,y,width,height
            if open[i]<=close[i]:
                rectsList.append(Rectangle((x,y),width,height,facecolor='w',edgecolor='k',linewidth=0.5))
            else:
                rectsList.append(Rectangle((x,y),width,height,facecolor='k',edgecolor='k',linewidth=0.5))
        ax.add_collection(PatchCollection(rectsList,match_original=True))     
    
    def drawBarPlot(self):
        """Rysuje główny wykres w postaci barowej."""
        if self.data==None or self.data.corrupted:
            return
        ax=self.mainPlot
        x=range(len(self.data.close))
        lines1=ax.vlines(x,self.data.low,self.data.high,label=self.data.name)
        lines2list=[]
        for i in x:
            lines2list.append(((i-0.3,self.data.open[i]),(i,self.data.open[i])))
            lines2list.append(((i,self.data.close[i]),(i+0.3,self.data.close[i])))   
        lines2=LineCollection(lines2list)
        lines2.color('k')
        ax.add_collection(lines2)
    
    def setMainIndicator(self, type):
        """Ustawiamy, jaki wskaźnik chcemy wyświetlać na głównym wykresie"""
        self.mainIndicator=type        
        self.updateMainPlot()
    
    def updateMainIndicator(self):
        """Odrysowuje wskaźnik na głównym wykresie"""
        if self.data==None or self.data.corrupted:
            return
        ax=self.mainPlot
        type=self.mainIndicator
        ax.hold(True) #hold on 
        x=range(len(self.data.close))
        if type=='SMA':
            indicValues=self.data.movingAverage('SMA')        
        elif type=='WMA':
            indicValues=self.data.movingAverage('WMA')
        elif type=='EMA':
            indicValues=self.data.movingAverage('EMA')
        elif type=='bollinger':            
            if self.data.bollinger('upper')!=None:
                ax.plot(x,self.data.bollinger('upper'),'r-',label=type)
            indicValues=self.data.bollinger('lower')
        else:
            ax.hold(False)
            return
        if indicValues!=None:
            ax.plot(x,indicValues,'r-',label=type)
        ax.hold(False) #hold off        
    
    def setOscPlot(self, type):
        """Dodaje pod głównym wykresem wykres oscylatora danego typu lub ukrywa"""
        if type not in ['momentum','CCI','RSI','ROC','williams']:
            """Ukrywa wykres oscylatora"""
            if self.oscPlot==None:
                return
            self.oscPlot.set_visible(False)        
            self.fixPositions()                            
            self.fixTimeLabels()
        else:
            self.oscType=type                
            if self.oscPlot==None:
                oscBounds=[self.margin, self.margin, self.maxSize, self.oscHeight]
                self.oscPlot=self.fig.add_axes(oscBounds, sharex=self.mainPlot)                                            
            self.updateOscPlot()
            self.oscPlot.set_visible(True)
            self.fixPositions()
            self.fixTimeLabels()                
                                    
    def updateOscPlot(self):
        """Odrysowuje wykres oscylatora"""
        if self.oscPlot==None or self.data.corrupted:
            return
        ax=self.oscPlot                
        type=self.oscType
        ax.clear()            
        if type == 'momentum':
            oscData=self.data.momentum()
        elif type == 'CCI':
            oscData=self.data.CCI()
        elif type == 'ROC':
            oscData=self.data.ROC()
        elif type == 'RSI':
            oscData=self.data.RSI()
        elif type == 'williams':
            oscData=self.data.williams()
        elif type == 'TRIN':
            oscData=self.data.TRIN()
        elif type == 'mcClellan':
            oscData=self.data.mcClellan()
        elif type == 'adLine':
            oscData=self.data.adLine()
        else:            
            return
        if oscData!=None:
            x=range(len(self.data.close))        
            ax.plot(x,oscData,'g-',label=type)
            ax.set_xlim(x[0],x[-1])
            #legenda
            leg = ax.legend(loc='best', fancybox=True)
            leg.get_frame().set_alpha(0.5)
            self.formatDateAxis(self.oscPlot)
            self.fixOscLabels()
            self.fixTimeLabels()
    
    def fixOscLabels(self):
        """Metoda ustawia zakres osi poprawny dla danego oscylatora. Ponadto przenosi
        etykiety na prawą stronę, żeby nie nachodziły na kurs akcji"""
        ax=self.oscPlot
        type=self.oscType                
        if type == 'ROC':
            ax.set_ylim(-100, 100)
        elif type == 'RSI':
            ax.set_ylim(0, 100)
            ax.set_yticks([30,70])
        elif type == 'williams':
            ax.set_ylim(-100,0)        
        for tick in ax.yaxis.get_major_ticks():
            tick.label1On = False
            tick.label2On = True
            tick.label2.set_size(7)

    def formatDateAxis(self,ax):
        """Formatuje etykiety osi czasu."""
        chartWidth=int(self.fig.get_figwidth()*self.fig.get_dpi()*self.maxSize)        
        t = TextPath((0,0), '9999-99-99', size=7)
        labelWidth = int(t.get_extents().width)    
        num_ticks=chartWidth/labelWidth/2          
        length=len(self.data.date)
        if(length>num_ticks):
            step=length/num_ticks        
        else:
            step=1
        x=range(0,length,step)
        ax.xaxis.set_major_locator(FixedLocator(x))
        ticks=ax.get_xticks()        
        labels=[]        
        for i, label in enumerate(ax.get_xticklabels()):
            label.set_size(7)                       
            index=int(ticks[i])            
            if(index>=len(self.data.date)):
                labels.append('')
            else:
                labels.append(self.data.date[index].strftime("%Y-%m-%d"))            
            label.set_horizontalalignment('center')                                    
        ax.xaxis.set_major_formatter(FixedFormatter(labels))        
    
    def fixTimeLabels(self):
        """Włącza wyświetlanie etykiet osi czasu pod odpowiednim (tzn. najniższym)
        wykresem, a usuwa w pozostałych"""
        #oscylator jest zawsze na samym dole
        if self.oscPlot!=None and self.oscPlot.get_visible():
            for label in self.mainPlot.get_xticklabels():
                label.set_visible(False)
            for label in self.volumeBars.get_xticklabels():
                label.set_visible(False)
            for label in self.oscPlot.get_xticklabels():
                label.set_visible(True)
        #jeśli nie ma oscylatora to pod wolumenem
        elif self.volumeBars!=None and self.volumeBars.get_visible():
            for label in self.mainPlot.get_xticklabels():
                label.set_visible(False)
            for label in self.volumeBars.get_xticklabels():
                label.set_visible(True)         
        #a jak jest tylko duży wykres to pod nim
        else:
            for label in self.mainPlot.get_xticklabels():
                label.set_visible(True)                        
    
    def fixPositions(self):
        """Dopasowuje wymiary i pozycje wykresów tak żeby zawsze wypełniały całą
        przestrzeń. Główny wykres puchnie albo się kurczy, a wolumen i oscylator 
        przesuwają się w górę lub dół."""
        #na początek wszystko spychamy na sam dół
        mainBounds=[self.margin, self.margin, self.maxSize, self.maxSize]
        volBounds=[self.margin, self.margin, self.maxSize, self.volHeight]
        oscBounds=[self.margin, self.margin, self.maxSize, self.oscHeight]
        #oscylator wypycha wolumen w górę i kurczy maina
        if self.oscPlot!=None and self.oscPlot.get_visible():
            mainBounds[1]+=self.oscHeight
            mainBounds[3]-=self.oscHeight
            volBounds[1]+=self.oscHeight
            self.oscPlot.set_position(oscBounds)
        #wolumen kolejny raz kurczy maina
        if self.volumeBars.get_visible():                    
            mainBounds[1]+=self.volHeight
            mainBounds[3]-=self.volHeight
            self.volumeBars.set_position(volBounds)
        self.mainPlot.set_position(mainBounds)     
    
    def setDrawingMode(self, mode):
        """Włączamy (True) lub wyłączamy (False) tryb rysowania po wykresie"""
        self.drawingMode=mode            
        self.x0, self.y0 = None,None
    
    def drawLine(self, x0, y0, x1, y1, color='black', lwidth = 1.0, lstyle = '-'):
          """Rysuje linie (trend) na wykresie """
          newLine=Line2D([x0,x1],[y0,y1], linewidth = lwidth, linestyle=lstyle, color=color)                
          self.mainPlot.add_line(newLine)
          self.additionalLines.append(newLine)
          newLine.figure.draw_artist(newLine)                                        
          self.blit(self.mainPlot.bbox)    #blit to taki redraw  
    
    def clearLines(self):
        """Usuwa wszystkie linie narysowane dodatkowo na wykresie (tzn. nie kurs i nie wskaźniki)"""
        for line in self.additionalLines:            
            line.remove()
        self.additionalLines = []
        self.draw()
        self.blit(self.mainPlot.bbox)
    
    def clearLastLine(self):
        """Usuwa ostatnią linię narysowaną na wykresie."""
        if self.additionalLines==[]:
            return
        self.additionalLines[-1].remove()
        self.additionalLines.remove(self.additionalLines[-1])
        self.draw()
        self.blit(self.mainPlot.bbox)
    
    def drawRectangle(self, x, y, width, height, colour='blue', lwidth = 2.0, lstyle = 'dashed'):
        """Zaznacza prostokątem lukę/formację świecową czy coś tam jeszcze"""
        newRect=Rectangle((x,y),width,height,facecolor='none',edgecolor=colour,linewidth=lwidth,linestyle=lstyle)                
        self.mainPlot.add_patch(newRect)
        self.rectangles.append(newRect)
        newRect.figure.draw_artist(newRect)                                        
        self.blit(self.mainPlot.bbox)    #blit to taki redraw        
    
    def clearRectangles(self):
        """Usuwa prostokąty"""
        for rect in self.rectangles:            
            rect.remove()
        self.rectangles = []
        self.draw()
        self.blit(self.mainPlot.bbox)

    def onClick(self, event):
        """Rysujemy linię pomiędzy dwoma kolejnymi kliknięciami."""        
        if self.drawingMode==False:
            return
        if event.button==3: 
            self.clearLastLine()            
        if event.button==2: 
            self.clearLines()
        elif event.button==1:
            if self.x0==None or self.y0==None :
                self.x0, self.y0 = event.xdata, event.ydata
                self.firstPoint=True
            else:
                x1, y1 = event.xdata, event.ydata        
                self.drawLine(self.x0,self.y0,x1,y1)                
                self.x0, self.y0 = None,None                                          
        
    def drawTrend(self):
        self.clearLines()
        a, b = trend.regression(self.data.close)
        trend.optimizedTrend(self.data.close)
        #self.drawTrendLine(0, b, len(self.data.close)-1, a*(len(self.data.close)-1) + b, 'y', 2.0)
        sup, res = trend.getChannelLines(self.data.close)
        self.drawLine(sup[0][1], sup[0][0], sup[len(sup)-1][1], sup[len(sup)-1][0], 'g')
        self.drawLine(res[0][1], res[0][0], res[len(res)-1][1], res[len(res)-1][0], 'r')
        if len(self.data.close) > 30:
            sup, res = trend.getChannelLines(self.data.close, 1, 2)
            self.drawLine(sup[0][1], sup[0][0], sup[len(sup)-1][1], sup[len(sup)-1][0], 'g', 2.0)
            self.drawLine(res[0][1], res[0][0], res[len(res)-1][1], res[len(res)-1][0], 'r', 2.0)
class MatplotlibSpecWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self,
                 parent=None,
                 name=None,
                 width=4,
                 height=4,
                 dpi=80,
                 bgcolor=None):
        self.parent = parent
        #	if self.parent:
        #bgc = parent.backgroundBrush().color()
        #bgcolor = float(bgc.red())/255.0, float(bgc.green())/255.0, float(bgc.blue())/255.0
        #bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())
        self.fig = Figure(figsize=(width, height),
                          dpi=dpi,
                          facecolor=bgcolor,
                          edgecolor=bgcolor)
        self.axes = self.fig.add_axes([0.07, 0.12, 0.91, 0.86])
        # We want the axes cleared every time plot() is called
        self.axes.hold(False)
        #self.axes.set_xticklabels(self.axes.get_xticklabels(), fontsize=10)
        self.axes.set_xlabel('wavelength [$\AA$]', fontsize=12)
        self.colorScheme = color_schema.colorSchemeSpec()
        self.spec1_vis = False
        self.spec2_vis = False
        self.viewLine_vis = False
        self.spec1 = None
        self.spec2 = None
        self.viewLine = None
        self.limitsWidget = None
        self.selectSpecMask = None
        self.xlim = [0, 0]

        FigureCanvas.__init__(self, self.fig)
        ##        self.reparent(parent, QPoint(0, 0))

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.cid = self.fig.canvas.mpl_connect('button_press_event',
                                               self.onclick)
        self.cid2 = self.fig.canvas.mpl_connect('motion_notify_event',
                                                self.move)
        self.cid3 = self.fig.canvas.mpl_connect('button_release_event',
                                                self.release)
        self.connect(self.colorScheme, SIGNAL("changed()"),
                     self.updateColorScheme)
#       self.cid4=self.fig.canvas.mpl_connect('pick_event', self.onpick)

    def clearWidget(self):
        self.spec1 = None
        self.spec1 = None
        self.spec2 = None
        self.viewLine = None
        self.selectSpecMask = None
        self.spec1_vis = False
        self.spec2_vis = False
        self.viewLine_vis = False
        self.xlim = [0, 0]
        self.axes.clear()
        self.axes.set_xlabel('wavelength [$\AA$]', fontsize=12)
        self.fig.canvas.draw()

    def setLimitsWidget(self, limits_widget):
        self.limitsWidget = limits_widget
        self.connect(self.limitsWidget, SIGNAL("limitsChanged"), self.setYLim)

    def initSpec1(self, wave, spec1):
        self.spec1 = self.axes.add_line(
            matplotlib.lines.Line2D(wave,
                                    spec1,
                                    linestyle=self.colorScheme.spec1['style'],
                                    color=self.colorScheme.spec1['color'],
                                    lw=self.colorScheme.spec1['width']))
        self.spec1_vis = True

        if self.limitsWidget.auto == True:
            min = numpy.min(spec1)
            max = numpy.max(spec1)
            self.axes.set_ylim(min, max)
            self.limitsWidget.setLimits(min, max)
        else:
            self.axes.set_ylim(self.limitsWidget.min, self.limitsWidget.max)
        self.fig.canvas.draw()

    def initSpec2(self, wave, spec2):
        self.spec2 = self.axes.add_line(
            matplotlib.lines.Line2D(wave,
                                    numpy.zeros(wave.shape[0]),
                                    linestyle=self.colorScheme.spec2['style'],
                                    color=self.colorScheme.spec2['color'],
                                    lw=self.colorScheme.spec2['width']))
        self.spec2_vis = True
        if self.limitsWidget.auto == True:
            min = numpy.min(spec2)
            max = numpy.max(spec2)
            self.axes.set_ylim(min, max)
            self.limitsWidget.setLimits(min, max)
        else:
            self.axes.set_ylim(self.limitsWidget.min, self.limitsWidget.max)
        self.fig.canvas.draw()

    def initViewLine(self, wave):
        self.viewLine = self.axes.axvline(
            wave,
            linestyle=self.colorScheme.slicer['style'],
            color=self.colorScheme.slicer['color'],
            lw=self.colorScheme.slicer['width'])
        self.viewLine_vis = True
        self.fig.canvas.draw()

    def initZoomBox(self, x, y):
        self.xyZoom = (x, y)
        self.zoomRect = matplotlib.patches.Rectangle(
            xy=(x, y),
            width=0.0,
            height=0.0,
            ec=self.colorScheme.zoom['color'],
            lw=self.colorScheme.zoom['width'],
            fill=False,
            visible=True,
            zorder=10)
        self.zoomArtist = self.axes.add_artist(self.zoomRect)
        self.widthZoom = (0, 0)
        self.fig.canvas.draw()

    def initSpecMask(self, waveLimit=None, visible=False):
        self.selectSpecMask = mask_def.displaySpecRegion(
            self,
            waveLimit=waveLimit,
            linestyle=self.colorScheme.select['style'],
            color=self.colorScheme.select['color'],
            lw=self.colorScheme.select['width'],
            hatch=self.colorScheme.select['hatch'],
            fill=False,
            visible=visible,
            alpha=self.colorScheme.select['alpha'])
        if self.colorScheme.select['hatch'] == None:
            self.selectSpecMask.setFill(True)
        self.changeSelectSpec = False

    def setXLim(self, xlim):
        self.xlim = xlim
        self.axes.set_xlim(xlim)
        self.fig.canvas.draw()
#    def updateViewLine(self, pos):

    def setYLim(self, ymin, ymax):
        self.axes.set_ylim((ymin, ymax))
        self.fig.canvas.draw()

    def updateSpec1(self, wave, spec1, limits=True):
        self.spec1.set_data(wave, spec1)
        self.spec1.set_visible(True)
        self.spec1_vis = True
        if self.limitsWidget.auto == True and limits:
            min = numpy.min(spec1)
            max = numpy.max(spec1)
            self.axes.set_ylim(min, max)
            self.limitsWidget.setLimits(min, max)
        else:
            self.axes.set_ylim(self.limitsWidget.min, self.limitsWidget.max)
        self.fig.canvas.draw()

    def updateSpec2(self, wave, spec2, limits=True):
        self.spec2.set_data(wave, spec2)
        self.spec2.set_visible(True)
        self.spec2_vis = True
        if self.limitsWidget.auto == True and limits:
            min = numpy.min(spec2)
            max = numpy.max(spec2)
            self.axes.set_ylim(min, max)
            self.limitsWidget.setLimits(min, max)
        else:
            self.axes.set_ylim(self.limitsWidget.min, self.limitsWidget.max)
        self.fig.canvas.draw()

    def setVisibleSpec2(self, visible):
        self.spec2.set_visible(visible)
        self.spec2_vis = visible
        self.fig.canvas.draw()

    def setVisibleSpec1(self, visible):
        self.spec1.set_visible(visible)
        self.spec1_vis = visible
        self.fig.canvas.draw()

    def updateViewLine(self, wave):
        self.viewLine.set_data([wave, wave], [0, 1])
        self.fig.canvas.draw()

    def resizeZoomBox(self, x, y):
        self.widthZoom = (x - self.xyZoom[0], y - self.xyZoom[1])
        self.zoomRect.set_width(x - self.xyZoom[0])
        self.zoomRect.set_height(y - self.xyZoom[1])
        self.fig.canvas.draw()

    def getZoomLimit(self):
        if self.widthZoom != (0, 0):
            xmin = numpy.min(
                [self.xyZoom[0], self.xyZoom[0] + self.widthZoom[0]])
            xmax = numpy.max(
                [self.xyZoom[0], self.xyZoom[0] + self.widthZoom[0]])
            ymin = numpy.min(
                [self.xyZoom[1], self.xyZoom[1] + self.widthZoom[1]])
            ymax = numpy.max(
                [self.xyZoom[1], self.xyZoom[1] + self.widthZoom[1]])
            return xmin, xmax, ymin, ymax
        else:
            return None

    def delZoomBox(self):
        self.zoomArtist.remove()
        self.zoomRect = None
        self.widthZoom = None
        self.xyZoom = None
        self.fig.canvas.draw()

    def zoomOut(self):
        if self.spec1 != None:
            spec = self.spec1.get_data(orig=True)
            self.axes.set_xlim(spec[0][0], spec[0][-1])
            self.axes.set_ylim(numpy.min(spec[1]), numpy.max(spec[1]))
            self.fig.canvas.draw()

    def updateColorScheme(self):
        if self.spec1 != None:
            self.spec1.set_color(self.colorScheme.spec1['color'])
            self.spec1.set_linewidth(self.colorScheme.spec1['width'])
            self.spec1.set_linestyle(self.colorScheme.spec1['style'])
        if self.spec2 != None:
            self.spec2.set_color(self.colorScheme.spec2['color'])
            self.spec2.set_linewidth(self.colorScheme.spec2['width'])
            self.spec2.set_linestyle(self.colorScheme.spec2['style'])
        if self.viewLine != None:
            self.viewLine.set_color(self.colorScheme.slicer['color'])
            self.viewLine.set_linewidth(self.colorScheme.slicer['width'])
            self.viewLine.set_linestyle(self.colorScheme.slicer['style'])
        if self.selectSpecMask != None:
            self.selectSpecMask.setColor(self.colorScheme.select['color'])
            self.selectSpecMask.setLineWidth(self.colorScheme.select['width'])
            self.selectSpecMask.setLineStyle(self.colorScheme.select['style'])
            self.selectSpecMask.setAlpha(self.colorScheme.select['alpha'])
            self.selectSpecMask.setHatch(self.colorScheme.select['hatch'])
            if self.colorScheme.select['hatch'] == None:
                self.selectSpecMask.setFill(True)
            else:
                self.selectSpecMask.setFill(False)
        self.fig.canvas.draw()
#   def setVisible(self, spec1_vis = None, spec2_vis = None, viewLine_vis = None):
#      if spec1_vis !=None:
#         self.spec1_vis = spec1_vis
#        self.spec1[0].set_visible(self.spec1_vis)

#    if spec2_vis !=None:
#       self.spec2_vis = spec2_vis
#      self.spec2[0].set_visible(self.spec2_vis)
# if viewLine_vis !=None:
#    self.viewLine_vis = viewLine_vis
#   self.viewLine.set_visible(self.viewLine_vis)
#self.fig.canvas.draw()

    def onclick(self, event):
        self.emit(SIGNAL("mouse_press_event"), event.button, event.x, event.y,
                  event.xdata, event.ydata)
#        print 'button=%d, x=%d, y=%d, xdata=%s, ydata=%s'%(
#        event.button, event.x, event.y, str(event.xdata), str(event.ydata))

    def move(self, event):
        self.emit(SIGNAL("mouse_move_event"), event.button, event.x, event.y,
                  event.xdata, event.ydata)

    def release(self, event):
        self.emit(SIGNAL("mouse_release_event"), event.button, event.x,
                  event.y, event.xdata, event.ydata)

    def onpick(self, event):
        thisline = event.artist
        xdata = thisline.get_xdata()
        ydata = thisline.get_ydata()
        ind = event.ind
        print thisline, ind
        print 'onpick points:', zip(xdata[ind], ydata[ind])

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QSize(w, h)

    def minimumSizeHint(self):
        return QSize(10, 10)
Example #25
0
class Section(FigureCanvas):
    """ Object of this class displays one section """
    def __init__(self, well, dpi=90):
        """
        :type well: well which will be displayed
        """
        self.well = well

        self.figure = Figure(dpi=dpi, facecolor="white")
        super().__init__(self.figure)
        # self.ax = self.figure.add_subplot(1, 1,1)
        self.refresh()

    def refresh(self):
        """ executes every time when new curve have been checked"""
        self.figure.clear()
        if not self.well.las:
            self.hide()
            return
        checked_curves = list(
            filter(
                lambda curve: curve.mnemonic != "DEPT" and curve.qt_item.
                checkState(0), self.well.las.curves))
        if len(checked_curves) < 1:
            self.hide()
            return

        self.setFixedWidth(0)
        self.figure.set_figwidth(0)

        self.figure.suptitle('"' + self.well.name + '"',
                             fontsize=10,
                             fontweight='normal')
        number_of_curves = len(checked_curves)

        for i, curve in enumerate(checked_curves):
            if curve.qt_item.checkState(0):
                self.setFixedWidth(self.geometry().width() + 200)
                self.figure.set_figwidth(self.figure.get_figwidth() +
                                         200 / self.figure.get_dpi())

                ax = self.figure.add_subplot(1, number_of_curves, i + 1)
                ax.set_title(curve.mnemonic, fontsize=8, fontweight='normal')
                ax.tick_params(axis='both', which='major', labelsize=7)
                ax.tick_params(axis='both', which='minor', labelsize=7)
                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                ax.spines['left'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                # Only show ticks on the left and bottom spines
                ax.yaxis.set_ticks_position('left')
                ax.xaxis.set_ticks_position('bottom')

                ax.invert_yaxis()
                ax.grid(color="gray")
                ax.set_ylabel('depth (m)', fontsize=7)
                # t = np.arange(0.0, 3.0, 0.01)
                # s = np.sin(4 * np.pi * t)
                ax.plot(curve.data, self.well.las["DEPT"])

        self.figure.tight_layout(rect=(0, 0, 1, 0.98))
class MatplotlibImgWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent=None, name=None, width=4, height=4, dpi=100, bgcolor=None):
	self.parent = parent
#	if self.parent:
		#bgc = parent.backgroundBrush().color()
		#bgcolor = float(bgc.red())/255.0, float(bgc.green())/255.0, float(bgc.blue())/255.0
		#bgcolor = "#%02X%02X%02X" % (bgc.red(), bgc.green(), bgc.blue())
        self.fig = Figure(figsize=(width, height), dpi=dpi, facecolor=bgcolor, edgecolor=bgcolor)
        self.axes = self.fig.add_axes([0.05, 0.05, 0.9, 0.9])
        # We want the axes cleared every time plot() is called
        self.colorScheme = color_schema.colorSchemeSpax()
        self.axes.hold(False)
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.image = None
        self.pickRectangle= None
        self.selectSpaxMask = None
        self.limitsWidget = None
        self.zoomRect = None
        #self.axes.set_tickslabels([])
        
        
        
        FigureCanvas.__init__(self, self.fig)
##        self.reparent(parent, QPoint(0, 0))

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.cid=self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.cid2=self.fig.canvas.mpl_connect('motion_notify_event', self.move)
        self.cid3=self.fig.canvas.mpl_connect('button_release_event', self.release)
        self.connect(self.colorScheme, SIGNAL("changed()"), self.updateColorScheme)
        
    def clearWidget(self):
        self.image=None
        self.selectSpaxMask = None
        self.pickRectangle = None
        self.axes.clear()
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.fig.canvas.draw()
        
    def setLimitsWidget(self, limits_widget):
        self.limitsWidget = limits_widget
        self.connect(self.limitsWidget, SIGNAL("limitsChanged"), self.setCLim)
     
        
    def initImage(self, image, limits=True):
        self.image = self.axes.imshow(image, interpolation=self.colorScheme.image['interpolation'], filterrad=self.colorScheme.image['radius'], origin = 'lower', cmap=matplotlib.cm.get_cmap(self.colorScheme.image['colormap']), zorder=1) 
        if self.limitsWidget.auto==True and limits==True:
            min = numpy.min(image)
            max = numpy.max(image) 
            self.limitsWidget.setLimits(min, max)
            if self.colorScheme.image['scaling']=='Logarithmic':
                if min<0:
                    vmin=1e-5
                else:
                    vmin=min
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=max)
                self.image.set_norm(norm)
            else:
                self.image.set_clim(min, max)
        else:
            if self.colorScheme.image['scaling']=='Logarithmic':
                if self.limitsWidget.min<0:
                    vmin=1e-5
                else:
                    vmin=self.limitsWidget.min
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=self.limitsWidget.max)
                self.image.set_norm(norm)
            else:
                self.image.set_clim(self.limitsWidget.min, self.limitsWidget.max)

        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.fig.canvas.draw()
        
    def initSpaxMask(self, visible):
        dim = self.image.get_array().shape
        if  self.colorScheme.select['hatch']!='None':
            fill = False
            hatch = self.colorScheme.select['hatch']
        else:
            fill= True
            hatch = None
        self.selectSpaxMask = mask_def.displayImgMask(self,  (dim[0], dim[1]),  hatch=hatch, color=self.colorScheme.select['color'], fill=fill, visible=visible)
        
    def setCLim(self, min, max):
        if self.colorScheme.image['scaling']=='Logarithmic':
            if min<0:
                vmin=1e-5
            else:
                vmin=min
            norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=max)
            self.image.set_norm(norm)
        else:
            self.image.set_clim(min, max)
        self.fig.canvas.draw()
        
    def initShowSpax(self, x, y, visible):
        self.showSpax = [(x-0.5, y-0.5), visible]
        if  self.colorScheme.marker['hatch']!='None':
            fill = False
            hatch = self.colorScheme.marker['hatch']
        else:
            fill= True
            hatch = None
        self.pickRectangle = matplotlib.patches.Rectangle(xy=self.showSpax[0], width=1.0, height=1.0, lw=self.colorScheme.marker['width'],  ec=self.colorScheme.marker['color'], fc=self.colorScheme.marker['color'],fill=fill, visible=self.showSpax[1], hatch = hatch,  alpha=self.colorScheme.marker['alpha'], zorder=10)
        self.axes.add_artist(self.pickRectangle)
        self.fig.canvas.draw()
        
    def initZoomBox(self, x, y):
        self.xyZoom = (x, y)
        self.widthZoom = (0, 0)
        self.zoomRect = matplotlib.patches.Rectangle(xy=(x-0.5, y-0.5), width=0.0, height=0.0, ec='r',lw=1.5,  fill=False, visible=True, zorder=10)
        self.zoomArtist = self.axes.add_artist(self.zoomRect)
        self.fig.canvas.draw()
        
    def updateImage(self, image, limits=True):
        
        if self.limitsWidget.auto==True and limits==True:
            min = numpy.min(image)
            max = numpy.max(image) 
            if self.colorScheme.image['scaling']=='Logarithmic':
                if min<0:
                    vmin=1e-5
                else:
                    vmin=min
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=max)
                self.image.set_norm(norm)
                self.image.set_data(image)
            else:
                self.image.set_data(image)
                self.image.set_clim(min,max)
            self.limitsWidget.setLimits(min, max)
        else:
            if self.colorScheme.image['scaling']=='Logarithmic':
                if self.limitsWidget.min<0:
                    vmin=1e-5
                else:
                    vmin=self.limitsWidget.min
                norm = matplotlib.colors.LogNorm(vmin=vmin,  vmax=self.limitsWidget.max)
                self.image.set_norm(norm)
            else:
                self.image.set_data(image)
                self.image.set_clim(self.limitsWidget.min,self.limitsWidget.max)
        self.fig.canvas.draw()
        
    def moveSelectSpax(self, x, y):
        self.showSpax = [(x-0.5, y-0.5), self.showSpax[1]]
        self.pickRectangle.set_xy((x, y))
        self.fig.canvas.draw()
        
    def resizeZoomBox(self, x, y):
        self.widthZoom = (x-self.xyZoom[0], y-self.xyZoom[1])
        self.zoomRect.set_width(x-self.xyZoom[0])
        self.zoomRect.set_height(y-self.xyZoom[1])
        self.fig.canvas.draw()
        
    def getZoomLimit(self):
        if self.widthZoom != (0, 0):
            xmin = numpy.min([self.xyZoom[0], self.xyZoom[0]+self.widthZoom[0]])
            xmax = numpy.max([self.xyZoom[0], self.xyZoom[0]+self.widthZoom[0]])
            ymin = numpy.min([self.xyZoom[1], self.xyZoom[1]+self.widthZoom[1]])
            ymax = numpy.max([self.xyZoom[1], self.xyZoom[1]+self.widthZoom[1]])
            return xmin, xmax, ymin, ymax
        else:
            return None
            
    def delZoomBox(self):
        self.zoomArtist.remove()
        self.zoomRect = None
        self.widthZoom = None
        self.xyZoom = None
        self.fig.canvas.draw()
        
    def zoomOut(self):
        if self.image!=None:
            array = self.image.get_array()
            dim = array.shape
            self.axes.set_xlim(-0.5,dim[1]-1.5)  
            self.axes.set_ylim(-0.5,dim[0]-1.5)  
            self.fig.canvas.draw()


    def updateColorScheme(self):
        if self.image!=None:
            if self.colorScheme.image['reversed']:
                cmap = self.colorScheme.image['colormap']+'_r'
            else:
                cmap = self.colorScheme.image['colormap']
            self.image.set_cmap(matplotlib.cm.get_cmap(cmap))
            self.image.set_interpolation(self.colorScheme.image['interpolation'])
            self.image.set_filterrad(self.colorScheme.image['radius'])
            if self.colorScheme.image['scaling']=='Linear':
                norm = matplotlib.colors.NoNorm()
                self.image.set_norm(norm)
                self.image.set_clim(self.limitsWidget.min,self.limitsWidget.max)
            elif self.colorScheme.image['scaling']=='Logarithmic':
                if self.limitsWidget.min<0:
                    vmin=1e-5
                else:
                    vmin=self.limitsWidget.min
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=self.limitsWidget.max)
                self.image.set_norm(norm)
        if self.pickRectangle!=None:
            if  self.colorScheme.marker['hatch']!='None':
                self.pickRectangle.set_fill(False)    
                self.pickRectangle.set_hatch(self.colorScheme.marker['hatch'])
            else:
                self.pickRectangle.set_fill(True)    
                self.pickRectangle.set_hatch(None)
            self.pickRectangle.set_ec(self.colorScheme.marker['color'])
            self.pickRectangle.set_fc(self.colorScheme.marker['color'])
            self.pickRectangle.set_alpha(self.colorScheme.marker['alpha'])
            self.pickRectangle.set_lw(self.colorScheme.marker['width'])
        if self.selectSpaxMask!=None:
            if  self.colorScheme.select['hatch']!='None':
                self.selectSpaxMask.setFill(False)
                self.selectSpaxMask.setHatch(self.colorScheme.select['hatch'])
            else:
                self.selectSpaxMask.setFill(True)
                self.selectSpaxMask.setHatch(None)
            self.selectSpaxMask.setColor(self.colorScheme.select['color'])
            self.selectSpaxMask.setAlpha(self.colorScheme.select['alpha'])
        self.fig.canvas.draw()
        
    
    
    
    def onclick(self,event):
        self.emit(SIGNAL("mouse_press_event"),event.button,event.x,event.y,event.xdata,event.ydata)
 #       print 'button=%d, x=%d, y=%d, xdata=%s, ydata=%s'%(
  #      event.button, event.x, event.y, str(event.xdata), str(event.ydata))

    def move(self,event):
        self.emit(SIGNAL("mouse_move_event"),event.button,event.x,event.y,event.xdata,event.ydata)
    
    def release(self, event):
        self.emit(SIGNAL("mouse_release_event"),event.button,event.x,event.y,event.xdata,event.ydata)

    def sizeHint(self):
        w = self.fig.get_figwidth()
        h = self.fig.get_figheight()
        return QSize(w, h)

    def minimumSizeHint(self):
        return QSize(10, 10)
Example #27
0
class PlotPanel(wx.Panel):
    def __init__(self, parent, toolbar_visible=False, **kwargs):
        """
        A panel which contains a matplotlib figure with (optionally) a 
            toolbar to zoom/pan/ect.
        Inputs:
            parent              : the parent frame/panel
            toolbar_visible     : the initial state of the toolbar
            **kwargs            : arguments passed on to 
                                  matplotlib.figure.Figure
        Introduces:
            figure              : a matplotlib figure
            canvas              : a FigureCanvasWxAgg from matplotlib's
                                  backends
            toggle_toolbar()    : to toggle the visible state of the toolbar
            show_toolbar()      : to show the toolbar
            hide_toolbar()      : to hide the toolbar
        Subscribes to:
            'TOGGLE_TOOLBAR'    : if data=None or data=self will toggle the
                                  visible state of the toolbar
            'SHOW_TOOLBAR'      : if data=None or data=self will show the
                                  toolbar
            'HIDE_TOOLBAR'      : if data=None or data=self will hide the
                                  toolbar
        """
        wx.Panel.__init__(self, parent)

        self.figure = Figure(**kwargs)
        self.canvas = Canvas(self, wx.ID_ANY, self.figure)
        self.toolbar = CustomToolbar(self.canvas, self)
        self.toolbar.Show(False)
        self.toolbar.Realize()

        toolbar_sizer = wx.BoxSizer(orient=wx.HORIZONTAL)
        self.x_coord = wx.StaticText(self, label='x:')
        self.y_coord = wx.StaticText(self, label='y:')
        toolbar_sizer.Add(self.toolbar, proportion=2)
        toolbar_sizer.Add(self.x_coord, proportion=1, 
                flag=wx.ALIGN_CENTER_VERTICAL|wx.ALIGN_LEFT|wx.LEFT, border=5)
        toolbar_sizer.Add(self.y_coord, proportion=1,
                flag=wx.ALIGN_CENTER_VERTICAL|wx.ALIGN_LEFT|wx.LEFT, border=5)


        sizer = wx.BoxSizer(orient=wx.VERTICAL)
        sizer.Add(toolbar_sizer, proportion=0, flag=wx.EXPAND)
        sizer.Add(self.canvas,  proportion=1, flag=wx.EXPAND)
        self.SetSizer(sizer)

        figheight = self.figure.get_figheight()
        figwidth  = self.figure.get_figwidth()
        min_size = self.set_minsize(figwidth ,figheight)

        self._toolbar_visible = toolbar_visible
        if toolbar_visible:
            self.show_toolbar()
        else:
            self.hide_toolbar()

        self.canvas.Bind(wx.EVT_LEFT_DCLICK, self.toggle_toolbar)

        self._last_time_coordinates_updated = 0.0
        self._coordinates_blank = True
        self.canvas.mpl_connect('motion_notify_event', self._update_coordinates)
        
        # ---- Setup Subscriptions
        pub.subscribe(self._toggle_toolbar, topic="TOGGLE_TOOLBAR")
        pub.subscribe(self._show_toolbar,   topic="SHOW_TOOLBAR")
        pub.subscribe(self._hide_toolbar,   topic="HIDE_TOOLBAR")

        self.axes = {}

    def _save_history(self):
        if (hasattr(self.toolbar, '_views') and 
                hasattr(self.toolbar, '_positions')):
            self._old_history = {}
            self._old_history['views'] = copy.copy(self.toolbar._views)
            self._old_history['positions'] = copy.copy(self.toolbar._positions)

    def _restore_history(self):
        if hasattr(self, '_old_history'):
            self.toolbar._views = self._old_history['views']
            self.toolbar._positions = self._old_history['positions']
            self.toolbar.set_history_buttons()
            if hasattr(self.toolbar, '_update_view'):
                self.toolbar._update_view()

    def clear(self, keep_history=False):
        self._save_history()
        self.axes = {}
        self.figure.clear()
        gc.collect()

    def set_minsize(self, figwidth, figheight):
        dpi = self.figure.get_dpi()
        # compensate for toolbar height, even if not visible, to keep
        #   it from riding up on the plot when it is visible and the
        #   panel is shrunk down.
        toolbar_height = self.toolbar.GetSize()[1]
        min_size_x = dpi*figwidth
        min_size_y = dpi*figheight+toolbar_height
        min_size = (min_size_x, min_size_y)
        self.SetMinSize(min_size)
        return min_size

    # --- TOGGLE TOOLBAR ----
    def _toggle_toolbar(self, message):
        if (message.data is None or 
            self is message.data):
            self.toggle_toolbar()

    def toggle_toolbar(self, event=None):
        '''
        Toggle the visible state of the toolbar.
        '''
        if self._toolbar_visible:
            self.hide_toolbar()
        else:
            self.show_toolbar()

    # --- SHOW TOOLBAR ----
    def _show_toolbar(self, message):
        if (message.data is None or 
            self is message.data):
            self.show_toolbar()

    def show_toolbar(self):
        '''
        Make the toolbar visible.
        '''
        self.toolbar.Show(True)
        self.x_coord.Show(True)
        self.y_coord.Show(True)
        self._toolbar_visible = True
        self.Layout()

    # --- HIDE TOOLBAR ----
    def _hide_toolbar(self, message):
        if (message.data is None or 
            self is message.data):
            self.hide_toolbar()

    def hide_toolbar(self):
        '''
        Make toolbar invisible.
        '''
        self.toolbar.Show(False)
        self.x_coord.Show(False)
        self.y_coord.Show(False)
        self._toolbar_visible = False
        self.Layout()

    def _update_coordinates(self, event=None):
        if event.inaxes:
            now = time.time()
            # only once every 100 ms.
            if now-self._last_time_coordinates_updated > 0.100:
                self._last_time_coordinates_updated = now
                x, y = event.xdata, event.ydata
                self._coordinates_blank = False
                self.x_coord.SetLabel('x: %e' % x)
                self.y_coord.SetLabel('y: %e' % y)
        elif not self._coordinates_blank:
            # make the coordinates blank
            self._coordinates_not_blank = True
            self.x_coord.SetLabel('x: ')
            self.y_coord.SetLabel('y: ')
Example #28
0
class Map(FigureCanvasQTAgg):
    class DetailLevel(enum.Enum):
        ZIPCODE = 1
        COUNTY = 2
        STATE = 3

    lod_auto_switch_order = (None, DetailLevel.ZIPCODE, DetailLevel.COUNTY,
                             DetailLevel.STATE)
    lod_auto_switch_threshold = (0.00, 0.1, 0.4, 10)

    scale_min = 0.001
    scale_max = 4
    scale_change_per_wheel_deg_ratio = 4.165E-4
    scale_change_per_wheel_deg = 0.833E-6 * 5

    min_update_interval = 0.1

    center_changed = QtCore.pyqtSignal(object)
    clim_changed = QtCore.pyqtSignal(object)

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes: Axes = self.fig.gca()
        super(Map, self).__init__(self.fig)
        self.zipcode_data = geopandas.read_file("data/geography/zipcode")
        self.zipcode_data["id"] = self.zipcode_data["ZCTA5CE10"].apply(
            lambda x: int(x))
        self.county_data = geopandas.read_file("data/geography/county")
        self.county_data["id"] = self.county_data.apply(
            lambda x: int(x["STATEFP"] + x["COUNTYFP"]), axis=1)
        self.state_data = geopandas.read_file("data/geography/state")
        print(self.state_data)
        self.state_data["id"] = self.state_data.apply(
            lambda x: int(x["STATEFP"]), axis=1)

        self.center = [-71.0900052, 42.3367387]
        self.scale = 0.01  # degrees per inch
        self.lod = Map.DetailLevel.ZIPCODE
        self.lod_auto_switch = True

        self.graph_min = 1000
        self.graph_max = 1000000
        self.color_bar_auto_range = True
        self.color_axis = (0, 1500000)
        self.cur_clim = None

        self.repo = DataRepository("data/housing.sqlite")
        self.last_update = timeit.default_timer()
        self.dragging = False
        self.drag_start = (0.0, 0.0)
        self.cur_date = "2021-02-28"
        self.show_zip_codes = False
        self.dates = self.repo.get_dates()

        self.cmap = copy.copy(cm.get_cmap("plasma"))
        self.cmap.set_bad(color="lightgrey")

        self.do_update = False
        self.timer = Qt.QTimer(self)
        self.timer.timeout.connect(self.on_timer)
        self.timer.start(100)

        self.update_plot()

    def update_plot(self):
        # print("plot")
        start = timeit.default_timer()

        # calculate lod auto switch
        if self.lod_auto_switch:
            if self.scale < self.lod_auto_switch_threshold[
                    self.lod_auto_switch_order.index(self.lod) - 1]:
                self.lod = self.lod_auto_switch_order[
                    self.lod_auto_switch_order.index(self.lod) - 1]
            elif self.scale > self.lod_auto_switch_threshold[
                    self.lod_auto_switch_order.index(self.lod)]:
                self.lod = self.lod_auto_switch_order[
                    self.lod_auto_switch_order.index(self.lod) + 1]
        print(self.lod)
        print(self.scale)

        width_deg = self.fig.get_figwidth() * self.scale
        height_deg = self.fig.get_figheight() * self.scale
        bbox = ((self.center[0] - width_deg / 2,
                 self.center[0] + width_deg / 2),
                (self.center[1] - height_deg / 2,
                 self.center[1] + height_deg / 2))
        # print(bbox)
        if self.lod == self.DetailLevel.ZIPCODE:
            zipcodes_visible = self.repo.get_zipcodes_within_bbox(
                bbox[0], bbox[1])
            geo_data_visible = self.zipcode_data[self.zipcode_data.id.isin(
                zipcodes_visible)].copy()
            geo_data_visible["value"] = geo_data_visible["id"].apply(
                lambda zipcode: self.repo.get_house_value_by_date_and_zipcode(
                    zipcode, self.cur_date))
        elif self.lod == self.DetailLevel.COUNTY:
            counties_visible = self.repo.get_counties_within_bbox(
                bbox[0], bbox[1])
            geo_data_visible = self.county_data[self.county_data.id.isin(
                counties_visible)].copy()
            geo_data_visible["value"] = geo_data_visible["id"].apply(
                lambda county: self.repo.get_house_value_by_date_and_county(
                    county, self.cur_date))
        else:
            states_visible = self.repo.get_states_within_bbox(bbox[0], bbox[1])
            geo_data_visible = self.state_data[self.state_data.id.isin(
                states_visible)].copy()
            geo_data_visible["value"] = geo_data_visible["id"].apply(
                lambda state: self.repo.get_house_value_by_date_and_state(
                    state, self.cur_date))
        if self.color_bar_auto_range:
            self.graph_min = geo_data_visible["value"].quantile(0.25)
            self.graph_max = geo_data_visible["value"].quantile(0.75)

        t1 = timeit.default_timer()
        self.axes.clear()
        self.fig.clear()
        self.axes = self.fig.gca()

        t2 = timeit.default_timer()
        self.axes.clear()
        self.axes.set(xlim=bbox[0], ylim=bbox[1])

        patches = []
        values = []

        for index, row in geo_data_visible.iterrows():
            poly = row["geometry"]
            value = row["value"]

            if value is None:
                value = np.nan

            if isinstance(poly, shapely.geometry.MultiPolygon):
                for sub_poly in poly:
                    a = np.asarray(sub_poly.exterior)
                    patches.append(Polygon(a))
                    values.append(value)
            else:
                a = np.asarray(poly.exterior)
                patches.append(Polygon(a))
                values.append(value)
        patches = PatchCollection(patches, edgecolors="white")
        values = np.asarray(values)

        if values is not None:
            patches.set_array(values)
            if self.color_bar_auto_range:
                patches.set_clim(vmin=self.graph_min, vmax=self.graph_max)
            else:
                patches.set_clim(vmin=self.color_axis[0],
                                 vmax=self.color_axis[1])
            self.color_axis = patches.get_clim()
            self.clim_changed.emit(self.color_axis)
            patches.set_cmap(self.cmap)

        self.axes.add_collection(patches, autolim=True)

        if self.show_zip_codes:
            graph_mid = (self.graph_max + self.graph_min) / 2
            if self.lod == self.DetailLevel.ZIPCODE:
                for idx, row in geo_data_visible.iterrows():
                    color = "white" if row["value"] < graph_mid else "black"
                    self.axes.annotate(
                        text=f"{row['id']:0>5}",
                        xy=row["geometry"].centroid.coords[:][0],
                        horizontalalignment='center',
                        color=color,
                        fontsize=10,
                    )
            else:
                for idx, row in geo_data_visible.iterrows():
                    color = "white" if row["value"] is not None and row[
                        "value"] < graph_mid else "black"
                    self.axes.annotate(
                        text=row["NAME"],
                        xy=row["geometry"].centroid.coords[:][0],
                        horizontalalignment='center',
                        color=color,
                        fontsize=10,
                    )

        t3 = timeit.default_timer()

        comma_fmt = StrMethodFormatter("${x:,.0f}")
        self.fig.colorbar(cm.ScalarMappable(norm=patches.norm, cmap=self.cmap),
                          ax=self.axes,
                          format=comma_fmt,
                          pad=0.05)
        # self.fig.tight_layout()
        self.fig.canvas.draw()

        end = timeit.default_timer()
        print(f"{end - start}, {t1 - start}, {end - t1}")
        print(f"clear: {t2-t1}, plot: {t3-t2}, render: {end - t3}")

    def wheelEvent(self, event: Qt.QWheelEvent):
        self.scale = self.scale + (event.angleDelta().y() * -self.scale *
                                   self.scale_change_per_wheel_deg_ratio)
        if self.scale < self.scale_min:
            self.scale = self.scale_min
        elif self.scale > self.scale_max:
            self.scale = self.scale_max

        self.do_update = True

        # self.update_plot()

    def mousePressEvent(self, event: Qt.QMouseEvent):
        if event.buttons() & QtCore.Qt.MiddleButton:
            self.dragging = True
            self.drag_start = (event.x(), event.y())

    def mouseMoveEvent(self, event: Qt.QMouseEvent):
        if self.dragging:
            scale_factor = self.scale / self.fig.get_dpi()
            self.center[0] -= (event.x() - self.drag_start[0]) * scale_factor
            self.center[1] += (event.y() - self.drag_start[1]) * scale_factor
            self.drag_start = (event.x(), event.y())

            self.do_update = True
            # self.update_plot()
            self.center_changed.emit(self.center)

    def mouseReleaseEvent(self, event: Qt.QMouseEvent):
        if not event.buttons() & QtCore.Qt.MiddleButton:
            self.dragging = False

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self.do_update = True

    def on_timer(self):
        if self.do_update:
            self.do_update = False
            self.update_plot()

    def update_date_by_idx(self, idx):
        self.cur_date = self.dates[idx]
        self.do_update = True
Example #29
0
class Window(QDialog):
    def __init__(self, parent=None):
        super(Window, self).__init__(parent)
        # a figure instance to plot on
        self.figure = Figure(facecolor=FACE_COLOR)

        # This is to make the QDialog Window full size.
        if START_FULLSCREEN:
            self.showMaximized()
            self.setFixedSize(self.size())

        self.canvas = FigureCanvas(self.figure)

        # Steup the figures axis, margins, and background color.
        self.setup_figure()

        # setup the gui
        self.setup_ui()

        # initialize the dots list
        self.dots = []
        self.tracked_dots = {}
        self.highlighted_dot = None

        # keep track of the trials
        self.dot_motion_active = False
        self.clicking_active = False
        self.trial_dictionary = TRIAL_DICTIONARY
        self.trial_id = 0
        self.trial_clicks = self.trial_dictionary[self.trial_id]
        self.clicked_dots = []
        self.trial_starts = []
        self.trial_durations = []
        self.correct_dots = np.sort(list(self.trial_dictionary.values()))
        self.total_duration = 0
        self.output_file = ""
        self.mouse_pressed = False

    def setup_dots(self):
        """Removes any dots on the canvas and generates a new list of them.
        This method does not draw them or add them to the canvas.
        """
        self.remove_dots()
        for i in range(NUMBER_OF_DOTS):
            self.dots.append(
                TrackableDot(
                    (self.generate_location(self.dots, RADIUS)), RADIUS, COLOR,
                    self.generate_velocity(VELOCITY), i))

    def draw_dots(self):
        """Uses the dots list and adds the circles to the canvas and draws them.
        This method does not create the circle objects.
        """
        for i in range(len(self.dots)):
            self.ax.add_artist(self.dots[i])
        self.canvas.draw()

    def track_dots(self):
        """Sets up a dictionary and in order to start the tracking.
        Args:
            num: an integer less than NUMBER_OF_DOTS, this determines the number
            of tracked dots.
        """
        self.tracked_dots = {}
        for i in range(self.trial_clicks):
            self.tracked_dots[self.dots[i].id] = (self.dots[i])

    @property
    def valid_click(self):
        return self.trial_clicks > 0 and self.clicking_active

    def generate_velocity(self, mag=1.0):
        """Generate a velocity with magnitude of mag.
        Returns:
            vel: a velocity vector as a 2-tuple
        """
        vel = np.random.uniform(-1, 1), np.random.uniform(-1, 1)
        normalization = np.sqrt(vel[0]**2 + vel[1]**2)
        return vel[0] / normalization * mag, vel[1] / normalization * mag

    def generate_location(self, dots_list, radius):
        """Generates a location using the width and height of the window.
        Args:
            dots_list: a list of generated dots, used to prevent location collisions
            radius: the size of the dots in the list
        Returns:
            a tuple which represents the location.
        """
        location = (np.random.uniform(0.1, 0.9) * self.WIDTH,
                    np.random.uniform(0.1, 0.9) * self.HEIGHT)
        should_restart = True
        while should_restart:
            should_restart = False
            for dot in dots_list:
                if (np.sqrt((location[0] - dot.center[0])**2 +
                            (location[1] - dot.center[1])**2)) < (5 * radius):
                    location = (np.random.uniform(0.1, 0.9) * self.WIDTH,
                                np.random.uniform(0.1, 0.9) * self.HEIGHT)
                    should_restart = True
                    break
        return location

    def grab_default_dimensions(self):
        """Stores default dimensions in order to allow resizing after tracking
        has ended.
        """
        self.DEFAULT_MINSIZE = self.minimumSize()
        self.DEFAULT_MAXSIZE = self.maximumSize()
        self.WIDTH = self.figure.get_figwidth()
        self.HEIGHT = self.figure.get_figheight()

    def reset_sizing(self):
        self.setMinimumSize(self.DEFAULT_MINSIZE)
        self.setMaximumSize(self.DEFAULT_MAXSIZE)

    def setup_ui(self):
        """Sets up two BoxLayouts(a vertical and a horizontal) with two buttons
        and a text box. It also holds our matplotlib window.
        """
        # track button setup
        self.track_button = QPushButton('Begin Tracking')
        self.track_button.clicked.connect(self.begin_tracking_button_clicked)
        self.track_button.setMaximumWidth(200)

        # stop button setup
        # self.stop_button = QPushButton('Stop Tracking')
        # self.stop_button.clicked.connect(self.stop_tracking_button_clicked)
        # self.stop_button.setEnabled(False)
        # self.stop_button.setMaximumWidth(200)

        # text field setup
        self.text_field = QLineEdit()
        self.text_field.setPlaceholderText("Participant ID")
        self.text_field.setMaximumWidth(200)

        # next button setup
        self.next_button = QPushButton('Next Trial')
        self.next_button.clicked.connect(self.next_button_clicked)
        self.next_button.setEnabled(False)
        self.next_button.setMaximumWidth(200)

        # Info label setup
        self.info_label = QLabel("Trial #0 -- 2 Clicks Left")
        self.info_label.setMaximumWidth(300)
        self.info_label.setMinimumWidth(300)
        self.info_label.setMaximumHeight(20)
        self.info_label.setFont(QFont("Arial", 12, QFont.Normal))

        # Set up the bottom layout
        self.bottom_layout = QHBoxLayout()
        self.bottom_layout.addWidget(self.track_button)
        #self.bottom_layout.addWidget(self.stop_button)
        self.bottom_layout.addWidget(self.info_label)
        self.bottom_layout.addWidget(self.next_button)
        self.bottom_layout.addWidget(self.text_field)

        # set up the message box
        self.pid_message = QMessageBox()
        self.pid_message.setIcon(QMessageBox.Critical)
        self.pid_message.setText("Please add a Participant ID.")
        self.pid_message.setStandardButtons(QMessageBox.Ok)

        # set up stop message box
        self.stop_message = QMessageBox()
        self.stop_message.setIcon(QMessageBox.Critical)
        self.stop_message.setText(
            "Are you sure you would like to stop tracking?")
        self.stop_message.setStandardButtons(QMessageBox.Yes | QMessageBox.No)

        self.save_message = QMessageBox()
        self.save_message.setIcon(QMessageBox.Information)
        self.save_message.setText("")
        # set the whole layout
        layout = QVBoxLayout()
        layout.addWidget(self.canvas)
        layout.addLayout(self.bottom_layout)
        self.setLayout(layout)

    def remove_dots(self):
        """Iterate through the dots array and delete them while removing them
        from the canvas.
        """
        for dot in self.dots:
            dot.remove()
        self.dots = []
        self.canvas.draw()

    def closeEvent(self, event):
        """This method is called when the application is closing.
        Disconnects the event handler which tracks the clicking.
        """
        print("The window will close.")
        try:
            self.figure.canvas.mpl_disconnect(self.cid)
        except AttributeError:
            pass
        try:
            self.line_ani._stop()
        except AttributeError:
            pass
        event.accept()

    def resizeEvent(self, event):
        """Handles resize event and updates the default dimesions.
        """
        self.grab_default_dimensions()

    def display_pid_message(self):
        """Displays a message which requires a PID to be entered.
        """
        self.pid_message.exec_()

    def next_button_clicked(self):
        """Start the trial and updates the instance variables about the current
        trial. Make sure that clicking is not active and a new trial is ok to
        begin.
        """
        if not self.clicking_active:
            if self.trial_id == 1:
                self.total_duration = time.time()
            self.update_info_label()
            self.setup_dots()
            self.draw_dots()
            self.canvas.draw()
            self.track_dots()
            self.animate_plot()
            self.canvas.draw()

    def begin_tracking_button_clicked(self):
        """Action when tracking button clicked.
        Validate the textField.text value and begin the animation sequence.
        """
        self.grab_default_dimensions()
        self.ax.set_ylim([0, self.HEIGHT])
        self.ax.set_xlim([0, self.WIDTH])
        if (self.has_valid_pid()):
            self.next_button.setEnabled(True)
            self.text_field.setReadOnly(True)
            self.track_button.setEnabled(False)
            self.track_button.setText('Tracking ...')
        else:
            self.display_pid_message()

    def stop_tracking_button_clicked(self, event):
        """Checks to see if tracking is currently active, if it is it stops
        the animation and allows the window to be resized.
        """
        retval = self.stop_message.exec_()
        if retval == QMessageBox.Yes:
            self.text_field.setReadOnly(False)
            self.canvas.draw()
            self.track_button.setEnabled(True)
            self.track_button.setText('Begin Tracking')
            self.remove_dots()
            try:
                self.dot_ani._stop()
            except AttributeError:
                pass

    def has_valid_pid(self):
        """Makes sure that the PID field is not empty.
        Returns:
            bool: True if it has a valid pid.
        """
        if self.text_field.text() != "":
            return True
        return False

    def update_dots(self, i):
        """Uses the velocity values of the dot and the detector class in order
        to update the position of the dots.
        Args:
            i: generator for the animation
        """
        detector = BoundaryCollisionDetector(self)
        for dot in self.dots:
            dot.update_position(DT)
            dot.colliding = detector.detect_collision(dot)
        for dot in self.dots:
            detector.update_velocity(dot)
        return self.dots

    def _blink_stage(self, i):
        """Uses stored values to use one animation function to first blink
        then move the dots.
        Args:
            i: the current iteration value
        Returns:
            blinking: a boolean which denotes if it is in the blinking stage.
        """
        num_iter = int(TRIAL_DURATION / INTERVAL)
        if (i < int(num_iter / TRIAL_DURATION * BLINKING_DURATION)):
            return True
        else:
            return False

    def conduct_subtrial(self, i):
        """Runs the animation of a sub trial. Modifies a instance variable which
        keeps track whether the trial is active.
        Args:
            sub_id: uses the TRIAL_DICTIONARY to determine the setup.
            i: generator for the animation
        """
        if (self._blink_stage(i)):
            if i == 0:
                #self.stop_button.setEnabled(False)
                self.next_button.setEnabled(False)
                self.dot_motion_active = True
            if i % 6 == 0:
                self.blink_dots(i)
        elif (i + 1 == int(TRIAL_DURATION / INTERVAL)):
            self.dot_motion_active = False
            self.clicking_active = True
            self.next_button.setEnabled(False)
            self.trial_starts.append(time.time())
        else:
            self.update_dots(i)

    def animate_plot(self):
        """Wrapper that runs the animation.
        """
        self.dot_ani = animation.FuncAnimation(self.figure,
                                               self.conduct_subtrial,
                                               int(TRIAL_DURATION / INTERVAL),
                                               interval=INTERVAL,
                                               repeat=False)

    def blink_dots(self, i):
        """Uses the dictionary of tracked dots and blinks them.
        """
        for dot in self.tracked_dots.values():
            if dot.color == BLINKING_COLOR:
                dot.set_color(COLOR)
                dot.color = COLOR
            else:
                dot.set_color(BLINKING_COLOR)
                dot.color = BLINKING_COLOR

    def setup_figure(self):
        '''Setup figure to eliminate the toolbar and resizing.
        Presents the figure.
        '''
        # Remove toolbar
        mpl.rcParams['toolbar'] = 'None'

        # instead of ax.hold(False)
        self.figure.clear()

        self.ax = self.figure.add_subplot(111)

        # Remove ticks and labels on axis
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.ax.set_xlim([0, 1])
        self.ax.set_ylim([0, 1])
        self.ax.axis('off')

        # Remove margin on plot.
        self.figure.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.00)

        # refresh canvas
        self.canvas.draw()
        self.cid = self.figure.canvas.mpl_connect('button_release_event',
                                                  self.onrelease)
        self.cid = self.figure.canvas.mpl_connect('button_press_event',
                                                  self.onclick)
        self.cid = self.figure.canvas.mpl_connect('motion_notify_event',
                                                  self.onmouse)

    def _distance(self, loc1, loc2):
        """Outputs distance between two tuples (x, y)
        Args:
            loc1/2: locations for distances to be computed
        Returns:
            dist: float value of distance
        """
        return np.sqrt((loc1[0] - loc2[0])**2 + (loc1[1] - loc2[1])**2)

    def update_info_label(self):
        new_string = "Trial #{} -- {} Clicks Left".format(
            self.trial_id, self.trial_clicks)
        self.info_label.setText(new_string)
        self.info_label.repaint()

    def end_trial(self):
        """End the trial of 15 subtrials and reset the page. Also write output
        file.
        """
        self.total_duration = time.time() - self.total_duration
        self.next_button.setEnabled(False)
        self.track_button.setEnabled(True)
        self.track_button.setText("Begin Tracking")
        #self.stop_button.setEnabled(False)
        self.text_field.setReadOnly(False)
        self.clicked_dots = []
        self.trial_clicks = self.trial_dictionary[self.trial_id]
        self.clicking_active = False
        trial_duration_string = [
            format(x * 1000, '.0f') for x in self.trial_durations
        ]
        correct_dots_string = [str(x) for x in self.correct_dots]
        with open('output_file.txt', 'w') as f:
            f.write(self.text_field.text())
            f.write('\n')
            f.write(",".join(trial_duration_string))
            f.write('\n')
            f.write(",".join(correct_dots_string))

    def dot_clicked(self):
        """Updates the information label and the number of clicks left not in that
        order :)
        """
        self.trial_clicks -= 1
        self.update_info_label()
        if self.trial_clicks == 0:
            self.trial_durations.append(time.time() - self.trial_starts[-1])
            for dot in self.dots:
                if (dot.id in self.tracked_dots
                        and dot not in self.clicked_dots):
                    dot.set_color(UNSELECTED_COLOR)
                    self.canvas.draw()
            if self.trial_id == 15:
                self.trial_id = 0
                self.end_trial()
            else:
                self.trial_id += 1
                self.clicked_dots = []
                self.trial_clicks = self.trial_dictionary[self.trial_id]
                self.clicking_active = False
                self.next_button.setEnabled(True)
                #self.stop_button.setEnabled(True)

    def detect_clicked_dot(self, dots_list, event):
        """Uses an events location and returns a mpl.patches object corresponding
        to the click.
        Args:
            dots_list: A list of circle patch objects
            event: The event which triggered the call (contains location)
        Returns:
            circle: mpl.patches.Circle object which was clicked or None
        """
        circle = None
        for dot in dots_list:
            if self._distance(dot.center,
                              (event.xdata, event.ydata)) < dot.radius:
                if circle is not None:
                    if self._distance(
                            dot.center,
                        (event.xdata, event.ydata)) < self._distance(
                            circle.center, (event.xdata, event.ydata)):
                        circle = dot
                else:
                    circle = dot
        return circle

    def onclick(self, event):
        '''When a button is clicked, and while it is being held down it will be a
        different color.
        '''
        self.mouse_pressed = True
        selected_dot = self.detect_clicked_dot(self.dots, event)
        if (selected_dot is not None and selected_dot not in self.clicked_dots
                and self.valid_click):
            self.highlighted_dot = selected_dot
            selected_dot.set_color(SELECTION_COLOR)
            self.canvas.draw()

    def onrelease(self, event):
        '''When the button is released, the selection is made if the release
        happens inside of a dot.
        '''

        # TODO, use _distance to determine if highlighted_dot is also an option
        # for clicked dot. if so then click that dot.
        self.mouse_pressed = False
        if (event.xdata is not None):
            selected_dot = self.detect_clicked_dot(self.dots, event)
        else:
            selected_dot = None

        if (self.highlighted_dot is not None
                and self._distance(self.highlighted_dot.center,
                                   (event.xdata, event.ydata)) < 0.3):
            selected_dot = self.highlighted_dot
        if (selected_dot is not None and selected_dot not in self.clicked_dots
                and self.valid_click):
            if (selected_dot.id in self.tracked_dots):
                self.clicked_dots.append(selected_dot)
                selected_dot.set_color(BLINKING_COLOR)
                self.canvas.draw()
            else:
                self.clicked_dots.append(selected_dot)
                selected_dot.set_color(INCORRECT_COLOR)
                self.correct_dots[self.trial_id] -= 1
                self.canvas.draw()
            self.dot_clicked()
        elif (selected_dot is None and self.highlighted_dot is not None
              and self.highlighted_dot not in self.clicked_dots):
            self.highlighted_dot.set_color(COLOR)
            self.canvas.draw()
        self.highlighted_dot = None

    def onmouse(self, event):
        '''If the mouse moves while the user has the cursor pressed the selection should be
Example #30
0
def save_png(fig: Figure,
             path: Union[None, str, pathlib.Path],
             width: Union[int, float] = None,
             height: Union[int, float] = None,
             unit: str = 'px',
             print_info: bool = False) -> Union[str, io.BytesIO]:
    """
    Save PNG image of the figure.

    :param fig:        Figure to save.
    :param path:       Full path of the image to save. If directory (string ending in slash - '/' or '\\') then
                       the figure window title is used as a file name. If `None`, in-memory :class:`io.BytesIO`
                       file will be generated and returned.

    :param width:      Image width in `unit`. If not provided it will be left as it is.
    :param height:     Image height in `unit`. If not provided it will be left as it is.
    :param unit:       Unit of the image width and height, one of: 'px' (pixels), 'cm' (centimeters), 'in' (inch).

    :param print_info: Whether to print information about saved file.

    :return: Full path of the generated image if `path` was provided or in-memory :class:`io.BytesIO` file.
    """
    if path:
        directory, file_name = os.path.split(path)
        # Create the directory if not existent
        os.makedirs(directory, exist_ok=True)
        # If the provided path is only a directory, use window title as filename
        if not file_name:
            file_name = get_window_title(fig)
        # Image path must have .png extension!
        if os.path.splitext(file_name)[1] != ".png":
            file_name += ".png"
        path = os.path.join(directory, file_name)

    dpi = fig.get_dpi()

    if width or height:
        size = fig.get_size_inches()

        if unit == 'px':
            fig.set_size_inches((width or size[0]) / dpi,
                                (height or size[1]) / dpi)

        elif unit in ('mm', 'cm', 'in', 'inch'):
            if unit == 'mm':
                width /= 25.4
                height /= 25.4
            elif unit == 'cm':
                width /= 2.54
                height /= 2.54
            # Unit is inches.
            fig.set_size_inches(width or size[0], height or size[1])

        else:
            raise ValueError(f"Unsupported size unit '{unit}'")

    width = fig.get_figwidth()
    height = fig.get_figheight()
    width_px = int(round(width * dpi))
    height_px = int(round(height * dpi))
    width_mm = width * 25.4
    height_mm = height * 25.4

    if path:
        fig.savefig(path, dpi=dpi)
        ret = path
        if print_info:
            print(
                f"Saved plot ({width_px}x{height_px} px = {width_mm:.1f}x{height_mm:.1f} mm @ {dpi} dpi)"
                f" to '{os.path.normpath(path)}'")
    else:
        file = io.BytesIO()
        fig.savefig(file, dpi=dpi)
        file.seek(0)
        ret = file

    return ret
Example #31
0
class asaplotbase:
    """
    ASAP plotting base class based on matplotlib.
    """

    def __init__(self, rows=1, cols=0, title='', size=None, buffering=False):
        """
        Create a new instance of the ASAPlot plotting class.

        If rows < 1 then a separate call to set_panels() is required to define
        the panel layout; refer to the doctext for set_panels().
        """
        self.is_dead = False
        self.figure = Figure(figsize=size, facecolor='#ddddee')
        self.canvas = None

        self.set_title(title)
        self.subplots = []
        if rows > 0:
            self.set_panels(rows, cols)

        # Set matplotlib default colour sequence.
        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()

        c = asaprcParams['plotter.colours']
        if isinstance(c,str) and len(c) > 0:
            self.colormap = c.split()
        # line styles need to be set as a list of numbers to use set_dashes
        self.lsalias = {"line":  [1,0],
                        "dashdot": [4,2,1,2],
                        "dashed" : [4,2,4,2],
                        "dotted" : [1,2],
                        "dashdotdot": [4,2,1,2,1,2],
                        "dashdashdot": [4,2,4,2,1,2]
                        }

        styles = "line dashed dotted dashdot".split()
        c = asaprcParams['plotter.linestyles']
        if isinstance(c,str) and len(c) > 0:
            styles = c.split()
        s = []
        for ls in styles:
            if self.lsalias.has_key(ls):
                s.append(self.lsalias.get(ls))
            else:
                s.append('-')
        self.linestyles = s

        self.color = 0;
        self.linestyle = 0;
        self.attributes = {}
        self.loc = 0

        self.buffering = buffering

        self.events = {'button_press':None,
                       'button_release':None,
                       'motion_notify':None}

    def _alive(self):
        # Return True if the GUI alives.
        if (not self.is_dead) and \
               self.figmgr and hasattr(self.figmgr, "num"):
            figid = self.figmgr.num
            # Make sure figid=0 is what asapplotter expects.
            # It might be already destroied/overridden by matplotlib
            # commands or other methods using asaplot.
            return _pylab_helpers.Gcf.has_fignum(figid) and \
                   (self.figmgr == _pylab_helpers.Gcf.get_fig_manager(figid))
        return False

    def _subplotsOk(self, rows, cols, npanel=0):
        """
        Check if the axes in subplots are actually the ones plotted on
        the figure. Returns a bool.
        This method is to detect manual layout changes using mpl methods.
        """
        # compare with user defined layout
        if (rows is not None) and (rows != self.rows):
            return False
        if (cols is not None) and (cols != self.cols):
            return False
        # check number of subplots
        figaxes = self.figure.get_axes()
        np = self.rows*self.cols
        if npanel > np:
            return False
        if len(figaxes) != np:
            return False
        if len(self.subplots) != len(figaxes):
            return False
        # compare axes instance in this class and on the plotter
        ok = True
        for ip in range(np):
            if self.subplots[ip]['axes'] != figaxes[ip]:
                ok = False
                break
        return ok

    ### Delete artists ###
    def clear(self):
        """
        Delete all lines from the current subplot.
        Line numbering will restart from 0.
        """

        #for i in range(len(self.lines)):
        #   self.delete(i)
        self.axes.clear()
        self.color = 0
        self.linestyle = 0
        self.lines = []
        self.subplots[self.i]['lines'] = self.lines

    def delete(self, numbers=None):
        """
        Delete the 0-relative line number, default is to delete the last.
        The remaining lines are NOT renumbered.
        """

        if numbers is None: numbers = [len(self.lines)-1]

        if not hasattr(numbers, '__iter__'):
            numbers = [numbers]

        for number in numbers:
            if 0 <= number < len(self.lines):
                if self.lines[number] is not None:
                    for line in self.lines[number]:
                        line.set_linestyle('None')
                        self.lines[number] = None
        self.show()


    ### Set plot parameters ###
    def hold(self, hold=True):
        """
        Buffer graphics until subsequently released.
        """
        self.buffering = hold

    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
        if colormap:
            if isinstance(colormap,list):
                self.colormap = colormap
            elif isinstance(colormap,str):
                self.colormap = colormap.split()
        if 0 <= color < len(self.colormap):
            self.color = color
        if linestyles:
            self.linestyles = []
            if isinstance(linestyles,list):
                styles = linestyles
            elif isinstance(linestyles,str):
                styles = linestyles.split()
            for ls in styles:
                if self.lsalias.has_key(ls):
                    self.linestyles.append(self.lsalias.get(ls))
                else:
                    self.linestyles.append(self.lsalias.get('line'))
        if 0 <= linestyle < len(self.linestyles):
            self.linestyle = linestyle

    def legend(self, loc=None):
        """
        Add a legend to the plot.

        Any other value for loc else disables the legend:
             1: upper right
             2: upper left
             3: lower left
             4: lower right
             5: right
             6: center left
             7: center right
             8: lower center
             9: upper center
            10: center

        """
        if isinstance(loc, int):
            self.loc = None
            if 0 <= loc <= 10: self.loc = loc
        else:
            self.loc = None
        #self.show()

    #def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, margin=None,ganged=True):
        """
        Set the panel layout.

        rows and cols, if cols != 0, specify the number of rows and columns in
        a regular layout.   (Indexing of these panels in matplotlib is row-
        major, i.e. column varies fastest.)

        cols == 0 is interpreted as a retangular layout that accomodates
        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
        rows == 2, cols == 3.

        0 <= n < rows*cols is interpreted as the 0-relative panel number in
        the configuration specified by rows and cols to be added to the
        current figure as its next 0-relative panel number (i).  This allows
        non-regular panel layouts to be constructed via multiple calls.  Any
        other value of n clears the plot and produces a rectangular array of
        empty panels.  The number of these may be limited by nplots.
        """
        if n < 0 and len(self.subplots):
            self.figure.clear()
            self.set_title()

        if margin:
            lef, bot, rig, top, wsp, hsp = margin
            self.figure.subplots_adjust(
                left=lef,bottom=bot,right=rig,top=top,wspace=wsp,hspace=hsp)
            del lef,bot,rig,top,wsp,hsp

        if rows < 1: rows = 1

        if cols <= 0:
            i = int(sqrt(rows))
            if i*i < rows: i += 1
            cols = i

            if i*(i-1) >= rows: i -= 1
            rows = i

        if 0 <= n < rows*cols:
            i = len(self.subplots)

            self.subplots.append({})

            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
                                            cols, n+1)
            self.subplots[i]['lines'] = []

            if i == 0: self.subplot(0)

            self.rows = 0
            self.cols = 0

        else:
            self.subplots = []

            if nplots < 1 or rows*cols < nplots:
                nplots = rows*cols
            if ganged:
                hsp,wsp = None,None
                if rows > 1: hsp = 0.0001
                if cols > 1: wsp = 0.0001
                self.figure.subplots_adjust(wspace=wsp,hspace=hsp)
            for i in range(nplots):
                self.subplots.append({})
                self.subplots[i]['lines'] = []
                if not ganged:
                    self.subplots[i]['axes'] = self.figure.add_subplot(rows,
                                                cols, i+1)
                    if asaprcParams['plotter.axesformatting'] != 'mpl':
                        self.subplots[i]['axes'].xaxis.set_major_formatter(OldScalarFormatter())
                else:
                    if i == 0:
                        self.subplots[i]['axes'] = self.figure.add_subplot(rows,
                                                cols, i+1)
                        if asaprcParams['plotter.axesformatting'] != 'mpl':

                            self.subplots[i]['axes'].xaxis.set_major_formatter(OldScalarFormatter())
                    else:
                        self.subplots[i]['axes'] = self.figure.add_subplot(rows,
                                                cols, i+1,
                                                sharex=self.subplots[0]['axes'],
                                                sharey=self.subplots[0]['axes'])

                    # Suppress tick labelling for interior subplots.
                    if i <= (rows-1)*cols - 1:
                        if i+cols < nplots:
                            # Suppress x-labels for frames width
                            # adjacent frames
                            for tick in self.subplots[i]['axes'].xaxis.majorTicks:
                                tick.label1On = False
                            #self.subplots[i]['axes'].xaxis.label.set_visible(False)
                    if i%cols:
                        # Suppress y-labels for frames not in the left column.
                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
                            tick.label1On = False
                        #self.subplots[i]['axes'].yaxis.label.set_visible(False)
                    # disable the first tick of [1:ncol-1] of the last row
                    #if i+1 < nplots:
                    #    self.subplots[i]['axes'].xaxis.majorTicks[0].label1On = False
                # set axes label state for interior subplots.
                if i%cols:
                    self.subplots[i]['axes'].yaxis.label.set_visible(False)
                if (i <= (rows-1)*cols - 1) and (i+cols < nplots):
                    self.subplots[i]['axes'].xaxis.label.set_visible(False)
            self.rows = rows
            self.cols = cols
            self.subplot(0)
        del rows,cols,n,nplots,margin,ganged,i

    def subplot(self, i=None, inc=None):
        """
        Set the subplot to the 0-relative panel number as defined by one or
        more invokations of set_panels().
        """
        l = len(self.subplots)
        if l:
            if i is not None:
                self.i = i

            if inc is not None:
                self.i += inc

            self.i %= l
            self.axes  = self.subplots[self.i]['axes']
            self.lines = self.subplots[self.i]['lines']

    def set_axes(self, what=None, *args, **kwargs):
        """
        Set attributes for the axes by calling the relevant Axes.set_*()
        method.  Colour translation is done as described in the doctext
        for palette().
        """

        if what is None: return
        if what[-6:] == 'colour': what = what[:-6] + 'color'

        key = "colour"
        if kwargs.has_key(key):
            val = kwargs.pop(key)
            kwargs["color"] = val

        getattr(self.axes, "set_%s"%what)(*args, **kwargs)

        self.show(hardrefresh=False)


    def set_figure(self, what=None, *args, **kwargs):
        """
        Set attributes for the figure by calling the relevant Figure.set_*()
        method.  Colour translation is done as described in the doctext
        for palette().
        """

        if what is None: return
        if what[-6:] == 'colour': what = what[:-6] + 'color'
        #if what[-5:] == 'color' and len(args):
        #    args = (get_colour(args[0]),)

        newargs = {}
        for k, v in kwargs.iteritems():
            k = k.lower()
            if k == 'colour': k = 'color'
            newargs[k] = v

        getattr(self.figure, "set_%s"%what)(*args, **newargs)
        self.show(hardrefresh=False)


    def set_limits(self, xlim=None, ylim=None):
        """
        Set x-, and y-limits for each subplot.

        xlim = [xmin, xmax] as in axes.set_xlim().
        ylim = [ymin, ymax] as in axes.set_ylim().
        """
        for s in self.subplots:
            self.axes  = s['axes']
            self.lines = s['lines']
            oldxlim =  list(self.axes.get_xlim())
            oldylim =  list(self.axes.get_ylim())
            if xlim is not None:
                for i in range(len(xlim)):
                    if xlim[i] is not None:
                        oldxlim[i] = xlim[i]
            if ylim is not None:
                for i in range(len(ylim)):
                    if ylim[i] is not None:
                        oldylim[i] = ylim[i]
            self.axes.set_xlim(oldxlim)
            self.axes.set_ylim(oldylim)
        return


    def set_line(self, number=None, **kwargs):
        """
        Set attributes for the specified line, or else the next line(s)
        to be plotted.

        number is the 0-relative number of a line that has already been
        plotted.  If no such line exists, attributes are recorded and used
        for the next line(s) to be plotted.

        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do

            import matplotlib
            help(matplotlib.lines)

        The set_* methods of class Line2D define the attribute names and
        values.  For non-US usage, 'colour' is recognized as synonymous with
        'color'.

        Set the value to None to delete an attribute.

        Colour translation is done as described in the doctext for palette().
        """

        redraw = False
        for k, v in kwargs.iteritems():
            k = k.lower()
            if k == 'colour': k = 'color'

            if 0 <= number < len(self.lines):
                if self.lines[number] is not None:
                    for line in self.lines[number]:
                        getattr(line, "set_%s"%k)(v)
                    redraw = True
            else:
                if v is None:
                    del self.attributes[k]
                else:
                    self.attributes[k] = v

        if redraw: self.show(hardrefresh=False)


    def get_line(self):
        """
        Get the current default line attributes.
        """
        return self.attributes


    ### Actual plot methods ###
    def hist(self, x=None, y=None, fmt=None, add=None):
        """
        Plot a histogram.  N.B. the x values refer to the start of the
        histogram bin.

        fmt is the line style as in plot().
        """
        from numpy import array
        from numpy.ma import MaskedArray
        if x is None:
            if y is None: return
            x = range(len(y))

        if len(x) != len(y):
            return
        l2 = 2*len(x)
        x2 = range(l2)
        y2 = range(12)
        y2 = range(l2)
        m2 = range(l2)
        ymsk = None
        ydat = None
        if hasattr(y, "raw_mask"):
            # numpy < 1.1
            ymsk = y.raw_mask()
            ydat = y.raw_data()
        else:
            ymsk = y.mask
            ydat = y.data
        for i in range(l2):
            x2[i] = x[i/2]
            m2[i] = ymsk[i/2]

        y2[0] = 0.0
        for i in range(1,l2):
            y2[i] = ydat[(i-1)/2]

        self.plot(x2, MaskedArray(y2,mask=m2,copy=0), fmt, add)


    def plot(self, x=None, y=None, fmt=None, add=None):
        """
        Plot the next line in the current frame using the current line
        attributes.  The ASAPlot graphics window will be mapped and raised.

        The argument list works a bit like the matlab plot() function.
        """
        if x is None:
            if y is None: return
            x = range(len(y))

        elif y is None:
            y = x
            x = range(len(y))
        if fmt is None:
            line = self.axes.plot(x, y)
        else:
            line = self.axes.plot(x, y, fmt)
        # add a picker to lines for spectral value mode.
        # matplotlib.axes.plot returns a list of line object (1 element)
        line[0].set_picker(5.0)

        # Add to an existing line?
        i = None
        if add is None or len(self.lines) < add < 0:
            # Don't add.
            self.lines.append(line)
            i = len(self.lines) - 1
        else:
            if add == 0: add = len(self.lines)
            i = add - 1
            self.lines[i].extend(line)

        # Set/reset attributes for the line.
        gotcolour = False
        for k, v in self.attributes.iteritems():
            if k == 'color': gotcolour = True
            for segment in self.lines[i]:
                getattr(segment, "set_%s"%k)(v)

        if not gotcolour and len(self.colormap):
            for segment in self.lines[i]:
                getattr(segment, "set_color")(self.colormap[self.color])
                if len(self.colormap)  == 1:
                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])

            self.color += 1
            if self.color >= len(self.colormap):
                self.color = 0

            if len(self.colormap) == 1:
                self.linestyle += 1
            if self.linestyle >= len(self.linestyles):
                self.linestyle = 0

        self.show()


    def tidy(self):
        # this needs to be exceuted after the first "refresh"
        nplots = len(self.subplots)
        if nplots == 1: return
        for i in xrange(nplots):
            ax = self.subplots[i]['axes']
            if i%self.cols:
                ax.xaxis.majorTicks[0].label1On = False
            else:
                if i != 0:
                    ax.yaxis.majorTicks[-1].label1On = False
            ## set axes label state for interior subplots.
            #innerax=False
            #if i%self.cols:
            #    ax.yaxis.label.set_visible(innerax)
            #if (i <= (self.rows-1)*self.cols - 1) and (i+self.cols < nplots):
            #    ax.xaxis.label.set_visible(innerax)
            

    def set_title(self, title=None):
        """
        Set the title of the plot window.  Use the previous title if title is
        omitted.
        """
        if title is not None:
            self.title = title

        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')


    def text(self, *args, **kwargs):
        """
        Add text to the figure.
        """
        self.figure.text(*args, **kwargs)
        self.show()

    def vline_with_label(self, x, y, label,
                         location='bottom', rotate=0.0, **kwargs):
        """
        Plot a vertical line with label.
        It takes 'world' values fo x and y.
        """
        ax = self.axes
        # need this to suppress autoscaling during this function
        self.axes.set_autoscale_on(False)
        ymin = 0.0
        ymax = 1.0
        valign = 'center'
        if location.lower() == 'top':
            y = max(0.0, y)
        elif location.lower() == 'bottom':
            y = min(0.0, y)
        lbloffset = 0.06
        # a rough estimate for the bb of the text
        if rotate > 0.0: lbloffset = 0.03*len(label)
        peakoffset = 0.01
        xy = None
        xy0 = None
        # matplotlib api change 0.98 is using transform now
        if hasattr(ax.transData, "inverse_xy_tup"):
            # get relative coords
            xy0 = ax.transData.xy_tup((x,y))
            xy = ax.transAxes.inverse_xy_tup(xy0)
        else:
            xy0 = ax.transData.transform((x,y))
            # get relative coords
            xy = ax.transAxes.inverted().transform(xy0)
        if location.lower() == 'top':
            ymax = 1.0-lbloffset
            ymin = xy[1]+peakoffset
            valign = 'bottom'
            ylbl = ymax+0.01
        elif location.lower() == 'bottom':
            ymin = lbloffset
            ymax = xy[1]-peakoffset
            valign = 'top'
            ylbl = ymin-0.01
        trans = blended_transform_factory(ax.transData, ax.transAxes)
        l = ax.axvline(x, ymin, ymax, color='black', **kwargs)
        t = ax.text(x, ylbl ,label, verticalalignment=valign,
                                    horizontalalignment='center',
                    rotation=rotate,transform = trans)
        self.axes.set_autoscale_on(True)

    def release(self):
        """
        Release buffered graphics.
        """
        self.buffering = False
        self.show()


    def show(self, hardrefresh=True):
        """
        Show graphics dependent on the current buffering state.
        """
        if not hardrefresh: return
        if not self.buffering:
            if self.loc is not None:
                for sp in self.subplots:
                    lines  = []
                    labels = []
                    i = 0
                    for line in sp['lines']:
                        i += 1
                        if line is not None:
                            lines.append(line[0])
                            lbl = line[0].get_label()
                            if lbl == '':
                                lbl = str(i)
                            labels.append(lbl)

                    if len(lines):
                        fp = FP(size=rcParams['legend.fontsize'])
                        #fsz = fp.get_size_in_points() - len(lines)
                        fsz = fp.get_size_in_points() - max(len(lines),self.cols)
                        #fp.set_size(max(fsz,6))
                        fp.set_size(max(fsz,8))
                        sp['axes'].legend(tuple(lines), tuple(labels),
                                          self.loc, prop=fp)
                    #else:
                    #    sp['axes'].legend((' '))

            from matplotlib.artist import setp
            fpx = FP(size=rcParams['xtick.labelsize'])
            xts = fpx.get_size_in_points()- (self.cols)/2
            fpy = FP(size=rcParams['ytick.labelsize'])
            yts = fpy.get_size_in_points() - (self.rows)/2
            fpa = FP(size=rcParams['axes.labelsize'])
            fpat = FP(size=rcParams['axes.titlesize'])
            axsize =  fpa.get_size_in_points()
            tsize =  fpat.get_size_in_points()-(self.cols)/2
            for sp in self.subplots:
                ax = sp['axes']
                ax.title.set_size(tsize)
                setp(ax.get_xticklabels(), fontsize=xts)
                setp(ax.get_yticklabels(), fontsize=yts)
                off = 0
                if self.cols > 1: off = self.cols
                ax.xaxis.label.set_size(axsize-off)
                off = 0
                if self.rows > 1: off = self.rows
                ax.yaxis.label.set_size(axsize-off)

    def save(self, fname=None, orientation=None, dpi=None, papertype=None):
        """
        Save the plot to a file.

        fname is the name of the output file.  The image format is determined
        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
        directory.
        """
        from asap import rcParams
        if papertype is None:
            papertype = rcParams['plotter.papertype']
        if fname is None:
            from datetime import datetime
            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
            fname = 'asap'+dstr+'.png'

        d = ['png','.ps','eps', 'svg']

        from os.path import expandvars
        fname = expandvars(fname)

        if fname[-3:].lower() in d:
            try:
                if fname[-3:].lower() == ".ps":
                    from matplotlib import __version__ as mv
                    w = self.figure.get_figwidth()
                    h = self.figure.get_figheight()

                    if orientation is None:
                        # oriented
                        if w > h:
                            orientation = 'landscape'
                        else:
                            orientation = 'portrait'
                    from matplotlib.backends.backend_ps import papersize
                    pw,ph = papersize[papertype.lower()]
                    ds = None
                    if orientation == 'landscape':
                        ds = min(ph/w, pw/h)
                    else:
                        ds = min(pw/w, ph/h)
                    ow = ds * w
                    oh = ds * h
                    self.figure.set_size_inches((ow, oh))
                    self.figure.savefig(fname, orientation=orientation,
                                        papertype=papertype.lower())
                    self.figure.set_size_inches((w, h))
                    print 'Written file %s' % (fname)
                else:
                    if dpi is None:
                        dpi =150
                    self.figure.savefig(fname,dpi=dpi)
                    print 'Written file %s' % (fname)
            except IOError, msg:
                #print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
                asaplog.post()
                asaplog.push('Failed to save %s: Error msg was\n\n%s' % (fname, str(msg)))
                asaplog.post( 'ERROR' )
                return
        else:
Example #32
0
class MatplotlibWidget(FigureCanvas):
    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent=None, name=None, width=5, height=4, dpi=100,
                 bgcolor=None):
        QtGui.QWidget.__init__(self, parent)

	self.parent = parent

        self.fig = Figure(figsize=(width, height), dpi=dpi,
                          facecolor=bgcolor, edgecolor=bgcolor)
        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)
        self.setSizePolicy(QtGui.QSizePolicy.Expanding,
                           QtGui.QSizePolicy.Expanding)
        self.updateGeometry()


    def sizeHint(self):
        w = self.fig.get_figwidth()*self.fig.get_dpi()
        h = self.fig.get_figheight()*self.fig.get_dpi()
        return QtCore.QSize(w, h)

    @QtCore.pyqtSlot()
    def clear(self):
        ''' Clears the figure. '''
        self.fig.clear()


    @QtCore.pyqtSlot()
    def reset(self):
        '''
        Clears the figure and prepares a new plot.
        The difference between this and clear() is that
        the latter only clears the figure, while this
        also prepares the canvas for the plot commands.
        '''
        self.clear()
        self.axes = self.fig.add_subplot(111)
        self.axes.hold(False)  # We want the axes cleared every time plot() is called


    @QtCore.pyqtSlot (wave.Wave)
    def plot1d (self, data, redraw=True):
        '''
        Called to plot the specified wave. It will be passed
        to matplotlib's plot() as it is. This typically means
        that if it's a higher-D matrix, it will be plotted as a
        series of 1D graphs.
        '''
        w = wave.WCast(data)
        self.axes.plot(w.dim[0].range, w)
        if redraw == True:
                self.draw()


    @QtCore.pyqtSlot (wave.Wave)
    def plot2d (self, data, redraw=True):
        '''
        Called to plot the specified 2D wave. Uses matplotlib's
        imshow() to show the specified image.
        '''
        self.axes.imshow(data, aspect='auto', extent=wave.WCast(data).imlim)

        if redraw == True:
                self.draw()


    @QtCore.pyqtSlot(wave.Wave)
    def plot(self, data, redraw=True):
        '''
        Convenience wrapper for plot1d() or plot2d().
        Assuming that 'data' is one single Wave (or ndarray object),
        it calls plot1d() or plot2d(), depending on the dimensionality
        of the data.
        '''

        if hasattr(data, '__iter__') and not hasattr(data, 'ndim'):
            data_list = data
            data = data_list[0]

        if not hasattr(data, 'ndim'):
            log.error ("Don't know how to plot data type: %s" % data)
            return

        if data.ndim == 1:
            self.plot1d(data, redraw)
        elif data.ndim == 2:
            self.plot2d(data, redraw)
        else:
            self.axes.clear()
            w = data.view(wave.Wave)
            axinfo_str = ''.join([ ("axis %d: %f...%f (units: '%s')\n" 
                                    % (j, i.offset, i.end, i.units))
                                   for i,j in zip(w.dim, range(w.ndim))])
            self.axes.text (0.05, 0.95, "Don't know how to display wave!\n\nname: "
                            "%s\ndimensions: %d\n%s\n%s" % 
                            (w.infs('name'), w.ndim, axinfo_str, pprint.pformat(w.info)),
                            transform=self.axes.transAxes, va='top')
            self.draw()
def gui(*args):
    if len(args) < 1:
        fig = Figure(figsize=(10, 6))
        ax = fig.add_subplot(121)
        ax.set_xlabel("X axis")
        ax.set_ylabel("Y axis")
        pts = 1000

        mean1, mean2, var1, var2 = 4, 7, 1, 1
        x = [random.gauss(4, 1) for _ in range(pts)]
        y = [random.gauss(7, 1) for _ in range(pts)]
        bins = np.linspace(0, 10, 100)
        ax.hist(x, bins, alpha=0.5, label='x', color='pink')
        ax.hist(y, bins, alpha=0.5, label='y', color='deepskyblue')
        #        ax.plot([1,2,3,4,5,6],[2,2,5,5,3,4])
        ax.legend(loc='upper right')
        ax.grid(True)
        ax.tick_params(labelcolor='white',
                       top='on',
                       bottom='on',
                       left='on',
                       right='on')

        ax2 = fig.add_subplot(122)
        x = [i for i in range(1000)]
        y = [random.gauss(7, 1) for _ in range(pts)]
        ax2.plot(x, y)
        ax2.grid(False)

    #Otherwise grab the figure passed into the function
    if len(args) == 1:
        fig = args[0]
    #These are for global xlabels and other options later. A hidden axes
    globalax = fig.add_subplot(
        111, frame_on=False
    )  #ends up being "final" axes of fig.axes[-1] for global settings
    globalax.grid(False)
    globalax.tick_params(labelcolor='none',
                         top=False,
                         bottom=False,
                         left=False,
                         right=False)

    #getting info for dynamic GUI rendering listbox
    CURRENT_SUBPLOT = "global"
    subplot_strs = [f"subplot_{i}" for i in range(len(fig.axes))]
    subplot_strs = ["global"] + subplot_strs[:-1]

    column_legend = [[
        sg.Radio('TL', 'legendary', enable_events=True, key='-TL-'),
        sg.Radio('T', 'legendary', enable_events=True, key='-T-'),
        sg.Radio('TR', 'legendary', enable_events=True, key='-TR-')
    ],
                     [
                         sg.Radio('ML',
                                  'legendary',
                                  enable_events=True,
                                  key='-ML-'),
                         sg.Radio('M',
                                  'legendary',
                                  enable_events=True,
                                  key='-M-'),
                         sg.Radio('MR',
                                  'legendary',
                                  enable_events=True,
                                  key='-MR-'),
                         sg.Radio('No Legend',
                                  'legendary',
                                  enable_events=True,
                                  key='-NOLEGEND-')
                     ],
                     [
                         sg.Radio('BL',
                                  'legendary',
                                  enable_events=True,
                                  key='-BL-'),
                         sg.Radio('B',
                                  'legendary',
                                  enable_events=True,
                                  key='-B-'),
                         sg.Radio('BR',
                                  'legendary',
                                  enable_events=True,
                                  key='-BR-')
                     ]]

    #the main settings column for the program
    column1_frame = [[
        sg.Text('X Label'),
        sg.Input(key='-XLABEL-', enable_events=True, size=(30, 14))
    ],
                     [
                         sg.Text('Y Label'),
                         sg.Input(key='-YLABEL-',
                                  enable_events=True,
                                  size=(30, 14))
                     ],
                     [
                         sg.Text('XLimits', key='XLIM'),
                         sg.Input(key='-XMIN-',
                                  enable_events=True,
                                  size=(4, 14)),
                         sg.Input(key='-XMAX-',
                                  enable_events=True,
                                  size=(4, 14)),
                         sg.Text('YLimits'),
                         sg.Input(key='-YMIN-',
                                  enable_events=True,
                                  size=(4, 14)),
                         sg.Input(key='-YMAX-',
                                  enable_events=True,
                                  size=(4, 14))
                     ],
                     [
                         sg.Text('Ticks', key='-TICKTEXT-'),
                         sg.Checkbox("left",
                                     enable_events=True,
                                     key='-LEFTTICK-',
                                     default=True),
                         sg.Checkbox("bottom",
                                     enable_events=True,
                                     key='-BOTTOMTICK-',
                                     default=True),
                         sg.Checkbox("right",
                                     enable_events=True,
                                     key='-RIGHTTICK-',
                                     default=True),
                         sg.Checkbox("top",
                                     enable_events=True,
                                     key='-TOPTICK-',
                                     default=True)
                     ],
                     [
                         sg.Checkbox("Grid",
                                     enable_events=True,
                                     key='-GRID-',
                                     pad=((50, 0), 0)),
                         sg.Checkbox("Frame",
                                     enable_events=True,
                                     key='-FRAME-',
                                     default=True),
                         sg.Checkbox("Frame-Part",
                                     enable_events=True,
                                     key='-FRAMEPART-',
                                     default=False)
                     ], [sg.Text("Legend"),
                         sg.Column(column_legend)]]
    column1 = [[
        sg.Text('Choose subplot:',
                justification='center',
                font='Helvetica 14',
                key='-text2-')
    ],
               [
                   sg.Listbox(values=subplot_strs,
                              key='-SUBPLOT-',
                              size=(20, 3),
                              enable_events=True)
               ],
               [
                   sg.Text('Title',
                           justification='center',
                           font='Helvetica 14',
                           key='-OUT-'),
                   sg.Input(key='-TITLE-', enable_events=True, size=(32, 14))
               ],
               [
                   sg.Text('Font size',
                           pad=(0, (14, 0)),
                           justification='center',
                           font='Helvetica 12',
                           key='-OUT2-'),
                   sg.Slider(range=(1, 32),
                             key='-TITLESIZE-',
                             enable_events=True,
                             pad=(0, 0),
                             default_value=12,
                             size=(24, 15),
                             orientation='h',
                             font=("Helvetica", 10))
               ],
               [
                   sg.Frame('General Axes Options, global',
                            [[sg.Column(column1_frame)]],
                            key='-AXESBOX-',
                            pad=(0, (14, 0)))
               ], [sg.B('Save', key='-SAVE-')]]

    #And here is where we create the layout
    sg.theme('DarkBlue')
    layout = [[
        sg.Text('Matplotlib Editor',
                size=(20, 1),
                justification='center',
                font='Helvetica 20')
    ], [sg.Column(column1)]]

    layout2 = [[
        sg.Canvas(size=(fig.get_figwidth() * 100, fig.get_figheight() * 100),
                  background_color='black',
                  key='canvas')
    ]]

    #[sg.Listbox(values=pyplot.style.available, size=(20, 6), key='-STYLE-', enable_events=True)]
    window = sg.Window('Simple GUI to envision ROC curves', layout)
    window.Finalize(
    )  # needed to access the canvas element prior to reading the window

    window_g = sg.Window("Graphing", layout2, resizable=True)
    window_g.Finalize()
    canvas_elem = window_g['canvas']
    graph = FigureCanvasTkAgg(fig, master=canvas_elem.TKCanvas)
    canvas = canvas_elem.TKCanvas

    def update_graph():
        graph.draw()
        figure_x, figure_y, figure_w, figure_h = fig.bbox.bounds
        figure_w, figure_h = int(figure_w), int(figure_h)
        photo = Tk.PhotoImage(master=canvas, width=figure_w, height=figure_h)
        canvas.image = photo
        canvas.pack(fill="both", expand=True)
        canvas.create_image(fig.get_figwidth() * 100 / 2,
                            fig.get_figheight() * 100 / 2,
                            image=photo)
        #canvas.update(size=(size(window_g)[0],size(window_g)[1]))
        figure_canvas_agg = FigureCanvasAgg(fig)
        figure_canvas_agg.draw()
        _backend_tk.blit(photo,
                         figure_canvas_agg.get_renderer()._renderer,
                         (0, 1, 2, 3))

    update_graph()

    def frame_set(ax, value):
        ax.spines["top"].set_visible(value)
        ax.spines["right"].set_visible(value)
        ax.spines["bottom"].set_visible(value)
        ax.spines["left"].set_visible(value)

### THIS IS THE MAIN BULK OF THE PROGRAM -- MANAGING EVENTS FOR THE GUI ####

    class event_manager_class:
        def switch(self, event):
            if event != 'Exit' or None:
                event = event[1:-1]
            #print(event)
            getattr(self, event)()

        def TITLE(self):
            if CURRENT_SUBPLOT == 'global':
                fig.suptitle(values['-TITLE-'])
            else:
                fig.axes[CURRENT_SUBPLOT].set_title(values['-TITLE-'])

        def TITLESIZE(self):
            fontsize = int(values['-TITLESIZE-'])
            if CURRENT_SUBPLOT == 'global':
                if type(fig._suptitle) is matplotlib.text.Text:
                    fig.suptitle(values['-TITLE-'], fontsize=fontsize)
            else:
                fig.axes[CURRENT_SUBPLOT].set_title(
                    fig.axes[CURRENT_SUBPLOT].get_title(), fontsize=fontsize)

        def XLABEL(self):
            if CURRENT_SUBPLOT == 'global':
                globalax.set_xlabel(values['-XLABEL-'])
            else:
                fig.axes[CURRENT_SUBPLOT].set_xlabel(values['-XLABEL-'])

        def YLABEL(self):
            if CURRENT_SUBPLOT == 'global':
                globalax.set_ylabel(values['-YLABEL-'])
            else:
                fig.axes[CURRENT_SUBPLOT].set_ylabel(values['-YLABEL-'])

        def XMIN(self):
            try:
                limit_update = float(values['-XMIN-'])
            except ValueError:
                return
            if CURRENT_SUBPLOT == 'global':
                [sub.set_xlim(xmin=limit_update) for sub in fig.axes]
            else:
                fig.axes[CURRENT_SUBPLOT].set_xlim(xmin=limit_update)

        def XMAX(self):
            try:
                limit_update = float(values['-XMAX-'])
            except ValueError:
                return
            if CURRENT_SUBPLOT == 'global':
                [sub.set_xlim(xmax=limit_update) for sub in fig.axes]
            else:
                fig.axes[CURRENT_SUBPLOT].set_xlim(xmax=limit_update)

        def YMIN(self):
            try:
                limit_update = float(values['-YMIN-'])
            except ValueError:
                return
            if CURRENT_SUBPLOT == 'global':
                [sub.set_ylim(ymin=limit_update) for sub in fig.axes]
            else:
                fig.axes[CURRENT_SUBPLOT].set_ylim(ymin=limit_update)

        def YMAX(self):
            try:
                limit_update = float(values['-YMAX-'])
            except ValueError:
                return
            if CURRENT_SUBPLOT == 'global':
                [sub.set_ylim(ymax=limit_update) for sub in fig.axes]
            else:
                fig.axes[CURRENT_SUBPLOT].set_ylim(ymax=limit_update)

        def LEFTTICK(self):
            if CURRENT_SUBPLOT == 'global':
                [
                    sub.tick_params(left=values['-LEFTTICK-'])
                    for sub in fig.axes[0:-1]
                ]
            else:
                fig.axes[CURRENT_SUBPLOT].tick_params(
                    left=values['-LEFTTICK-'])

        def TOPTICK(self):
            if CURRENT_SUBPLOT == 'global':
                [
                    sub.tick_params(top=values['-TOPTICK-'])
                    for sub in fig.axes[0:-1]
                ]
            else:
                fig.axes[CURRENT_SUBPLOT].tick_params(top=values['-TOPTICK-'])

        def RIGHTTICK(self):
            if CURRENT_SUBPLOT == 'global':
                [
                    sub.tick_params(right=values['-RIGHTTICK-'])
                    for sub in fig.axes[0:-1]
                ]
            else:
                fig.axes[CURRENT_SUBPLOT].tick_params(
                    right=values['-RIGHTTICK-'])

        def BOTTOMTICK(self):
            if CURRENT_SUBPLOT == 'global':
                [
                    sub.tick_params(bottom=values['-BOTTOMTICK-'])
                    for sub in fig.axes[0:-1]
                ]
            else:
                fig.axes[CURRENT_SUBPLOT].tick_params(
                    bottom=values['-BOTTOMTICK-'])

        def GRID(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.grid(values['-GRID-']) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].grid(values['-GRID-'])

        def FRAME(self):
            if CURRENT_SUBPLOT == 'global':
                [frame_set(sub, values['-FRAME-']) for sub in fig.axes[0:-1]]
            else:
                frame_set(fig.axes[CURRENT_SUBPLOT], values['-FRAME-'])

        def FRAMEPART(self):
            if CURRENT_SUBPLOT == 'global':
                [
                    sub.spines["top"].set_visible(values['-FRAMEPART-'])
                    for sub in fig.axes[0:-1]
                ]
                [
                    sub.spines["right"].set_visible(values['-FRAMEPART-'])
                    for sub in fig.axes[0:-1]
                ]
            else:
                fig.axes[CURRENT_SUBPLOT].spines["top"].set_visible(
                    values['-FRAMEPART-'])
                fig.axes[CURRENT_SUBPLOT].spines["right"].set_visible(
                    values['-FRAMEPART-'])

        #Radio Buttons for legend
        def TL(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=2) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=2)

        def T(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=9) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=9)

        def TR(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=1) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=1)

        def ML(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=6) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=6)

        def M(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=10) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=10)

        def MR(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=7) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=7)

        def BL(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=3) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=3)

        def B(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=8) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=8)

        def BR(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend(loc=4) for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend(loc=4)

        def NOLEGEND(self):
            if CURRENT_SUBPLOT == 'global':
                [sub.legend().remove() for sub in fig.axes[0:-1]]
            else:
                fig.axes[CURRENT_SUBPLOT].legend().remove()

        def SAVE(self):
            fname = sg.popup_get_file('Save figure', save_as=True)
            fig.savefig(fname)

################################################################################################

    event_manager = event_manager_class()

    while True:
        event, values = window.read(timeout=10)
        #print(event)
        #tic = time.perf_counter()

        if event == 'Exit' or event is None:
            break
        elif event == '__TIMEOUT__':
            pass
        #This gets a special elif bc I need to change the global variable CURRENT_SUBPLOT
        elif event == '-SUBPLOT-':
            idx = values['-SUBPLOT-'][0][-1]
            if idx.isdigit():
                CURRENT_SUBPLOT = int(idx)
            else:
                CURRENT_SUBPLOT = 'global'
            window['-AXESBOX-'].Update("General Axes Options, " +
                                       values['-SUBPLOT-'][0])
        else:
            event_manager.switch(event)
            update_graph()
        #toc = time.perf_counter()
        #print(toc-tic)

    #window['-OUT-'].update(CURRENT_SUBPLOT)

    window.close()
    window_g.close()