Ejemplo n.º 1
0
class PeakFinder(object):
    def __init__(self, ax, canvas):
        self.rectProps  = dict(facecolor='red', edgecolor = 'white',
                 alpha=0.5, fill=True)
        self.indicatorProps = dict(facecolor='white', edgecolor='black', alpha=0.5, fill=True)
        self.__selector = RectangleSelector(ax, self.onSelect, drawtype='box', rectprops=self.rectProps)
        self.__axes     = ax
        self.__canvas   = canvas
        
    def onSelect(self, epress, erelease):
        start   = map(int, (epress.xdata, epress.ydata))
        stop    = map(int, (erelease.xdata, erelease.ydata))
        ###################
        ax      = self.__axes
        dataMatrix  = ax.get_axes().get_images()[0].get_array()
        clipMatrix  = dataMatrix[start[1]:(stop[1]+1), start[0]:(stop[0]+1)]
        peakPos     = nonzero(clipMatrix == clipMatrix.max())
        peakPos     = (peakPos[1][0] + start[0], peakPos[0][0] + start[1])
        print peakPos
        circle      = Circle(peakPos, 4, **self.indicatorProps)
        ax.add_patch(circle)
        self.__canvas.show()
        ###################
    
    def activate(self):
        self.__selector.set_active(True)
        
    def deactivate(self):
        self.__selector.set_active(False)
        
    @property
    def isActivate(self):
        return self.__selector.active
Ejemplo n.º 2
0
class SeriesWidget(MplWidget):
    def __init__(self, parent=None):
        MplWidget.__init__(self, parent)
        
    def draw(self):
        (self.coord, series, title, xlabel, ylabel, showLegend) = self.inputPorts
        colors = pylab.cm.jet(np.linspace(0,1, series.values.shape[0]))
        
        pylab.clf()
        pylab.title(title)
        ll = pylab.plot(series.values.T, linewidth=1)
        for pos, _ids in enumerate(series.ids):
            ll[pos].set_color(colors[pos])
            if _ids in self.selectedIds:
                ll[pos].set_linewidth(3)
        pylab.xlabel(xlabel)
        pylab.ylabel(ylabel)

        if showLegend:
            pylab.legend(pylab.gca().get_lines(),
                         series.labels,
                         numpoints=1, prop=dict(size='small'), loc='upper right')
                
        self.figManager.canvas.draw()
        self.rectSelector = RectangleSelector(pylab.gca(), self.onselect, drawtype='box', 
                                              rectprops=dict(alpha=0.4, facecolor='yellow'))
        self.rectSelector.set_active(True)

    def updateSelection(self, selectedIds):
        self.selectedIds = selectedIds
        self.updateContents();
    
    def onselect(self, eclick, erelease):
        pass
Ejemplo n.º 3
0
class Select(object):
    def __init__(self, ax, data, flags):
        from matplotlib.widgets import RectangleSelector
        self._rs = RectangleSelector(ax,
                                     self._onselect,
                                     drawtype='box',
                                     interactive=True)
        # plt.connect('key_press_event', self._toggle)
        self.data = data
        self.flags = flags

    def _toggle(self, e):
        if e.key == 's':
            if self._rs.active:
                self._rs.set_active(False)
            else:
                self._rs.set_active(True)

    def _onselect(self, *args):
        a, b = sorted(self._rs.corners[0][:2])
        c, d = sorted(self._rs.corners[1][1:3])
        x = self.data.loc[num2date(a):num2date(b)]
        x = x[(x > c) & (x < d)].dropna()
        self.index = x.index
        plt.sca(self._rs.ax)
        if hasattr(self, 'pl'):
            self.pl[0].remove()
        self.pl = plt.plot(x, 'x')
        plt.draw()

    def redraw(self, data):
        plt.sca(self._rs.ax)
        plt.plot(data.xs('avg', 1, 'aggr')[self.data.columns[0]].dropna())
        plt.draw()
Ejemplo n.º 4
0
class SeriesWidget(MplWidget):
    def __init__(self, parent=None):
        MplWidget.__init__(self, parent)
        
    def draw(self):
        (self.coord, series, title, xlabel, ylabel, showLegend) = self.inputPorts
        colors = pylab.cm.jet(np.linspace(0,1, series.values.shape[0]))
        
        pylab.clf()
        pylab.title(title)
        ll = pylab.plot(series.values.T, linewidth=1)
        for pos, _ids in enumerate(series.ids):
            ll[pos].set_color(colors[pos])
            if _ids in self.selectedIds:
                ll[pos].set_linewidth(3)
        pylab.xlabel(xlabel)
        pylab.ylabel(ylabel)

        if showLegend:
            pylab.legend(pylab.gca().get_lines(),
                         series.labels,
                         numpoints=1, prop=dict(size='small'), loc='upper right')
                
        self.figManager.canvas.draw()
        self.rectSelector = RectangleSelector(pylab.gca(), self.onselect, drawtype='box', 
                                              rectprops=dict(alpha=0.4, facecolor='yellow'))
        self.rectSelector.set_active(True)

    def updateSelection(self, selectedIds):
        self.selectedIds = selectedIds
        self.updateContents();
    
    def onselect(self, eclick, erelease):
        pass
Ejemplo n.º 5
0
class AreaSelector(Frozen):
    def __init__(self, ax, line_select_callback):
        self.ax = ax
        self.rs = RectangleSelector(
            ax,
            line_select_callback,
            drawtype='box',
            useblit=False,
            button=[1, 3],  # don't use middle button
            minspanx=0,
            minspany=0,
            spancoords='pixels',
            interactive=True)

    def __call__(self, event):
        self.rs.update()
        if self.ax == event.inaxes:
            if event.key in ['Q', 'q']:
                self.rs.to_draw.set_visible(False)
                self.rs.set_active(False)
            if event.key in ['A', 'a']:
                self.rs.to_draw.set_visible(True)
                self.rs.set_active(True)

        return  #__call__
Ejemplo n.º 6
0
class run():
    def __init__(self):
        fig, current_ax = plt.subplots()
        yvec = []
        for i in range(200):
            yy = 25 + 3 * random.randn()
            yvec.append(yy)
            plt.plot(yvec, 'o')

        self.RS = RectangleSelector(current_ax,
                                    self.line_select_callback,
                                    drawtype='box',
                                    useblit=True,
                                    button=[1, 3],
                                    minspanx=5,
                                    minspany=5,
                                    spancoords='pixels',
                                    interactive=True)
        plt.connect('key_press_event', self.toggle_selector)
        plt.show()

    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))

    def toggle_selector(self, event):
        print(' Key pressed.')
        if event.key in ['Q', 'q'] and self.RS.active:
            print(' RectangleSelector deactivated.')
            self.RS.set_active(False)
        if event.key in ['A', 'a'] and not self.RS.active:
            print(' RectangleSelector activated.')
            self.RS.set_active(True)
Ejemplo n.º 7
0
class Ui_Form_bragg(QDialog):
    def __init__(self, data):
        super(Ui_Form_bragg, self).__init__()
        self.setObjectName("Form")
        self.resize(874, 726)
        self.data = data
        self.horizontalLayout = QtWidgets.QHBoxLayout(self)
        self.horizontalLayout.setObjectName("horizontalLayout")
        self.cmap = 'binary'
        self.width = 8
        self.height = 8
        self.dpi = 100
        self.fig = Figure(figsize=(self.width, self.height), dpi=self.dpi)
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.add_subplot(111)
        self.ax.axis('off')

        self.ax.pcolormesh(data, cmap=self.cmap)
        self.canvas.draw()
        self.rs = RectangleSelector(
            self.ax,
            self.line_select_callback,
            drawtype='line',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)
        self.retranslateUi(self)
        self.horizontalLayout.addWidget(self.canvas)

        QtCore.QMetaObject.connectSlotsByName(self)

        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.datas = dict([('x1', []), ('x2', []), ('y1', []), ('y2', [])])

    def on_click(self, event):
        if event.button == 1 or event.button == 3 and not self.rs.active:
            self.rs.set_active(True)
        else:
            self.rs.set_active(False)

    def line_select_callback(self, eclick, erelease):
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
        print(" The button you used were: %s %s" %
              (eclick.button, erelease.button))

        self.datas['x1'].append(int(np.floor(eclick.xdata)))
        self.datas['y1'].append(int(np.floor(eclick.ydata)))
        self.datas['x2'].append(int(np.floor(erelease.xdata)))
        self.datas['y2'].append(int(np.floor(erelease.ydata)))

    def retranslateUi(self, Form):
        _translate = QtCore.QCoreApplication.translate
        Form.setWindowTitle(_translate("Form", "Form"))
Ejemplo n.º 8
0
class MplCanvas(QWidget):

    roi_updated = pyqtSignal(tuple)

    def __init__(self, orient=0, axisoff=True, autoscale=False, **kwargs):
        super(MplCanvas, self).__init__()
        self.orient = orient
        self.setLayout(QVBoxLayout())

        # Figure
        self.fig, self.ax, self.canvas = self.figimage(axisoff=axisoff)
        self.rs = RectangleSelector(
            self.ax,
            self.line_select_callback,
            drawtype="box",
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords="pixels",
            interactive=True,
        )
        self.pressed = False
        self.ax.autoscale(enable=autoscale)
        self.layout().addWidget(self.canvas, 1)
        self.canvas.mpl_connect("button_press_event", self.on_press)

    def replot(self):
        self.ax.clear()
        self.ax.cla()

    def redraw(self):
        self.canvas.draw_idle()

    def figimage(self, scale=1, dpi=None, axisoff=True):
        fig = plt.figure(figsize=(10, 10))
        canvas = FigureCanvasQTAgg(fig)
        ax = fig.add_subplot(111)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        if axisoff:
            fig.subplots_adjust(left=0.03, bottom=0.05, right=0.97, top=0.99)
        canvas.draw()
        return fig, ax, canvas

    def on_press(self, event):
        self.setFocus()
        if event.button == 1 or event.button == 3 and not self.rs.active:
            #   self.redraw()
            self.rs.set_active(True)
        else:
            self.rs.set_active(False)

    def line_select_callback(self, eclick, erelease):
        self.roi_updated.emit(self.rs.extents)
class SelectData:
    def __init__(self, timestamps, signals, button_names=[]):
        self.button_names = button_names
        self.signals = signals
        self.timestamps = timestamps
        self.pick_indices = range(len(timestamps))
        self.button_array = []

    def onclick(self, event):
        self.event = event

    def line_select_callback(self, eclick, erelease):
        """eclick and erelease are the press and release events"""
        self.x1, self.y1 = eclick.xdata, eclick.ydata
        self.x2, self.y2 = erelease.xdata, erelease.ydata

    def close_calback(self):
        plt.close()

    def name_calback(self, event, name):
        print name

    def boxSelect(self):
        fig, current_ax = plt.subplots()

        print len(self.signals)
        for signal in self.signals:
            plt.plot(self.timestamps, signal)

        self.RS = RectangleSelector(
            current_ax,
            self.line_select_callback,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels')
        self.RS.set_active(True)

        plt.show()
        examp_indices = self.subsetData(self.x1, self.x2)
        return examp_indices

    def subsetData(self, x1, x2):
        dt = self.timestamps[len(self.timestamps) /
                             2] - self.timestamps[len(self.timestamps) / 2 - 1]
        x1_offset = x1 - self.timestamps[0]
        x2_offset = x2 - self.timestamps[0]
        start_idx = np.floor(x1_offset / dt).astype(int)
        end_idx = np.floor(x2_offset / dt).astype(int)
        return range(start_idx, end_idx)
Ejemplo n.º 10
0
class MplInteractiveWidget(MplWidget):

    selectionChanged = pyqtSignal(object)

    def __init__(self, *args, **kwargs):

        super(MplInteractiveWidget, self).__init__(*args, **kwargs)

        # drawtype is 'box' or 'line' or 'none'
        self.RS = RectangleSelector(
            self.sc.axes,
            self.range_select_callback,
            drawtype='box',
            useblit=True,
            button=[1],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)

    def select_range(self, r):

        if self.ploted_stuff is None:
            return

        idx = np.where((self.x > r[0]) & (self.x < r[1]) & (self.y > r[2])
                       & (self.y < r[3]))[0]
        print(idx)
        self.select_indices(idx)
        self.selectionChanged.emit(idx)

    def range_select_callback(self, eclick, erelease) -> None:
        'eclick and erelease are the press and release events'
        # x1, y1 = eclick.xdata, eclick.ydata
        # x2, y2 = erelease.xdata, erelease.ydata
        # print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
        # print(" The button you used were: %s %s" % (eclick.button, erelease.button))
        print(self.RS.extents)
        self.select_range(self.RS.extents)

    def toggle_selector(self, event) -> None:
        print(' Key pressed.', event.key())

        if event.key() in [QtCore.Qt.Key_Q] and self.RS.active:
            print(' RectangleSelector deactivated.')
            self.RS.set_active(False)

        if event.key() in [QtCore.Qt.Key_A] and not self.RS.active:
            print(' RectangleSelector activated.')
            self.RS.set_active(True)
Ejemplo n.º 11
0
class ROISelector(object):
    
    def __init__(self,artist):
            self.artist = artist
            self.selector = RectangleSelector(self.artist.axes,self.on_select,
                                       button=3, minspanx=5, minspany=5, spancoords='pixels',
                                       rectprops = dict(facecolor='red', edgecolor = 'red',
                                                        alpha=0.3, fill=True)) # drawtype='box'
            self.coords = []
            
    def on_select(self,click,release):
            x1,y1 = int(click.xdata),int(click.ydata)
            x2,y2 = int(release.xdata),int(release.ydata)
            self.coords =[(x1,y1),(x2,y2)]
            
    def activate(self):
        self.selector.set_active(True)
        
    def deactivate(self):
        self.selector.set_active(False)        
Ejemplo n.º 12
0
class ROISelector(object):
    
    def __init__(self,artist):
            self.artist = artist
            self.selector = RectangleSelector(self.artist.axes,self.on_select,
                                       button=3, minspanx=5, minspany=5, spancoords='pixels',
                                       rectprops = dict(facecolor='red', edgecolor = 'red',
                                                        alpha=0.3, fill=True))
            self.coords = []
            
    def on_select(self,click,release):
            x1,y1 = int(click.xdata),int(click.ydata)
            x2,y2 = int(release.xdata),int(release.ydata)
            self.coords =[(x1,y1),(x2,y2)]
            
    def activate(self):
        self.selector.set_active(True)
        
    def deactivate(self):
        self.selector.set_active(False)        
Ejemplo n.º 13
0
class MPLWidget(QtWidgets.QWidget):
    def __init__(self, parent=None):
        QtWidgets.QWidget.__init__(self, parent)
        self.setupUi()
        self.maskingSetting = False;
        self.rectSelect = RectangleSelector(self.ax, self.maskSelected,
                                       drawtype='box', useblit=True,
                                       interactive=False)
        self.rectSelect.set_active(False);
        self.dataArtist = matplotlib.lines.Line2D([],[],linestyle='',marker='.',markerfacecolor='b');
        self.ax.add_line(self.dataArtist)
        self.maskArtist = matplotlib.lines.Line2D([],[],linestyle='',marker='x',markerfacecolor='r',markeredgecolor='r');
        self.ax.add_line(self.maskArtist)
        self.fitArtist = matplotlib.lines.Line2D([],[],linestyle='--',color='k');
        self.ax.add_line(self.fitArtist)

    def setupUi(self):
        layout = QtWidgets.QVBoxLayout(self);
        self.canvas = FigureCanvas(Figure(figsize=(5, 3)))
        self.toolbar = NavigationToolbar(self.canvas, self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)
        self.ax = self.canvas.figure.subplots()
        self.show()

    def setDataManager(self,dm):
        self.dataManager = dm;
        self.dataManager.attach(self.dataUpdated)

    def maskSelect(self,setting):
        self.maskingSetting = setting;
        self.rectSelect.set_active(True);

    # FIXME: This won't work yet
    def maskSelected(self,click,release): # This function starts the rectangle selection on the figure and returns an array of whether or not the n-th point was in the selection
        self.rectSelect.set_active(False);
        lowerLeft = [min(click.xdata,release.xdata), min(click.ydata,release.ydata)]
        upperRight = [max(click.xdata,release.xdata), max(click.ydata,release.ydata)]
        data = self.dataManager.data;
        dataX = data[:,0]
        dataY = data[:,1]
        mask = (dataX > lowerLeft[0]) & (dataX < upperRight[0]) & (dataY > lowerLeft[1]) & (dataY < upperRight[1])
        self.dataManager.updateMask(mask, self.maskingSetting)

    def dataUpdated(self,sender,name=None):
        if(name=="FitStarted"):
            return;
        if (self.dataManager.dataIsValid is True):
            data = sender.data
            # mask = sender.mask
            # This may need to include the mask usage. I'm not sure.
            self.dataArtist.set_xdata(data[~data.mask[:,0],0].data); self.dataArtist.set_ydata(data[~data.mask[:,1],1].data)
            self.maskArtist.set_xdata(data[data.mask[:,0],0].data); self.maskArtist.set_ydata(data[data.mask[:,1],1].data)
            self.ax.relim();
            self.ax.autoscale()
        if (self.dataManager.fitfuncIsValid is True and self.dataManager.fit is not None):
            self.fitArtist.set_xdata(sender.fit[:,0]); self.fitArtist.set_ydata(sender.fit[:,1])
        else:
            self.fitArtist.set_xdata([]); self.fitArtist.set_ydata([])
        self.canvas.draw()
Ejemplo n.º 14
0
class RectangleSelection(object):
    def __init__(self, img):
        self.rectangle = None
        self.img = img
        self.done = False

        #Setup the figure
        self.fig, self.ax = plt.subplots()
        plt.imshow(self.img, cmap='gray')

        self.RS = RectangleSelector(
            self.ax,
            self.onselect,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)

        plt.connect('key_press_event', self.toggle_selector)
        plt.show()

    def onselect(self, e_click, e_release):
        minRow = int(min(e_click.ydata, e_release.ydata))
        minCol = int(min(e_click.xdata, e_release.xdata))
        maxRow = int(max(e_click.ydata, e_release.ydata))
        maxCol = int(max(e_click.xdata, e_release.xdata))
        self.rectangle = (minRow, minCol, maxRow, maxCol)

    def toggle_selector(self, event):
        if event.key in ['Q', 'q'] and self.RS.active:
            self.RS.set_active(False)
        if event.key in ['A', 'a'] and not self.RS.active:
            self.RS.set_active(True)
Ejemplo n.º 15
0
class BlockingRectangleSelector:
    """
    Blocking rectangle selector selects once then continues with script.
    """
    def __init__(self, ax=None):
        if ax is None: ax = gca()
        self.ax = ax

        # drawtype is 'box' or 'line' or 'none'
        self.selector = RectangleSelector(self.ax,
                                          self._callback,
                                          drawtype='box',
                                          useblit=True,
                                          minspanx=5,
                                          minspany=5,
                                          spancoords='pixels')
        self.selector.set_active(False)
        self.block = BlockingInput(self.ax.figure)

    def _callback(self, event1, event2):
        """
        Selection callback.  event1 and event2 are the press and release events
        """
        x1, y1 = event1.xdata, event1.ydata
        x2, y2 = event2.xdata, event2.ydata
        if x1 > x2: x1, x2 = x2, x1
        if y1 > y2: y1, y2 = y2, y1
        self.x1, self.x2, self.y1, self.y2 = x1, x2, y1, y2
        print 'stopping event loop'
        self.ax.figure.canvas.stop_event_loop_default()

    def select(self):
        """
        Wait for box to be selected on the axes.
        """

        # Wait for selector to complete
        self.selector.set_active(True)
        self.ax.figure.canvas.draw_idle()
        self.block()
        self.selector.set_active(False)

        # Make sure the graph is redrawn next time the event loop is shown
        self.ax.figure.canvas.draw_idle()

    def remove(self):
        """
        Remove the selector from the axes.

        Note: this currently does nothing since matplotlib doesn't allow
        widgets to be removed from axes.
        """
        pylab.close('all')
Ejemplo n.º 16
0
class BlockingRectangleSelector:
    """
    Blocking rectangle selector selects once then continues with script.
    """
    def __init__(self, ax=None):
        if ax is None: ax=gca()
        self.ax = ax

        # drawtype is 'box' or 'line' or 'none'
        self.selector = RectangleSelector(self.ax, self._callback,
                               drawtype='box',useblit=True,
                               minspanx=5,minspany=5,spancoords='pixels')
        self.selector.set_active(False)
        self.block = BlockingInput(self.ax.figure)


    def _callback(self, event1, event2):
        """
        Selection callback.  event1 and event2 are the press and release events
        """
        x1, y1 = event1.xdata, event1.ydata
        x2, y2 = event2.xdata, event2.ydata
        if x1>x2: x1,x2 = x2,x1
        if y1>y2: y1,y2 = y2,y1
        self.x1,self.x2,self.y1,self.y2 = x1,x2,y1,y2
        print 'stopping event loop'
        self.ax.figure.canvas.stop_event_loop_default()


    def select(self):
        """
        Wait for box to be selected on the axes.
        """

        # Wait for selector to complete
        self.selector.set_active(True)
        self.ax.figure.canvas.draw_idle()
        self.block()
        self.selector.set_active(False)

        # Make sure the graph is redrawn next time the event loop is shown
        self.ax.figure.canvas.draw_idle()

    def remove(self):
        """
        Remove the selector from the axes.

        Note: this currently does nothing since matplotlib doesn't allow
        widgets to be removed from axes.
        """
        pylab.close('all')
Ejemplo n.º 17
0
class MatplotlibWidget(QWidget):

    def __init__(self, parent=None, axtype='', title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', showtoolbar=True, dpi=100, *args, **kwargs):
        super(MatplotlibWidget, self).__init__(parent)

        # initialize axes (ax) and plots (l)
        self.axtype = axtype
        self.ax = [] # list of axes 
        self.l = {} # all the plot stored in dict
        self.l['temp'] = [] # for temp lines in list
        self.txt = {} # all text in fig
        self.leg = '' # initiate legend 

        # set padding size
        if axtype == 'sp': 
            self.fig = Figure(tight_layout={'pad': 0.05}, dpi=dpi)
            # self.fig = Figure(tight_layout={'pad': 0.05}, dpi=dpi, facecolor='none')
        elif axtype == 'legend':
            self.fig = Figure(dpi=dpi, facecolor='none')

        else:
            self.fig = Figure(tight_layout={'pad': 0.2}, dpi=dpi)
            # self.fig = Figure(tight_layout={'pad': 0.2}, dpi=dpi, facecolor='none')
        ### set figure background transparsent
        # self.setStyleSheet("background: transparent;")

        # FigureCanvas.__init__(self, fig)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setSizePolicy(QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        self.canvas.setFocusPolicy(Qt.ClickFocus)
        self.canvas.setFocus()
        # connect with resize function
        # self.canvas.mpl_connect("resize_event", self.resize)

        # layout
        self.vbox = QVBoxLayout()
        self.vbox.setContentsMargins(0, 0, 0, 0) # set layout margins
        self.vbox.addWidget(self.canvas)
        self.setLayout(self.vbox)

        # add toolbar and buttons given by showtoolbar
        if isinstance(showtoolbar, tuple):
            class NavigationToolbar(NavigationToolbar2QT):
                toolitems = [t for t in NavigationToolbar2QT.toolitems if t[0] in showtoolbar]
        else:
            class NavigationToolbar(NavigationToolbar2QT):
                pass                    

        self.toolbar = NavigationToolbar(self.canvas, self)
        self.toolbar.setMaximumHeight(config_default['max_mpl_toolbar_height'])
        self.toolbar.setStyleSheet("QToolBar { border: 0px;}")
        # if isinstance(showtoolbar, tuple):
        #     logger.info(self.toolbar.toolitems) 
        #     NavigationToolbar.toolitems = (t for t in NavigationToolbar2QT.toolitems if t[0] in showtoolbar)
        #     logger.info(self.toolbar.toolitems) 

        # self.toolbar.hide() # hide toolbar (or setHidden(bool))
        self.toolbar.isMovable()
        if showtoolbar:
            self.vbox.addWidget(self.toolbar)
            if self.axtype == 'sp_fit':
                self.toolbar.press_zoom = types.MethodType(press_zoomX, self.toolbar)
        else:
            # pass
            self.toolbar.hide() # hide toolbar. remove this will make every figure with shwotoolbar = False show tiny short toolbar 
        
       
        self.initial_axes(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale='linear')

        # set figure border
        # self.resize('draw_event')
        # if axtype == 'sp':
        #     plt.tight_layout(pad=1.08)
        # else:
        #     self.fig.subplots_adjust(left=0.12, bottom=0.13, right=.97, top=.98, wspace=0, hspace=0)
        #     # self.fig.tight_layout()
        #     # self.fig.tight_layout(pad=0.5, h_pad=0, w_pad=0, rect=(0, 0, 1, 1))

        self.canvas_draw()
    
             
        # self.fig.set_constrained_layout_pads(w_pad=0., h_pad=0., hspace=0., wspace=0.) # for python >= 3.6
        # self.fig.tight_layout()


    def initax_xy(self, *args, **kwargs):
        # axes
        ax1 = self.fig.add_subplot(111, facecolor='none')
        # ax1 = self.fig.add_subplot(111)

        # if self.axtype == 'sp_fit':
        #     # setattr(ax1, 'drag_pan', AxesLockY.drag_pan)
        #     ax1 = self.fig.add_subplot(111, facecolor='none', projection='AxesLockY')
        # else:
        #     ax1 = self.fig.add_subplot(111, facecolor='none')

        # ax1.autoscale()
        # logger.info(ax.format_coord) 
        # logger.info(ax.format_cursor_data) 
        # plt.tight_layout()
        # plt.tight_layout(pad=None, w_pad=None, h_pad=None,  rect=None)
        
        # append to list
        self.ax.append(ax1)


    def initax_xyy(self):
        '''
        input: ax1 
        output: [ax1, ax2]
        '''
        self.initax_xy()
        
        ax2 = self.ax[0].twinx()
        self.ax[0].set_zorder(self.ax[0].get_zorder()+1)

        ax2.tick_params(axis='y', labelcolor=color[1], color=color[1])
        ax2.yaxis.label.set_color(color[1])
        ax2.spines['right'].set_color(color[1])
        # ax2.autoscale()
        ax2.spines['left'].set_visible(False)


        # change axes color
        self.ax[0].tick_params(axis='y', labelcolor=color[0], color=color[0])
        self.ax[0].yaxis.label.set_color(color[0])
        self.ax[0].spines['left'].set_color(color[0])
        self.ax[0].spines['right'].set_visible(False)

        # append ax2 to self.ax
        self.ax.append(ax2)


    def init_sp(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the sp[n]
        initialize ax[0]: .lG, .lGfit, .lp, .lpfit plot
        initialize ax[1]: .lB, .lBfit, plot
        '''

        self.initax_xyy()

        self.ax[0].margins(x=0)
        self.ax[1].margins(x=0)
        self.ax[0].margins(y=.05)
        self.ax[1].margins(y=.05)

        # self.ax[0].autoscale()
        # self.ax[1].autoscale()

        self.l['lG'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G
        self.l['lB'] = self.ax[1].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[1]
        ) # B
        # self.l['lGpre'] = self.ax[0].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous G
        # self.l['lBpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous B
        # self.l['lPpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous polar
        self.l['lGfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # G fit
        self.l['lBfit'] = self.ax[1].plot(
            [], [], 
            color='k'
        ) # B fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['ltollb'] = self.ax[0].plot(
            [], [],
            linestyle='--',
            color='k'
        ) # tolerance interval lines
        self.l['ltolub'] = self.ax[0].plot(
            [], [],
            linestyle='--',
            color='k'
        ) # tolerance interval lines


        self.l['lP'] = self.ax[0].plot(
            [], [],
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # polar plot
        self.l['lPfit'] = self.ax[0].plot(
            [], [],
            color='k'
        ) # polar plot fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['lsp'] = self.ax[0].plot(
            [], [],
            color=color[2]
        ) # peak freq span


        # self.ax[0].xaxis.set_major_locator(plt.AutoLocator())
        # self.ax[0].xaxis.set_major_locator(plt.LinearLocator())
        # self.ax[0].xaxis.set_major_locator(plt.MaxNLocator(3))

        # add text
        self.txt['sp_harm'] = self.fig.text(0.01, 0.98, '', va='top',ha='left') # option: weight='bold'
        self.txt['chi'] = self.fig.text(0.01, 0.01, '', va='bottom',ha='left')

        # set label of ax[1]
        self.set_ax(self.ax[0], title=title, xlabel=r'$f$ (Hz)',ylabel=r'$G_P$ (mS)')
        self.set_ax(self.ax[1], xlabel=r'$f$ (Hz)',ylabel=r'$B_P$ (mS)')

        self.ax[0].xaxis.set_major_locator(ticker.LinearLocator(3))



    def update_sp_text_harm(self, harm):
        if isinstance(harm, int):
            harm = str(harm)
        self.txt['sp_harm'].set_text(harm)


    def update_sp_text_chi(self, chi=None):
        '''
        chi: number
        '''
        if not chi:
            self.txt['chi'].set_text('')
        else:
            self.txt['chi'].set_text(r'$\chi^2$ = {:.4f}'.format(chi))


    def init_sp_fit(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the spectra fit
        initialize .lG, .lB, .lGfit, .lBfit .lf, .lg plot
        '''
        self.initax_xyy()

        self.l['lG'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G
        self.l['lB'] = self.ax[1].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[1]
        ) # B
        # self.l['lGpre'] = self.ax[0].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous G
        # self.l['lBpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous B
        self.l['lGfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # G fit
        self.l['lBfit'] = self.ax[1].plot(
            [], [], 
            color='k'
        ) # B fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['lsp'] = self.ax[0].plot(
            [], [],
            color=color[2]
        ) # peak freq span

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$f$ (Hz)',ylabel=r'$G_P$ (mS)')
        self.set_ax(self.ax[1], xlabel=r'$f$ (Hz)',ylabel=r'$B_P$ (mS)')

        self.ax[0].xaxis.set_major_locator(ticker.LinearLocator(3))

        # self.ax[0].xaxis.set_major_locator(plt.AutoLocator())
        # self.ax[0].xaxis.set_major_locator(plt.LinearLocator())
        # self.ax[0].xaxis.set_major_locator(plt.MaxNLocator(3))

        self.ax[0].margins(x=0)
        self.ax[1].margins(x=0)
        self.ax[0].margins(y=.05)
        self.ax[1].margins(y=.05)
        # self.ax[1].sharex = self.ax[0]

        # self.ax[0].autoscale()
        # self.ax[1].autoscale()

        # add span selector
        self.span_selector_zoomin = SpanSelector(
            self.ax[0], 
            self.sp_spanselect_zoomin_callback,
            direction='horizontal', 
            useblit=True,
            button=[1],  # left click
            minspan=5,
            span_stays=False,
            rectprops=dict(facecolor='red', alpha=0.2)
        )        

        self.span_selector_zoomout = SpanSelector(
            self.ax[0], 
            self.sp_spanselect_zoomout_callback,
            direction='horizontal', 
            useblit=True,
            button=[3],  # right
            minspan=5,
            span_stays=False,
            rectprops=dict(facecolor='blue', alpha=0.2)
        )        


    def sp_spanselect_zoomin_callback(self, xclick, xrelease): 
        '''
        callback of span_selector
        '''

        self.ax[0].set_xlim(xclick, xrelease)


    def sp_spanselect_zoomout_callback(self, xclick, xrelease): 
        '''
        callback of span_selector
        '''
        curr_f1, curr_f2 = self.ax[0].get_xlim()
        curr_fc, curr_fs = UIModules.converter_startstop_to_centerspan(curr_f1, curr_f2)
        # selected range
        sel_fc, sel_fs = UIModules.converter_startstop_to_centerspan(xclick, xrelease)
        # calculate the new span
        ratio = curr_fs / sel_fs
        new_fs = curr_fs * ratio
        new_fc = curr_fc * (1 + ratio) - sel_fc * ratio
        # center/span to f1/f2
        new_f1, new_f2 = UIModules.converter_centerspan_to_startstop(new_fc, new_fs)
        # logger.info('curr_fs %s', curr_fs) 
        # logger.info('sel_fs %s', sel_fs) 
        # logger.info('new_fs %s', new_fs) 
        # logger.info('curr %s %s', curr_f1, curr_f2) 
        # logger.info('new %s', new_f1, new_f2) 
        # set new xlim
        self.ax[0].set_xlim(new_f1, new_f2)
       

    def init_sp_polar(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the spectra polar
        initialize plot: l['l'], l['lfit']
        '''
        self.initax_xy()

        self.l['l'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G vs. B
        self.l['lfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # fit

        self.l['lsp'] = self.ax[0].plot(
            [], [], 
            color=color[2]
        ) # fit in span range

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$G_P$ (mS)',ylabel=r'$B_P$ (mS)')

        # self.ax[0].autoscale()
        self.ax[0].set_aspect('equal')


    def init_data(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the mpl_plt1 & mpl_plt2
        initialize plot: 
            .l<nharm> 
            .lm<nharm>
        '''
        self.sel_mode = 'none' # 'none', 'selector', 'picker'

        self.initax_xy()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l
        
        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['lm' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                linestyle='none',
                picker=True,
                pickradius=5, # 5 points tolerance
                label='lm'+str(i),
                alpha=0.75,
            ) # marked points of line

            self.l['lt' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                linestyle='none',
                # picker=None, # picker is off by default
                label='lt'+str(i),
                alpha=0.35,
            ) # temperary points of line
            
            self.l['ls' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                markeredgecolor=color[1], 
                markerfacecolor=color[1],
                alpha= 0.5,
                linestyle='none',
                label='ls'+str(i),
            ) # points in rectangle_selector

        self.l['lp'] = self.ax[0].plot(
            [], [], 
            marker='+', 
            markersize = 12,
            markeredgecolor=color[1],
            markeredgewidth=1, 
            alpha= 1,
            linestyle='none',
            label='',
        ) # points of picker
        
        self.cid = None # for storing the cid of pick_event

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel='Time (s)',ylabel=ylabel)

        # self.ax[0].autoscale()

        # add rectangle_selector
        self.rect_selector = RectangleSelector(
            self.ax[0], 
            self.data_rectselector_callback,
            drawtype='box',
            button=[1], # left
            useblit=True,
            minspanx=5,
            minspany=5,
            # lineprops=None,
            rectprops=dict(edgecolor = 'black', facecolor='none', alpha=0.2, fill=False),
            spancoords='pixels', # default 'data'
            maxdist=10,
            marker_props=None,
            interactive=False, # change rect after drawn
            state_modifier_keys=None,
        )  
        # set if inavtive
        self.rect_selector.set_active(False)

        # create a toggle button fro selector
        self.pushButton_selectorswitch = QPushButton()
        self.pushButton_selectorswitch.setText('')
        self.pushButton_selectorswitch.setCheckable(True)
        self.pushButton_selectorswitch.setFlat(True)
        # icon
        icon_sel = QIcon()
        icon_sel.addPixmap(QPixmap(':/button/rc/selector.svg'), QIcon.Normal, QIcon.Off)
        self.pushButton_selectorswitch.setIcon(icon_sel)
        
        self.pushButton_selectorswitch.clicked.connect(self.data_rectsleector_picker_switch)

        # add it to toolbar
        self.toolbar.addWidget(self.pushButton_selectorswitch)
        toolbar_children = self.toolbar.children()
        # toolbar_children.insert(6, toolbar_children.pop(-1)) # this does not move the icon position
        toolbar_children[4].clicked.connect(self.data_show_all) # 4 is the home button

        # NOTE below does not work
        # add selector switch button to toolbar
        # self.fig.canvas.manager.toolmanager.add_tool('Data Selector', SelectorSwitch, selector=self.rect_selector)

        # add button to toolbar
        # self.canvas.manager.toolbar.add_tool(self.fig.canvas.manager.toolmanager.get_tool('DataSelector'), 'toolgroup')


    def data_rectsleector_picker_switch(self, checked):
        if checked:
            logger.info(True) 
            # active rectangle selector
            self.rect_selector.set_active(True)
            # connect pick event
            self.cid = self.canvas.mpl_connect('pick_event', self.data_onpick_callback)

            logger.info('%s', self.toolbar.mode)

            if self.toolbar.mode == "pan/zoom":
                self.toolbar.pan()
            elif self.toolbar.mode == "zoom rect":
                self.toolbar.zoom()

            # below works only for matplotlib < 3.3
            # if self.toolbar.mode == "PAN": # matplotlib < 3.3
            #     self.toolbar.pan()
            # elif self.toolbar.mode == "ZOOM": # matplotlib < 3.3
            #     self.toolbar.zoom()
        else:

            # deactive rectangle selector
            self.rect_selector.set_active(False)
            # reset .l['ls<n>']
            self.clr_lines(l_list=['ls'+ str(i) for i in range(1, int(config_default['max_harmonic']+2), 2)])
            # deactive pick event            
            self.canvas.mpl_disconnect(self.cid)
            # clear data .l['lp']
            self.clr_lines(l_list=['lp'])

            # reset
            self.sel_mode = 'none'


    def data_rectselector_callback(self, eclick, erelease):
        '''
        when rectangle selector is active
        '''
        # clear pick data .l['lp']
        self.clr_lines(l_list=['lp'])

        # logger.info(dir(eclick)) 
        logger.info(eclick) 
        logger.info(erelease) 
        x1, x2 = sorted([eclick.xdata, erelease.xdata]) # x1 < x2
        y1, y2 = sorted([eclick.ydata, erelease.ydata]) # y1 < y2
        
        # # dict for storing the selected indices
        # sel_idx_dict = {}
        # list for updating selected data
        sel_list = []
        # find the points in rect
        for l_str in ['l', 'lm']: # only one will be not empty
            for harm in range(1, config_default['max_harmonic']+2, 2):
                harm = str(harm)
                # logger.info(harm) 
                # get data from current plotted lines
                # clear .l['ls<n>']
                self.clr_lines(l_list=['ls'+harm])
                
                # logger.info(l_str) 
                # logger.info(self.l[l_str + harm][0].get_data()) 
                harm_x, harm_y = self.l[l_str + harm][0].get_data()
                
                if isinstance(harm_x, pd.Series): # if data is series (not empty)
                    sel_bool = harm_x.between(x1, x2) & harm_y.between(y1, y2)

                    # save data for plotting selected data
                    sel_list.append({'ln': 'ls'+harm, 'x': harm_x[sel_bool], 'y': harm_y[sel_bool]})

                    # # save indices for later process
                    # sel_idx = harm_x[sel_bool].index
                    # logger.info(sel_idx) 
                    # # update selected indices
                    # sel_idx_dict[harm] = sel_idx
            if (l_str == 'l') and sel_list: # UI mode showall
                break
                #TODO It can also set the display mode from UI (showall/showmarked) and do the loop by the mode

        if sel_list: # there is data selected
            self.sel_mode = 'selector'
        else:
            self.sel_mode = 'none'

        # logger.info(sel_list) 
        # plot the selected data
        self.update_data(*sel_list)


    def data_onpick_callback(self, event):
        '''
        callback function of mpl_data pick_event
        '''
        # clear selector data .l['ls<n>']
        self.clr_lines(l_list=['ls'+ str(i) for i in range(1, int(config_default['max_harmonic']+2), 2)])

        # logger.info(dir(event)) 
        thisline = event.artist
        x_p = thisline.get_xdata()
        y_p = thisline.get_ydata()
        ind = event.ind[0]
        logger.info(thisline) 
        # logger.info(dir(thisline)) 
        logger.info(thisline.get_label()) 
        logger.info(x_p.name) 
        logger.info(y_p.name) 
        logger.info(ind) 
        # logger.info('onpick1 line: %s %s', zip(np.take(xdata, ind), np.take(ydata, ind))) 

        # plot
        logger.info('x_p %s', x_p) 
        logger.info('%s %s', x_p.iloc[ind], y_p.iloc[ind]) 
        self.l['lp'][0].set_data(x_p.iloc[ind], y_p.iloc[ind])
        self.l['lp'][0].set_label(thisline.get_label() + '_' + str(ind)) # transfer the label of picked line and ind to 'lp'
        self.canvas_draw()

        # set
        self.sel_mode = 'picker'


    def init_contour(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the mechanics_contour1 & mechanics_contour2
        initialize plot: 
            .l['C'] (contour) 
            .l['cbar'] (colorbar)
            .l['l<n>]
            .l['lm<n>]
        NOTE: this function should be used every time contour is changed
        '''
        logger.info("self.l.get('C'): %s", self.l.get('C')) 
        logger.info('kwargs: %s', kwargs.keys()) 

        # if self.l.get('colorbar'):
        #     self.l['colorbar'].remove()
        if self.ax:
            self.ax[0].cla()
            self.ax[1].clear()
        else:
            self.initax_xy()
            # create axes for the colorbar
            self.ax.append(make_axes_locatable(self.ax[0]).append_axes("right", size="5%", pad="2%"))

        if not 'X' in kwargs or not 'Y' in kwargs or not 'Z' in kwargs:
            num = config_default['contour_array']['num']
            phi_lim = config_default['contour_array']['phi_lim']
            dlam_lim = config_default['contour_array']['dlam_lim']

            # initiate X, Y, Z data
            x = np.linspace(phi_lim[0], phi_lim[1], num=num)
            y = np.linspace(dlam_lim[0], dlam_lim[1], num=num)
            X, Y = np.meshgrid(y, x)
            Z = np.random.rand(*X.shape)
        else:
            X = kwargs.get('X')
            Y = kwargs.get('Y')
            Z = kwargs.get('Z')

        if 'levels' in kwargs:
            levels = kwargs.get('levels')
        else:
            levels = config_default['contour_array']['levels']
            levels = np.linspace(np.min(Z), np.max(Z), levels)
        
        if 'cmap' in kwargs:
            cmap = kwargs.get('cmap')
        else:
            logger.info('cmap not in kwargs') 
            cmap = config_default['contour_array']['cmap']
        
        logger.info('levels %s', type(levels), ) 
        self.l['C'] = self.ax[0].contourf(
            X, Y, Z, # X, Y, Z
            levels=levels, 
            cmap=cmap,
        ) # contour

        self.l['colorbar'] = plt.colorbar(self.l['C'], cax=self.ax[1]) # colorbar
        self.l['colorbar'].locator = ticker.MaxNLocator(nbins=6)
        self.l['colorbar'].update_ticks()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                # marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['p' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                xerr=np.nan,
                yerr=np.nan,
                marker='o', 
                markerfacecolor='none', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['pm' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                yerr=np.nan,
                xerr=np.nan,
                marker='o', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop marked   

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$d/\lambda$',ylabel=r'$\Phi$ ($\degree$)', title=title)

        self.canvas_draw()
        self.ax[0].autoscale(enable=False)


    def init_legendfig(self, *args, **kwargs):
        ''' 
        plot a figure with only legend
        '''
        self.initax_xy()

        for i in range(1, config_default['max_harmonic']+2, 2):
            l = self.ax[0].plot([], [], label=i) # l[i]
        self.leg = self.fig.legend(
            # handles=l,
            # labels=range(1, config_default['max_harmonic']+2, 2),
            loc='upper center', 
            bbox_to_anchor=(0.5, 1),
            borderaxespad=0.,
            borderpad=0.,
            ncol=int((config_default['max_harmonic']+1)/2), 
            frameon=False, 
            facecolor='none',
            labelspacing=0.0, 
            columnspacing=0.5
        )
        self.canvas_draw()
        
        # set label of ax[1]
        self.set_ax(self.ax[0], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs)

        self.ax[0].set_axis_off() # turn off the axis

        # logger.info(dir(self.leg)) 
        # p = self.leg.get_window_extent() #Bbox of legend
        # # set window height
        # dpi = self.fig.get_dpi()
        # # logger.info(dir(self.fig)) 
        # fsize = self.fig.get_figheight()
 

    def init_prop(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize property plot
        initialize plot: 
            .l<ln>
            ln = l, lm, p, pm 
        '''
        self.initax_xy()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                # marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l
        
        # for i in range(1, int(config_default['max_harmonic']+2), 2):
        #     self.l['lm' + str(i)] = self.ax[0].plot(
        #         [], [], 
        #         # marker='o', 
        #         color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
        #         pickradius=5, # 5 points tolerance
        #         label='lm'+str(i),
        #         alpha=0.75,
        #     ) # maked points of line

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['p' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                xerr=np.nan,
                yerr=np.nan,
                marker='o', 
                markerfacecolor='none', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['pm' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                yerr=np.nan,
                xerr=np.nan,
                marker='o', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop marked
        
        # set label of ax[1]
        self.set_ax(self.ax[0], title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)

    # def sizeHint(self):
    #     return QSize(*self.get_width_height())

    # def minimumSizeHint(self):
    #     return QSize(10, 10)

    def set_ax(self, ax, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        self.set_ax_items(ax, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        self.set_ax_font(ax)


    def set_ax_items(self, ax, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        if title:
            ax.set_title(title)
        if xlabel:
            ax.set_xlabel(xlabel)
        if ylabel:
            ax.set_ylabel(ylabel)

        if xscale is not None:
            ax.set_xscale(xscale)
        if yscale is not None:
            ax.set_yscale(yscale)
        if xlim is not None:
            ax.set_xlim(*xlim)
        if ylim is not None:
            ax.set_ylim(*ylim)


    def set_ax_font(self, ax, *args, **kwargs):
        if self.axtype == 'sp':
            fontsize = config_default['mpl_sp_fontsize']
            legfontsize = config_default['mpl_sp_legfontsize']
            txtfontsize = config_default['mpl_sp_txtfontsize']
        else:
            fontsize = config_default['mpl_fontsize']
            legfontsize = config_default['mpl_legfontsize']
            txtfontsize = config_default['mpl_txtfontsize']

        if self.axtype == 'contour':
            for ticklabel in self.l['colorbar'].ax.yaxis.get_ticklabels():
                ticklabel.set_size(fontsize)
            self.l['colorbar'].ax.yaxis.offsetText.set_size(fontsize)
            

        if self.axtype == 'legend':
            self.leg
            plt.setp(self.leg.get_texts(), fontsize=legfontsize) 
            
        ax.title.set_fontsize(fontsize+1)
        ax.xaxis.label.set_size(fontsize+1)
        ax.yaxis.label.set_size(fontsize+1)
        ax.tick_params(labelsize=fontsize)
        ax.xaxis.offsetText.set_size(fontsize)
        ax.yaxis.offsetText.set_size(fontsize)
        # ax.xaxis.set_major_locator(ticker.LinearLocator(3))

        # set text fontsize
        for key in self.txt.keys():
            if key == 'sp_harm':
                self.txt[key].set_fontsize(config_default['mpl_sp_harmfontsize'])
            else: # normal text
                self.txt[key].set_fontsize(txtfontsize)
        


    def resize(self, event):
        # on resize reposition the navigation toolbar to (0,0) of the axes.
        # require connect
        # self.canvas.mpl_connect("resize_event", self.resize)

        # borders = [60, 40, 10, 5] # left, bottom, right, top in px
        # figw, figh = self.fig.get_size_inches()
        # dpi = self.fig.dpi
        # borders = [
        #     borders[0] / int(figw * dpi), # left
        #     borders[1] / int(figh * dpi), # bottom
        #     (int(figw * dpi) - borders[2]) / int(figw * dpi), # right
        #     (int(figh * dpi) - borders[3]) / int(figh * dpi), # top
        # ]
        # logger.info('%s %s', figw, figh) 
        # logger.info(borders) 
        # self.fig.subplots_adjust(left=borders[0], bottom=borders[1], right=borders[2], top=borders[3], wspace=0, hspace=0)
        
        self.fig.tight_layout(pad=1.08)
        # x,y = self.ax[0].transAxes.transform((0,0))
        # logger.info('%s %s', x, y) 
        # figw, figh = self.fig.get_size_inches()
        # ynew = figh*self.fig.dpi-y - self.toolbar.frameGeometry().height()
        # self.toolbar.move(x,ynew)        

    def initial_axes(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        intialize axes by axtype:
        'xy', 'xyy', 'sp', 'sp_fit', 'sp_polar', 'data', 'contour'
        '''
        if self.axtype == 'xy':
            self.initax_xy()
        elif self.axtype == 'xyy':
            self.initax_xyy()
        elif self.axtype == 'sp':
            self.init_sp(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'sp_fit':
            self.init_sp_fit(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'sp_polar':
            self.init_sp_polar(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'data':
            self.init_data(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'contour':
            self.init_contour(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale, *args, **kwargs)
        elif self.axtype == 'legend':
            self.init_legendfig()
        elif self.axtype == 'prop':
            self.init_prop(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        else:
            pass

    def update_data(self, *args):
        ''' 
        update data of given in args (dict)
        arg = {'ln':, 'x':, 'y':, 'xerr':, 'yerr':, 'label':,}
            ln: string of line name
            x : x data
            y : y data
            yerr: y error
        NOTE: don't use this func to update contour
        '''
        axs = set() # initialize a empty set
        
        for arg in args:
            keys = arg.keys()

            if ('xerr' in keys) or ('yerr' in keys): # errorbar with caps and barlines
                if isinstance(self.l[arg['ln']], ErrorbarContainer): # type match 
                    # logger.info(arg) 
                    # since we initialize the errorbar plots with xerr and yerr, we don't check if they exist here. If you want, use self.l[ln].has_yerr
                    ln = arg['ln'] 
                    x = arg['x'] 
                    y = arg['y'] 
                    xerr = arg['xerr'] 
                    yerr = arg['yerr']

                    line, caplines, barlinecols = self.l[ln]
                    line.set_data(x, y)
                    # Find the ending points of the errorbars 
                    error_positions = (x-xerr,y), (x+xerr,y), (x,y-yerr), (x,y+yerr) 
                    # Update the caplines 
                    for i, pos in enumerate(error_positions): 
                        # logger.info('i %s', i) 
                        # logger.info(caplines) 
                        # logger.info('caplines_len %s', len(caplines)) 
                        caplines[i].set_data(pos) 
                    # Update the error bars 
                    barlinecols[0].set_segments(zip(zip(x-xerr,y), zip(x+xerr,y))) 
                    barlinecols[1].set_segments(zip(zip(x,y-yerr), zip(x,y+yerr))) 
                    
                    # barlinecols[0].set_segments(
                    #     np.array([[x - xerr, y], 
                    #     [x + xerr, y]]).transpose((2, 0, 1))
                    # ) 
                    axs.add(line.axes)
            else: # not errorbar
                ln = arg['ln'] 
                x = arg['x'] 
                y = arg['y']
                # logger.info(len(x: %s), len(y)) 
                # self.l[ln][0].set_xdata(x)
                # self.l[ln][0].set_ydata(y)
                self.l[ln][0].set_data(x, y)
                axs.add(self.l[ln][0].axes)
            
            if 'label' in keys: # with label
                self.l[ln][0].set_label(arg['label'])

            # ax = self.l[ln][0].axes
            # axbackground = self.canvas.copy_from_bbox(ax.bbox)
            # logger.info(ax) 
            # self.canvas.restore_region(axbackground)
            # ax.draw_artist(self.l[ln][0])
            # self.canvas.blit(ax.bbox)

        for ax in axs:
            self.reset_ax_lim(ax)

        self.canvas_draw()


    def get_data(self, ls=[]):
        '''
        get data of given ls (lis of string)
        return a list of data with (x, y)
        '''
        data = []
        for l in ls:
            # xdata = self.l[l][0].get_xdata()
            # ydata = self.l[l][0].get_ydata()
            xdata, ydata = self.l[l][0].get_data()
            data.append((xdata, ydata))
        return data


    def del_templines(self, ax=None):
        ''' 
        del all temp lines .l['temp'][:] 
        '''
        if ax is None:
            ax = self.ax[0]

        # logger.info(ax.lines) 
        # logger.info('len temp %s', len(self.l['temp'])) 
        # logger.info('temp %s', self.l['temp']) 

        for l_temp in self.l['temp']:
            # logger.info('l_temp %s', l_temp) 
            ax.lines.remove(l_temp[0]) # remove from ax
        self.l['temp'] = [] # inintiate

        self.reset_ax_lim(ax)
        self.canvas_draw()


    def clr_all_lines(self):
        self.clr_lines()


    def clr_lines(self, l_list=None):
        ''' 
        clear all lines in .l (but not .l['temp'][:]) of key in l_list
        '''
        # logger.info(self.l) 
        for key in self.l:
            # logger.info(key) 
            if key not in ['temp', 'C', 'colorbar']:
                if  l_list is None or key in l_list: # clear all or key
                    # self.l[key][0].set_xdata([])
                    # self.l[key][0].set_ydata([])
                    
                    
                    if isinstance(self.l[key], ErrorbarContainer): # errorbar plot
                        logger.info('key: %s', key)
                        logger.info('len(self.l[key]): %s', len(self.l[key]))
                        # clear errorbar
                        line, caplines, barlinecols = self.l[key]

                        logger.info(line) 
                        logger.info(caplines) 
                        logger.info('caplines len %s', len(caplines))
                        logger.info(barlinecols) 
                        logger.info('barlinecols len %s', len(barlinecols))

                        line.set_data([], [])
                        error_positions = ([],[]), ([],[]), ([],[]), ([],[]) 
                        # Update the caplines 
                        for i, pos in enumerate(error_positions): 
                            logger.info('i %s', i) 
                            caplines[i].set_data(pos) 
                        # Update the error bars 
                        barlinecols[0].set_segments(zip(zip([],[]), zip([],[]))) 
                        barlinecols[1].set_segments(zip(zip([],[]), zip([],[]))) 
                    else:
                        self.l[key][0].set_data([], []) # line plot




                else:
                    pass
            elif key == 'temp':
                for ax in self.ax:
                    self.del_templines(ax=ax)

        logger.info('it is a contour: %s', 'C' in self.l) 
        if 'C' not in self.l: # not contour
            # we don't reset contour limit
            self.reset_ax_lim(ax)
        self.canvas_draw()


    def change_style(self, line_list, **kwargs):
        '''
        set style of artists in class
        artists: 'linestyle', 'markersize' etc. the same keywords as in matplotlib
        '''
        logger.info(line_list) 
        logger.info(self.l) 
        logger.info(self.l.keys()) 
        for key, val in kwargs.items():
            for l in line_list:
                eval("self.l['{0}'][0].set_{1}('{2}')".format(l, key, val))


    def new_plt(self, xdata=[], ydata=[], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        ''' 
        plot data of in new plots 
        #TODO need to define xdata, ydata structure
        [[x1], [x2], ...] ?
        '''
        self.initax_xy()
        # set label of ax[1]
        self.set_ax(self.ax[0], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs)

        # self.l = {}
        for i, x, y in enumerate(zip(xdata, ydata)):
            self.l[i] = self.ax[0].plot(
                x, y, 
                marker='o', 
                markerfacecolor='none', 
            ) # l[i]

        self.ax[0].autoscale()


    def add_temp_lines(self, ax=None, xlist=[], ylist=[], label_list=[]):
        '''
        add line in self.l['temp'][i]
        all the lines share the same xdata
        '''
        logger.info('add_temp_lines')
        if len(label_list) == 0:
            label_list = [''] * len(xlist) # make up a label_list with all ''
        for (x, y, label) in zip(xlist, ylist, label_list):
            # logger.info('len x: %s', len(x)) 
            # logger.info('len y: %s', len(y)) 
            # logger.info(x) 
            # logger.info(y) 
            
            if ax is None:
                ax = self.ax[0]

            self.l['temp'].append(ax.plot(
                x, y,
                linestyle='--',
                color=color[-1],
                )
            )

            if label:
                self.l['temp'][-1][0].set_label(label)
        
        self.canvas_draw()


    def canvas_draw(self):
        '''
        redraw canvas after data changed
        '''
        self.canvas.draw()
        # self.canvas.draw_idle()
        # self.canvas.draw_event()
        # self.canvas.draw_cursor()
        self.canvas.flush_events() # flush the GUI events 


    def data_show_all(self):
        for ax in self.ax:
            ax.set_autoscale_on(True) # this reactive autoscale which might be turnned of by zoom/pan
            self.reset_ax_lim(ax)


    def reset_ax_lim(self, ax):
        '''
        reset the lim of ax
        this change the display and where home button goes back to
        '''
        # turn off PAN/ZOOM
        # if self.toolbar._active == "PAN":
        #     self.toolbar.pan()
        # elif self.toolbar._active == "ZOOM":
        #     self.toolbar.zoom()
        
        # self.canvas.toolbar.update() # reset toolbar memory
        # self.canvas.toolbar.push_current() # set current to memory

        # ax.set_autoscale_on(True) # this reactive autoscale which might be turnned of by zoom/pan
        ax.relim(visible_only=True)
        ax.autoscale_view(True,True,True)
class Human_Reviser(Frame):
    def __init__(self, parent, controller):
        Frame.__init__(self, parent)
        label = Label(self, text='Human Reviser', font=LARGE_FONT)
        label.grid(row=0, column=2, columnspan=1, sticky="e", padx=10, pady=10)

        # Choose and display resource file
        self.output_res = Text(self,
                               width=55,
                               height=1,
                               wrap="word",
                               bg="white")
        self.output_res.grid(row=1,
                             column=0,
                             columnspan=5,
                             sticky="nw",
                             padx=10,
                             pady=5)
        btn_file = Button(self,
                          text="Choose Resource",
                          command=self.choose_resource)
        btn_file.grid(row=1, column=6, sticky="nw", padx=2, pady=5)

        # Choose and display AI keypoints file
        self.output_AI = Text(self,
                              width=55,
                              height=1,
                              wrap="word",
                              bg="white")
        self.output_AI.grid(row=2,
                            column=0,
                            columnspan=5,
                            sticky="nw",
                            padx=10,
                            pady=5)
        btn_file = Button(self,
                          text="Choose AI Result",
                          command=self.choose_AIkpts)
        btn_file.grid(row=2, column=6, sticky="nw", padx=2, pady=5)

        # Choose and display flags file
        self.output_Flags = Text(self,
                                 width=55,
                                 height=1,
                                 wrap="word",
                                 bg="white")
        self.output_Flags.grid(row=3,
                               column=0,
                               columnspan=5,
                               sticky="nw",
                               padx=10,
                               pady=5)
        btn_file = Button(self,
                          text="Choose Review Result",
                          command=self.choose_flags)
        btn_file.grid(row=3, column=6, sticky="nw", padx=2, pady=5)

        btn_fix = Button(self,
                         text="Start Revising",
                         command=self.revise_label)
        btn_fix.grid(row=4, column=1, sticky="nw", padx=5, pady=10)

        if branch == 1:
            menu = BodyMenu
            name = "BodyMenu"
        elif branch == 2:
            menu = FaceMenu
            name = "FaceMenu"
        else:
            menu = BranchMenu
            name = "BranchMenu"

        button1 = Button(
            self,
            text='Go Back',  # likewise StartPage
            command=lambda: controller.show_frame(menu, name, branch))
        button1.grid(row=4, column=2, sticky="nw", padx=40, pady=10)

        button2 = Button(
            self,
            text='Exit',  # likewise StartPage
            command=lambda: controller.exit())
        button2.grid(row=4, column=3, sticky="nw", padx=10, pady=10)

    def choose_resource(self):
        # load resource images/frames folder
        self.resource = filedialog.askdirectory()
        self.output_res.delete(0.0, END)
        self.output_res.insert(END, self.resource)

    def choose_AIkpts(self):
        # load AI keypoints file
        self.AIfile = filedialog.askopenfilename(
            initialdir='.',
            filetypes=(("Pickle File", "*.pkl"), ("All Files", "*.*")),
            title="Choose a file")
        self.output_AI.delete(0.0, END)
        self.output_AI.insert(END, self.AIfile)

        # extract model name
        filename, file_extension = os.path.splitext(
            os.path.basename(self.AIfile))
        string = filename.split('_')[-1]
        self.model = []
        print(string)
        if string == "opencv":
            self.model = "OpenCV"
        elif string == "hg":
            self.model = "Hourglass"
        elif string == "fRCNN":
            self.model = "Faster R-CNN"
        elif string == "fan":
            self.model = "FAN"

    def choose_flags(self):
        # load flags file
        self.Flagsfile = filedialog.askopenfilename(
            initialdir='.',
            filetypes=(("Pickle File", "*.pkl"), ("All Files", "*.*")),
            title="Choose a file")
        self.output_Flags.delete(0.0, END)
        self.output_Flags.insert(END, self.Flagsfile)

    def revise_label(self):
        global dict_flags, result
        dict_flags.clear()

        # load images list
        types = ('*.jpg', '*.png', '*.jpeg')
        files_grabbed = []
        for files in types:
            files_grabbed.extend(glob.glob(os.path.join(self.resource, files)))
        self.im_list = sorted(files_grabbed)

        # load AI kpts array
        with open(self.AIfile, 'rb') as f:
            data = pickle.load(f)

        if self.model == "Hourglass" or self.model == "Faster R-CNN" or self.model == "FAN":
            org_kpts = data['all_keyps'][1]
            if self.model == "Hourglass":
                frames_boxes = data['all_boxes'][1]
            else:
                frames_boxes = data['all_boxes'][0]
            result_kpts = copy.deepcopy(org_kpts)
            result['images'] = []
            result['all_keyps'] = [[], result_kpts]
            result['all_boxes'] = [[] for i in range(len(result_kpts))]
            self.frames_kpts = result['all_keyps'][1]
        # elif self.model == "Mask R-CNN":
        #     result = copy.deepcopy(data)
        #     self.frames_kpts = result['all_keyps'][1]
        #     self.frames_boxes = result['all_boxes'][1]

        self.num_frames = len(self.frames_kpts)
        print("Total frames: ", self.num_frames)

        # load reviewed flags array
        with open(self.Flagsfile, 'rb') as f:
            self.dict_flags = pickle.load(f)

        self.idx = 0
        self.show_flags()

    def show_flags(self):
        global num_kpts, num_poses, txt_list, plot_list, fix, fixed, vis_pose_idx, result, bbox
        # initialize global variables
        fix = []
        fixed = []
        txt_list = []
        plot_list = []
        num_kpts = 0
        num_poses = 0
        vis_pose_idx = []
        bbox = []

        # load current image
        im_name = os.path.basename(self.im_list[self.idx])
        img = cv2.imread(self.im_list[self.idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        result['images'].append(im_name)

        # load current flags
        lists_flags = self.dict_flags[str(self.idx)]
        arr_flags = np.asarray(lists_flags)
        print("flag", arr_flags)

        num_poses = len(self.frames_kpts[self.idx])
        num_kpts = self.frames_kpts[self.idx][0].shape[1]

        # indexes of all false positive keypoints (incorrectly detected) for current image
        delete = np.where(arr_flags.flatten() == -1)[0]
        print('delete', delete)
        for i in range(len(delete)):
            pose = int(delete[i] / num_kpts)
            joint = delete[i] % num_kpts
            if self.model == "Hourglass" or self.model == "Faster R-CNN" or self.model == "FAN":
                # set the x,y to negative value
                result['all_keyps'][1][self.idx][pose][0][joint] = -45
                result['all_keyps'][1][self.idx][pose][1][joint] = -45
            # elif self.model == "Mask R-CNN":
            #     # decrease the confidence of false positive keypoints (less than 2)
            #     result['all_keyps'][1][self.idx][pose][2][joint] = 1.5

        # indexes of all false negative keypoints (incorrectly undetected) for current image
        insert = np.where(arr_flags.flatten() == 2)[0]
        print('insert', insert)
        for i in range(len(insert)):
            pose = int(insert[i] / num_kpts)
            joint = insert[i] % num_kpts
            if self.model == "Hourglass" or self.model == "Faster R-CNN" or self.model == "FAN":
                # set the x,y as coorindates of neck keypoint
                result['all_keyps'][1][self.idx][pose][0][joint] = result[
                    'all_keyps'][1][self.idx][pose][0][8]
                result['all_keyps'][1][self.idx][pose][1][joint] = result[
                    'all_keyps'][1][self.idx][pose][1][8]
            # elif self.model == "Mask R-CNN":
            #     # increase the confidence of false negative keypoints (more than 2)
            #     result['all_keyps'][1][self.idx][pose][2][joint] = 2.5

        # load current keypoints
        lists_kpts = self.frames_kpts[self.idx]
        print("kpt", lists_kpts)

        if self.model == "Hourglass" or self.model == "Faster R-CNN" or self.model == "FAN":
            # single-person has only one pose
            vis_pose_idx = [0]
        # elif self.model == "Mask R-CNN":
        #   # multi-person has multiple poses
        #     # load current boxes
        #     lists_poses = self.frames_boxes[self.idx]
        #     vis_pose_idx = helpers.visposes(lists_poses)
        lists_vis, flatten_vis = helpers.viskpts(img, lists_kpts, vis_pose_idx,
                                                 self.model)
        # print("vis", lists_vis)
        flat_arr_flags = arr_flags.flatten()
        for i in range(flat_arr_flags.shape[0]):
            if flat_arr_flags[i] == 0 or flat_arr_flags[i] == 2:
                fix.append(i)

        self.fig, (self.ax1, self.ax2) = plt.subplots(
            1, 2, gridspec_kw={'width_ratios': [1, 6]}, figsize=(10, 6))
        self.fig.canvas.set_window_title(im_name)

        # draw keypoints on image
        img = helpers.drawkpts(img, lists_kpts, lists_vis, self.model)

        # add flags on image
        num_pose = len(lists_kpts)
        for num in range(num_pose):
            x_kpts = lists_kpts[num][0]
            y_kpts = lists_kpts[num][1]
            points = np.append([x_kpts], [y_kpts], axis=0)
            flags = lists_flags[num]
            for i in range(points.shape[1]):
                if flags[i] == 0 or flags[i] == 2:
                    plt.plot(int(points[0, i]),
                             int(points[1, i]),
                             'ro',
                             markersize=8)
                    plt.text(int(points[0, i]),
                             int(points[1, i]),
                             str(num) + "_" + str(i),
                             color='r',
                             fontsize=12)

        # display keypoints reference in left subplot
        self.display_annotation(self.ax1)
        # show image with keypoints and flags in right subplot
        self.ax2.imshow(img)
        self.ax2.set_axis_off()
        plt.tight_layout()

        # bind button and key with figure
        self.fig.canvas.mpl_connect('button_press_event', self.onclick_revise)
        self.fig.canvas.mpl_connect('key_press_event', self.onkey_revise)
        self.rs = RectangleSelector(
            self.ax2,
            self.line_select_callback,
            drawtype='box',
            useblit=False,
            button=[1],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)
        self.rs.set_active(False)
        plt.show()

    def update_flags(self):
        global result
        for num in range(num_poses):
            result['all_keyps'][1][self.idx][num] = np.append(
                result['all_keyps'][1][self.idx][num],
                np.ones((1, num_kpts)),
                axis=0)
            result['all_boxes'][self.idx].append(bbox)
        for i in range(len(fixed)):
            pose = int(fix[i] / num_kpts)
            joint = fix[i] % num_kpts
            # print(result['all_keyps'][1][self.idx][pose][0][joint])
            result['all_keyps'][1][self.idx][pose][0][joint] = fixed[i][0]
            result['all_keyps'][1][self.idx][pose][1][joint] = fixed[i][1]
            result['all_keyps'][1][self.idx][pose][-1][joint] = fixed[i][2]

        if self.model == "Faster R-CNN":
            for num in range(num_poses):
                for joint in range(
                        len(result['all_keyps'][1][self.idx][num][-1])):
                    if (result['all_keyps'][1][self.idx][num][-1][joint] == 1):
                        result['all_keyps'][1][self.idx][num][-1][joint] = 2
                    if (result['all_keyps'][1][self.idx][num][-1][joint] == 0):
                        result['all_keyps'][1][self.idx][num][-1][joint] = 1
                    if (result['all_keyps'][1][self.idx][num][0][joint] == -45
                            and result['all_keyps'][1][self.idx][num][1][joint]
                            == -45):
                        result['all_keyps'][1][self.idx][num][-1][joint] = 0

        self.idx = self.idx + 1
        if self.idx < len(self.frames_kpts):
            self.show_flags()
        else:
            helpers.savepkl(result, self.resource, "gt")
            messagebox.showinfo(
                "Information",
                "All frames are revised and keypoints are saved!")
            plt.close()

    def onclick_revise(self, event):
        global txt_list, fixed, plot_list
        if len(fixed) < len(fix):
            if event.button == 1 and event.inaxes == self.ax2:
                plot, = plt.plot(event.xdata, event.ydata, 'bo', markersize=8)
                txt = plt.text(event.xdata,
                               event.ydata,
                               str(fix[len(fixed)] % num_kpts) + '_vis',
                               horizontalalignment='right',
                               verticalalignment='bottom',
                               color='b',
                               fontsize=12)
                fixed.append([event.xdata, event.ydata, 1])
                plot_list.append(plot)
                txt_list.append(txt)
            elif event.button == 3 and event.inaxes == self.ax2:
                plot, = plt.plot(event.xdata, event.ydata, 'bo', markersize=8)
                txt = plt.text(event.xdata,
                               event.ydata,
                               str(fix[len(fixed)] % num_kpts) + '_invis',
                               horizontalalignment='right',
                               verticalalignment='bottom',
                               color='b',
                               fontsize=12)
                fixed.append([event.xdata, event.ydata, 0])
                plot_list.append(plot)
                txt_list.append(txt)
            self.fig.canvas.draw()
        else:
            print("Please select head bounding box.")
            self.rs.set_active(True)
            if bbox != []:
                print(
                    "Revising has done!, Please press 'y' to revise next image..."
                )

    def onkey_revise(self, event):
        global txt_list, plot_list, fixed
        if event.key == "u" and len(fixed) != 0:
            print("revise undo")
            # print(plot_list[-1])
            plot_list[-1].remove()
            txt_list[-1].remove()
            self.fig.canvas.draw()
            del plot_list[-1]
            del txt_list[-1]
            del fixed[-1]
        elif event.key == "y" and fix == [] and bbox != []:
            plt.close()
            print("Revise next image")
            self.update_flags()
        elif event.key == "y" and len(fixed) == len(fix) and bbox != []:
            plt.close()
            print("Revise next image")
            self.update_flags()
        else:
            print("Please continue to revise!")

    def line_select_callback(self, eclick, erelease):
        global bbox
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        bbox = [x1, y1, x2, y2]
        print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))

    def display_annotation(self, ax):
        str_list = dict_model[self.model].split(",")
        ax.set_axis_off()
        ax.set_ylim((0, len(str_list) + 2))
        for i in range(len(str_list)):
            ax.text(0, (len(str_list) - i), str_list[i], fontsize=9)
        ax.text(0, (len(str_list) + 1), "Keypoints Reference:", fontsize=12)
Ejemplo n.º 19
0
    def do_ui(self, workspace, pixel_data, labels):
        '''Display a UI for editing'''
        import matplotlib
        from matplotlib.widgets import Lasso, RectangleSelector
        from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg
        import wx
        
        style = wx.DEFAULT_DIALOG_STYLE | wx.RESIZE_BORDER
        dialog_box = wx.Dialog(workspace.frame, -1,
                               "Identify objects manually",
                               style = style)
        sizer = wx.BoxSizer(wx.VERTICAL)
        dialog_box.SetSizer(sizer)
        sub_sizer = wx.BoxSizer(wx.HORIZONTAL)
        sizer.Add(sub_sizer, 1, wx.EXPAND)
        figure = matplotlib.figure.Figure()
        axes = figure.add_subplot(1,1,1)
        panel = FigureCanvasWxAgg(dialog_box, -1, figure)
        sub_sizer.Add(panel, 1, wx.EXPAND)
        #
        # The controls are the radio buttons for tool selection and
        # a zoom out button
        #
        controls_sizer = wx.BoxSizer(wx.VERTICAL)
        sub_sizer.Add(controls_sizer, 0, 
                      wx.ALIGN_CENTER_HORIZONTAL | wx.ALIGN_TOP | 
                      wx.EXPAND|wx.ALL, 10)
        
        tool_choice = wx.RadioBox(dialog_box, -1, "Active tool",
                                  style = wx.RA_VERTICAL,
                                  choices = [TOOL_OUTLINE, TOOL_ZOOM_IN, 
                                             TOOL_ERASE])
        tool_choice.SetSelection(0)
        controls_sizer.Add(tool_choice, 0, wx.ALIGN_LEFT | wx.ALIGN_TOP)
        zoom_out_button = wx.Button(dialog_box, -1, "Zoom out")
        zoom_out_button.Disable()
        controls_sizer.Add(zoom_out_button, 0, wx.ALIGN_CENTER_HORIZONTAL)
        erase_last_button = wx.Button(dialog_box, -1, "Erase last")
        erase_last_button.Disable()
        controls_sizer.Add(erase_last_button, 0, wx.ALIGN_CENTER_HORIZONTAL)
        erase_all_button = wx.Button(dialog_box, -1, "Erase all")
        erase_all_button.Disable()
        controls_sizer.Add(erase_all_button, 0, wx.ALIGN_CENTER_HORIZONTAL)
        
        zoom_stack = []
        
        ########################
        #
        # The drawing function
        #
        ########################
        def draw():
            '''Draw the current display'''
            assert isinstance(axes, matplotlib.axes.Axes)
            if len(axes.images) > 0:
                del axes.images[0]
            image = self.draw_outlines(pixel_data, labels)
            axes.imshow(image)
            if len(zoom_stack) > 0:
                axes.set_xlim(zoom_stack[-1][0][0], zoom_stack[-1][1][0])
                axes.set_ylim(zoom_stack[-1][0][1], zoom_stack[-1][1][1])
            else:
                axes.set_xlim(0, pixel_data.shape[1])
                axes.set_ylim(0, pixel_data.shape[0])
            figure.canvas.draw()
            panel.Refresh()

        ##################################
        #
        # The erase last button
        #
        ##################################
        def on_erase_last(event):
            erase_label = labels.max()
            if erase_label > 0:
                labels[labels == erase_label] = 0
                labels[labels > erase_label] -= 1
                draw()
                if labels.max() == 0:
                    erase_last_button.Disable()
                    erase_all_button.Disable()
            else:
                erase_last_button.Disable()
                erase_all_button.Disable()
               
        dialog_box.Bind(wx.EVT_BUTTON, on_erase_last, erase_last_button)

        ##################################
        #
        # The erase all button
        #
        ##################################
        def on_erase_all(event):
            labels[labels > 0] = 0
            draw()
            erase_all_button.Disable()
            erase_last_button.Disable()
            
        dialog_box.Bind(wx.EVT_BUTTON, on_erase_all, erase_all_button)

        ##################################
        #
        # The zoom-out button
        #
        ##################################
        def on_zoom_out(event):
            zoom_stack.pop()
            if len(zoom_stack) == 0:
                zoom_out_button.Disable()
            draw()
            
        dialog_box.Bind(wx.EVT_BUTTON, on_zoom_out, zoom_out_button)

        ##################################
        #
        # Zoom selector callback
        #
        ##################################
        
        def on_zoom_in(event_click, event_release):
            xmin = min(event_click.xdata, event_release.xdata)
            xmax = max(event_click.xdata, event_release.xdata)
            ymin = min(event_click.ydata, event_release.ydata)
            ymax = max(event_click.ydata, event_release.ydata)
            zoom_stack.append(((xmin, ymin), (xmax, ymax)))
            draw()
            zoom_out_button.Enable()
        
        zoom_selector = RectangleSelector(axes, on_zoom_in, drawtype='box',
                                          rectprops = dict(edgecolor='red', 
                                                           fill=False),
                                          useblit=True,
                                          minspanx=2, minspany=2,
                                          spancoords='data')
        zoom_selector.set_active(False)
        
        ##################################
        #
        # Lasso selector callback
        #
        ##################################
        
        current_lasso = []
        def on_lasso(vertices):
            lasso = current_lasso.pop()
            figure.canvas.widgetlock.release(lasso)
            mask = np.zeros(pixel_data.shape[:2], int)
            new_label = np.max(labels) + 1
            vertices = [x for x in vertices 
                        if x[0] is not None and x[1] is not None]
            for i in range(len(vertices)):
                v0 = (int(vertices[i][1]), int(vertices[i][0]))
                i_next = (i+1) % len(vertices)
                v1 = (int(vertices[i_next][1]), int(vertices[i_next][0]))
                draw_line(mask, v0, v1, new_label)
            mask = fill_labeled_holes(mask)
            labels[mask != 0] = new_label
            draw()
            if labels.max() > 0:
                erase_all_button.Enable()
                erase_last_button.Enable()
            
        ##################################
        #
        # Left mouse button down
        #
        ##################################
        
        def on_left_mouse_down(event):
            if figure.canvas.widgetlock.locked():
                return
            if event.inaxes != axes:
                return
            
            idx = tool_choice.GetSelection()
            tool = tool_choice.GetItemLabel(idx)
            if tool == TOOL_OUTLINE:
                lasso = Lasso(axes, (event.xdata, event.ydata), on_lasso)
                lasso.line.set_color('red')
                current_lasso.append(lasso)
                figure.canvas.widgetlock(lasso)
            elif tool == TOOL_ERASE:
                erase_label = labels[int(event.ydata), int(event.xdata)]
                if erase_label > 0:
                    labels[labels == erase_label] = 0
                    labels[labels > erase_label] -= 1
                draw()
                if labels.max() == 0:
                    erase_all_button.Disable()
                    erase_last_button.Disable()
            
        figure.canvas.mpl_connect('button_press_event', on_left_mouse_down)
        
        ######################################
        #
        # Radio box change
        #
        ######################################
        def on_radio(event):
            idx = tool_choice.GetSelection()
            tool = tool_choice.GetItemLabel(idx)
            if tool == TOOL_ZOOM_IN:
                zoom_selector.set_active(True)
        
        tool_choice.Bind(wx.EVT_RADIOBOX, on_radio)
        
        button_sizer = wx.StdDialogButtonSizer()
        sizer.Add(button_sizer, 0, wx.ALIGN_RIGHT | wx.EXPAND | wx.ALL, 10)
        button_sizer.AddButton(wx.Button(dialog_box, wx.ID_OK))
        button_sizer.Realize()
        draw()
        dialog_box.Fit()
        dialog_box.ShowModal()
        dialog_box.Destroy()
Ejemplo n.º 20
0
class ProjectionWidget(MplWidget):
    """ ProjectionWidget is a widget to show 2D projections. It has some interactive
    features like, show labels, selections, and synchronization.
    """

    def __init__(self, parent=None):
        MplWidget.__init__(self, parent)

        self.showLabels = False
        self.toolBarType = QProjectionToolBar

    def draw(self):
        """draw(fig: Figure) ->None
        code using matplotlib.
        Use self.fig and self.figManager
        """

        (self.coord, self.matrix, title) = self.inputPorts

        # for faster access
        id2pos = {idd: pos for (pos, idd) in enumerate(self.matrix.ids)}
        circleSize = np.ones(len(self.matrix.ids))
        for selId in self.selectedIds:
            circleSize[id2pos[selId]] = 4

        pylab.clf()
        pylab.axis("off")

        pylab.title(title)
        pylab.scatter(
            self.matrix.values[:, 0],
            self.matrix.values[:, 1],
            #                       c=colors, cmap=pylab.cm.Spectral,
            s=40,
            linewidth=circleSize,
            marker="o",
        )

        # draw labels
        if self.showLabels and self.matrix.labels is not None:
            for label, xy in zip(self.matrix.labels, self.matrix.values):
                pylab.annotate(
                    str(label),
                    xy=xy,
                    xytext=(5, 5),
                    textcoords="offset points",
                    bbox=dict(boxstyle="round,pad=0.2", fc="yellow", alpha=0.5),
                )

        self.figManager.canvas.draw()

        # Set Selectors
        self.rectSelector = RectangleSelector(
            pylab.gca(), self.onselect, drawtype="box", rectprops=dict(alpha=0.4, facecolor="yellow")
        )
        self.rectSelector.set_active(True)

    def updateSelection(self, selectedIds):
        self.selectedIds = selectedIds
        self.updateContents()

    def onselect(self, eclick, erelease):
        if self.coord is None:
            return
        left, bottom = min(eclick.xdata, erelease.xdata), min(eclick.ydata, erelease.ydata)
        right, top = max(eclick.xdata, erelease.xdata), max(eclick.ydata, erelease.ydata)
        region = Bbox.from_extents(left, bottom, right, top)

        selectedIds = []
        for (xy, idd) in zip(self.matrix.values, self.matrix.ids):
            if region.contains(xy[0], xy[1]):
                selectedIds.append(idd)
        self.coord.notifyModules(selectedIds)
Ejemplo n.º 21
0
class TaylorDiagramWidget(MplWidget):
    def __init__(self, parent=None):
        MplWidget.__init__(self, parent)
        
        self.markers = ['o','x','*',',','+','.','s','v','<','>','^','D','h','H','_','8',
                        'd',3,0,1,2,7,4,5,6,'1','3','4','2','|','x']

    def draw(self):
        (self.coord, self.stats, title, showLegend) = self.inputPorts
                
        stds, corrs = self.stats.values[:,0], self.stats.values[:,1]
        self.Xs = stds*corrs
        self.Ys = stds*np.sin(np.arccos(corrs))
        
        colors = pylab.cm.jet(np.linspace(0,1,len(self.stats.ids)))

        pylab.clf()
        fig = pylab.figure(str(self))
        dia = taylor_diagram.TaylorDiagram(stds[0], corrs[0], fig=fig, label=self.stats.labels[0])
        dia.samplePoints[0].set_color(colors[0])  # Mark reference point as a red star
        if self.stats.ids[0] in self.selectedIds: dia.samplePoints[0].set_markeredgewidth(3)
        
        # add models to Taylor diagram
        for i, (_id, stddev,corrcoef) in enumerate(zip(self.stats.ids[1:], stds[1:], corrs[1:])):
            label = self.stats.labels[i+1]
            size = 3 if _id in self.selectedIds else 1
            dia.add_sample(stddev, corrcoef,
                           marker='o', #self.markers[i],
                           ls='',
                           mfc=colors[i+1],
                           mew = size,
                           label=label
                           )

        # Add grid
        dia.add_grid()

        # Add RMS contours, and label them
        contours = dia.add_contours(levels=5, colors='0.5') # 5 levels in grey
        pylab.clabel(contours, inline=1, fontsize=10, fmt='%.1f')

        # Add a figure legend and title
        if showLegend:
            fig.legend(dia.samplePoints,
                       [ p.get_label() for p in dia.samplePoints ],
                       numpoints=1, prop=dict(size='small'), loc='upper right')
        fig.suptitle(title, size='x-large') # Figure title
        self.figManager.canvas.draw()
        
        self.rectSelector = RectangleSelector(pylab.gca(), self.onselect, drawtype='box', 
                                              rectprops=dict(alpha=0.4, facecolor='yellow'))
        self.rectSelector.set_active(True)

    def updateSelection(self, selectedIds):
        self.selectedIds = selectedIds
        self.updateContents();
    
    def onselect(self, eclick, erelease):
        if (self.coord is None): return

        left, bottom = min(eclick.xdata, erelease.xdata), min(eclick.ydata, erelease.ydata)
        right, top = max(eclick.xdata, erelease.xdata), max(eclick.ydata, erelease.ydata)
        region = Bbox.from_extents(left, bottom, right, top)
        
        selectedIds = []
        for (x, y, idd) in zip(self.Xs, self.Ys, self.stats.ids):
            if region.contains(x, y):
                selectedIds.append(idd)
        self.coord.notifyModules(selectedIds)
Ejemplo n.º 22
0
class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red', edgecolor = 'black',
                              alpha=0.5, fill=True)
    selector_lineprops = dict(color='black', linestyle='-',
                              linewidth = 2, alpha=0.5)
    def __init__(self, data=None, bands=None, classes=None, source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)
        
        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors
 
        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None
        
        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()
        
    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
          (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source
    
    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)
            
        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry
        
        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()
        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback


        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
          self.selection = None
          return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()
    
    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({'cmap': cm, 'vmin': 0, 'norm': NoNorm(),
                       'interpolation': self._interpolation})
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes, mask=(self.classes==0))
        else:
            self.class_rgb = np.array(self.classes)
        
    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()


    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
Ejemplo n.º 23
0
class plot_2d_data(wx.Frame):
    """Generic 2d plotting routine - inputs are:
    - data (2d array of values),
    - x and y extent of the data,
    - title of graph, and
    - pixel mask to be used during summation  - must have same dimensions as data
    (only data entries corresponding to nonzero values in pixel_mask will be summed)
    - plot_title, x_label and y_label are added to the 2d-plot as you might expect"""

    def __init__(
        self,
        data,
        extent,
        caller=None,
        scale="log",
        window_title="log plot",
        pixel_mask=None,
        plot_title="data plot",
        x_label="x",
        y_label="y",
        parent=None,
    ):
        wx.Frame.__init__(self, parent=None, title=window_title, pos=wx.DefaultPosition, size=wx.Size(800, 600))
        print parent
        self.extent = extent
        self.data = data
        self.caller = caller
        self.window_title = window_title
        x_range = extent[0:2]
        # x_range.sort()
        self.x_min, self.x_max = x_range
        y_range = extent[2:4]
        # y_range.sort()
        self.y_min, self.y_max = y_range
        self.plot_title = plot_title
        self.x_label = x_label
        self.y_label = y_label
        self.slice_xy_range = (x_range, y_range)
        self.ID_QUIT = wx.NewId()
        self.ID_LOGLIN = wx.NewId()
        self.ID_UPCOLLIM = wx.NewId()
        self.ID_LOWCOLLIM = wx.NewId()

        menubar = wx.MenuBar()
        filemenu = wx.Menu()
        quit = wx.MenuItem(filemenu, 1, "&Quit\tCtrl+Q")
        # quit.SetBitmap(wx.Bitmap('icons/exit.png'))
        filemenu.AppendItem(quit)

        plotmenu = wx.Menu()
        self.menu_log_lin_toggle = plotmenu.Append(
            self.ID_LOGLIN, "Plot 2d data with log color scale", "plot 2d on log scale", kind=wx.ITEM_CHECK
        )
        self.Bind(wx.EVT_MENU, self.toggle_2d_plot_scale, id=self.ID_LOGLIN)
        menu_upper_colormap_limit = plotmenu.Append(
            self.ID_UPCOLLIM, "Set upper limit of color map", "Set upper limit of color map"
        )
        self.Bind(wx.EVT_MENU, self.set_new_upper_color_limit, id=self.ID_UPCOLLIM)
        menu_lower_colormap_limit = plotmenu.Append(
            self.ID_LOWCOLLIM, "Set lower limit of color map", "Set lower limit of color map"
        )
        self.Bind(wx.EVT_MENU, self.set_new_lower_color_limit, id=self.ID_LOWCOLLIM)
        # live_on_off = wx.MenuItem(live_update, 1, '&Live Update\tCtrl+L')
        # quit.SetBitmap(wx.Bitmap('icons/exit.png'))
        # live_update.AppendItem(self.live_toggle)
        # self.menu_log_lin_toggle.Check(True)

        menubar.Append(filemenu, "&File")
        menubar.Append(plotmenu, "&Plot")
        self.SetMenuBar(menubar)
        self.Centre()

        if pixel_mask == None:
            pixel_mask = ones(data.shape)

        if pixel_mask.shape != data.shape:
            print "Warning: pixel mask shape incompatible with data"
            pixel_mask = ones(data.shape)

        self.pixel_mask = pixel_mask

        self.show_data = transpose(data.copy())
        # self.minimum_intensity = self.data[pixel_mask.nonzero()].min()
        # correct for floating-point weirdness:
        self.minimum_intensity = self.data[self.data > 1e-17].min()

        # if scale == 'log':
        # self.show_data = log ( self.data.copy().T + self.minimum_intensity/2.0 )
        # self._scale = 'log'
        # self.menu_log_lin_toggle.Check(True)

        # elif (scale =='lin' or scale == 'linear'):
        # self._scale = 'lin'
        # self.menu_log_lin_toggle.Check(True)

        # self.bin_data = caller.bin_data
        # self.params = caller.params
        # fig = figure()
        self.fig = Figure(dpi=80, figsize=(5, 5))
        # self.fig = figure()
        fig = self.fig
        self.canvas = Canvas(self, -1, self.fig)
        self.show_sliceplots = False  # by default, sliceplots on
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.TOP | wx.LEFT | wx.EXPAND)

        # self.toolbar = Toolbar(self.canvas)
        self.toolbar = MyNavigationToolbar(self.canvas, True, self)
        self.toolbar.Realize()
        if wx.Platform == "__WXMAC__":
            # Mac platform (OSX 10.3, MacPython) does not seem to cope with
            # having a toolbar in a sizer. This work-around gets the buttons
            # back, but at the expense of having the toolbar at the top
            self.SetToolBar(self.toolbar)
        else:
            # On Windows platform, default window size is incorrect, so set
            # toolbar width to figure width.
            tw, th = self.toolbar.GetSizeTuple()
            fw, fh = self.canvas.GetSizeTuple()
            # By adding toolbar in sizer, we are able to put it at the bottom
            # of the frame - so appearance is closer to GTK version.
            # As noted above, doesn't work for Mac.
            self.toolbar.SetSize(wx.Size(fw, th))
            self.sizer.Add(self.toolbar, 0, wx.LEFT | wx.EXPAND)

        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetFieldsCount(2)
        self.statusbar.SetStatusWidths([-1, -2])
        self.statusbar.SetStatusText("Current Position:", 0)

        self.canvas.mpl_connect("motion_notify_event", self.onmousemove)
        # self.canvas.mpl_connect('button_press_event', self.right_click_handler)
        # self.axes = fig.add_subplot(111)
        # self.axes = self.fig.gca()
        # ax = self.axes
        self.mapper = FigureImage(self.fig)
        # im = self.axes.pcolor(x,y,V,shading='flat')
        # self.mapper.add_observer(im)

        # self.show_data = transpose(log(self.show_data + self.minimum_intensity / 2.0))

        # self.canvas.mpl_connect('pick_event', self.log_lin_select)

        ax = fig.add_subplot(221, label="2d_plot")
        fig.sx = fig.add_subplot(222, label="sx", picker=True)
        fig.sx.xaxis.set_picker(True)
        fig.sx.yaxis.set_picker(True)
        fig.sx.yaxis.set_ticks_position("right")
        fig.sx.set_zorder(1)
        fig.sz = fig.add_subplot(223, label="sz", picker=True)
        fig.sz.xaxis.set_picker(True)
        fig.sz.yaxis.set_picker(True)
        fig.sz.set_zorder(1)
        self.RS = RectangleSelector(ax, self.onselect, drawtype="box", useblit=True)
        fig.slice_overlay = None

        ax.set_position([0.125, 0.1, 0.7, 0.8])
        fig.cb = fig.add_axes([0.85, 0.1, 0.05, 0.8])
        fig.cb.set_zorder(2)

        fig.ax = ax
        fig.ax.set_zorder(2)
        self.axes = ax
        ax.set_title(plot_title)
        # connect('key_press_event', self.toggle_selector)
        if scale == "log":
            self.show_data = log(self.data.copy().T + self.minimum_intensity / 2.0)
            self.__scale = "log"
            self.fig.cb.set_xlabel("$\log_{10}I$")
            self.menu_log_lin_toggle.Check(True)

        elif scale == "lin" or scale == "linear":
            self.__scale = "lin"
            self.fig.cb.set_xlabel("$I$")
            self.menu_log_lin_toggle.Check(False)

        im = self.axes.imshow(
            self.show_data, interpolation="nearest", aspect="auto", origin="lower", cmap=cm.jet, extent=extent
        )
        # im = ax.imshow(data, interpolation='nearest', aspect='auto', origin='lower',cmap=cm.jet, extent=extent)
        fig.im = im
        ax.set_xlabel(x_label, size="large")
        ax.set_ylabel(y_label, size="large")
        self.toolbar.update()
        # zoom_colorbar(im)

        # fig.colorbar(im, cax=fig.cb)
        zoom_colorbar(im=im, cax=fig.cb)
        # figure(fig.number)
        # fig.canvas.draw()
        # return

        self.SetSizer(self.sizer)
        self.Fit()

        self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.OnContext)
        self.Bind(wx.EVT_CLOSE, self.onExit)
        self.sliceplots_off()
        self.SetSize(wx.Size(800, 600))
        self.canvas.draw()
        return

    def onExit(self, event):
        self.Destroy()

    def exit(self, event):
        wx.GetApp().Exit()

    def set_new_upper_color_limit(self, evt=None):
        current_uplim = self.fig.im.get_clim()[1]
        current_lowlim = self.fig.im.get_clim()[0]
        dlg = wx.TextEntryDialog(
            None, "Change upper limit of color map (currently %f)" % current_uplim, defaultValue="%f" % current_uplim
        )
        if dlg.ShowModal() == wx.ID_OK:
            new_val = dlg.GetValue()
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()
            self.fig.im.set_clim((current_lowlim, float(new_val)))
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()
        dlg.Destroy()

    def set_new_lower_color_limit(self, evt=None):
        current_uplim = self.fig.im.get_clim()[1]
        current_lowlim = self.fig.im.get_clim()[0]
        dlg = wx.TextEntryDialog(
            None, "Change lower limit of color map (currently %f)" % current_lowlim, defaultValue="%f" % current_lowlim
        )
        if dlg.ShowModal() == wx.ID_OK:
            new_val = dlg.GetValue()
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()
            self.fig.im.set_clim((float(new_val), current_uplim))
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()
        dlg.Destroy()

    def OnContext(self, evt):
        print self.show_sliceplots
        mpl_x = evt.X
        mpl_y = self.fig.canvas.GetSize()[1] - evt.Y
        mpl_mouseevent = matplotlib.backend_bases.MouseEvent("button_press_event", self.canvas, mpl_x, mpl_y, button=3)

        if mpl_mouseevent.inaxes == self.fig.ax:
            self.area_context(mpl_mouseevent, evt)
        elif (mpl_mouseevent.inaxes == self.fig.sx or mpl_mouseevent.inaxes == self.fig.sz) and (
            self.show_sliceplots == True
        ):
            self.lineplot_context(mpl_mouseevent, evt)

    def area_context(self, mpl_mouseevent, evt):
        area_popup = wx.Menu()
        item1 = area_popup.Append(wx.ID_ANY, "&Grid on/off", "Toggle grid lines")
        wx.EVT_MENU(self, item1.GetId(), self.OnGridToggle)
        cmapmenu = CMapMenu(self, callback=self.OnColormap, mapper=self.mapper, canvas=self.canvas)
        item2 = area_popup.Append(wx.ID_ANY, "&Toggle log/lin", "Toggle log/linear scale")
        wx.EVT_MENU(self, item2.GetId(), lambda evt: self.toggle_log_lin(mpl_mouseevent))
        item3 = area_popup.AppendMenu(wx.ID_ANY, "Colourmaps", cmapmenu)
        self.PopupMenu(area_popup, evt.GetPositionTuple())

    def figure_list_dialog(self):
        figure_list = get_fignums()
        figure_list_names = []
        for fig in figure_list:
            figure_list_names.append("Figure " + str(fig))
        figure_list_names.insert(0, "New Figure")
        figure_list.insert(0, None)
        # selection_num = wx.GetSingleChoiceIndex('Choose other plot', '', other_plot_names)
        dlg = wx.SingleChoiceDialog(None, "Choose figure number", "", figure_list_names)
        dlg.SetSize(wx.Size(640, 480))
        if dlg.ShowModal() == wx.ID_OK:
            selection_num = dlg.GetSelection()
        dlg.Destroy()
        print selection_num
        return figure_list[selection_num]

    def lineplot_context(self, mpl_mouseevent, evt):
        popup = wx.Menu()
        item1 = popup.Append(wx.ID_ANY, "&Toggle log/lin", "Toggle log/linear scale of slices")
        wx.EVT_MENU(self, item1.GetId(), lambda evt: self.toggle_log_lin(mpl_mouseevent))
        if mpl_mouseevent.inaxes == self.fig.sx:
            item2 = popup.Append(wx.ID_ANY, "Save x slice", "save this slice")
            wx.EVT_MENU(self, item2.GetId(), self.save_x_slice)
            item3 = popup.Append(wx.ID_ANY, "&Popout plot", "Open this data in a figure window")
            wx.EVT_MENU(self, item3.GetId(), lambda evt: self.popout_x_slice())
        elif mpl_mouseevent.inaxes == self.fig.sz:
            item2 = popup.Append(wx.ID_ANY, "Save y slice", "save this slice")
            wx.EVT_MENU(self, item2.GetId(), self.save_y_slice)
            item3 = popup.Append(wx.ID_ANY, "&Popout plot", "Open this data in a new plot window")
            wx.EVT_MENU(self, item3.GetId(), lambda evt: self.popout_y_slice())
        self.PopupMenu(popup, evt.GetPositionTuple())

    def popout_y_slice(self, event=None, figure_num=None, label=None):
        if figure_num == None:
            figure_num = self.figure_list_dialog()
        fig = figure(figure_num)  # if this is None, matplotlib automatically increments figure number to highest + 1
        ax = self.fig.sz
        slice_desc = "\nsliceplot([%f,%f],[%f,%f])" % (
            self.slice_xy_range[0][0],
            self.slice_xy_range[0][1],
            self.slice_xy_range[1][0],
            self.slice_xy_range[1][1],
        )
        if figure_num == None:
            default_title = self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, "Enter title for plot", defaultValue=default_title)
            if dlg.ShowModal() == wx.ID_OK:
                title = dlg.GetValue()
            else:
                title = default_title
            dlg.Destroy()
            new_ax = fig.add_subplot(111)
            new_ax.set_title(title, size="large")
            new_ax.set_xlabel(self.x_label, size="x-large")
            new_ax.set_ylabel("$I_{summed}$", size="x-large")
        else:
            new_ax = fig.axes[0]
        if label == None:
            default_label = self.window_title + ": " + self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, "Enter data label (for plot legend)", defaultValue=default_label)
            if dlg.ShowModal() == wx.ID_OK:
                label = dlg.GetValue()
            else:
                label = default_label
            dlg.Destroy()
        xy = ax.lines[0].get_data()
        x = xy[0]
        y = xy[1]
        new_ax.plot(x, y, label=label)
        font = FontProperties(size="small")
        lg = legend(prop=font)
        drag_lg = DraggableLegend(lg)
        drag_lg.connect()
        fig.canvas.draw()
        fig.show()

    def popout_x_slice(self, event=None, figure_num=None, label=None):
        if figure_num == None:
            figure_num = self.figure_list_dialog()
        fig = figure(figure_num)
        ax = self.fig.sx
        slice_desc = "\nsliceplot([%f,%f],[%f,%f])" % (
            self.slice_xy_range[0][0],
            self.slice_xy_range[0][1],
            self.slice_xy_range[1][0],
            self.slice_xy_range[1][1],
        )
        if figure_num == None:
            default_title = self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, "Enter title for plot", defaultValue=default_title)
            if dlg.ShowModal() == wx.ID_OK:
                title = dlg.GetValue()
            else:
                title = default_title
            dlg.Destroy()
            new_ax = fig.add_subplot(111)
            new_ax.set_title(title, size="large")
            new_ax.set_xlabel(self.y_label, size="x-large")
            new_ax.set_ylabel("$I_{summed}$", size="x-large")
        else:
            new_ax = fig.axes[0]
        if label == None:
            default_label = self.window_title + ": " + self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, "Enter data label (for plot legend)", defaultValue=default_label)
            if dlg.ShowModal() == wx.ID_OK:
                label = dlg.GetValue()
            else:
                label = default_label
            dlg.Destroy()
        xy = ax.lines[0].get_data()
        x = xy[1]
        y = xy[0]
        new_ax.plot(x, y, label=label)
        font = FontProperties(size="small")
        lg = legend(prop=font)
        drag_lg = DraggableLegend(lg)
        drag_lg.connect()
        fig.canvas.draw()
        fig.show()

    def save_x_slice(self, event=None, outFileName=None):
        if outFileName == None:
            dlg = wx.FileDialog(None, "Save 2d data as:", "", "", "", wx.FD_SAVE)
            if dlg.ShowModal() == wx.ID_OK:
                fn = dlg.GetFilename()
                fd = dlg.GetDirectory()
            dlg.Destroy()
            outFileName = fd + "/" + fn
        outFile = open(outFileName, "w")
        outFile.write("#" + self.title + "\n")
        outFile.write("#xmin: " + str(self.slice_xy_range[0][0]) + "\n")
        outFile.write("#xmax: " + str(self.slice_xy_range[0][1]) + "\n")
        outFile.write("#ymin: " + str(self.slice_xy_range[1][0]) + "\n")
        outFile.write("#ymax: " + str(self.slice_xy_range[1][1]) + "\n")
        outFile.write("#y\tslice_x_data\n")
        if not (self.slice_x_data == None):
            for i in range(self.slice_x_data.shape[0]):
                x = self.y[i]
                y = self.slice_x_data[i]
                outFile.write(str(x) + "\t" + str(y) + "\n")
        outFile.close()
        print ("saved x slice in %s" % (outFileName))
        return

    def save_y_slice(self, event=None, outFileName=None):
        if outFileName == None:
            dlg = wx.FileDialog(None, "Save 2d data as:", "", "", "", wx.FD_SAVE)
            if dlg.ShowModal() == wx.ID_OK:
                fn = dlg.GetFilename()
                fd = dlg.GetDirectory()
            dlg.Destroy()
            outFileName = fd + "/" + fn
        outFile = open(outFileName, "w")
        outFile.write("#" + self.title + "\n")
        outFile.write("#xmin: " + str(self.slice_xrange[0]) + "\n")
        outFile.write("#xmax: " + str(self.slice_xrange[1]) + "\n")
        outFile.write("#ymin: " + str(self.slice_yrange[0]) + "\n")
        outFile.write("#ymax: " + str(self.slice_yrange[1]) + "\n")
        outFile.write("#x\tslice_y_data\n")
        if not (self.slice_y_data == None):
            for i in range(self.slice_y_data.shape[0]):
                x = self.x[i]
                y = self.slice_y_data[i]
                outFile.write(str(x) + "\t" + str(y) + "\n")
        outFile.close()
        print ("saved y slice in %s" % (outFileName))
        return

    def OnGridToggle(self, event):
        self.fig.ax.grid()
        self.fig.canvas.draw_idle()

    def OnColormap(self, name):
        print "Selected colormap", name
        self.fig.im.set_cmap(get_cmap(name))
        self.fig.canvas.draw()

    def toggle_2d_plot_scale(self, event=None):
        if self.__scale == "log":
            self.show_data = self.data.T
            self.fig.im.set_array(self.show_data)
            self.fig.im.autoscale()
            self.fig.cb.set_xlabel("$I$")
            self.__scale = "lin"
            self.menu_log_lin_toggle.Check(False)
            self.statusbar.SetStatusText("%s scale" % self.__scale, 0)
            self.fig.canvas.draw_idle()
        elif self.__scale == "lin":
            self.show_data = log(self.data.copy().T + self.minimum_intensity / 2.0)
            self.fig.im.set_array(self.show_data)
            self.fig.im.autoscale()
            self.fig.cb.set_xlabel("$\log_{10}I$")
            self.__scale = "log"
            self.menu_log_lin_toggle.Check(True)
            self.statusbar.SetStatusText("%s scale" % self.__scale, 0)
            self.fig.canvas.draw_idle()

    def toggle_log_lin(self, event):

        ax = event.inaxes
        label = ax.get_label()

        if label == "2d_plot":
            self.toggle_2d_plot_scale()

        if label == "sz":
            scale = ax.get_yscale()
            if scale == "log":
                ax.set_yscale("linear")
                ax.figure.canvas.draw_idle()
            elif scale == "linear":
                ax.set_yscale("log")
                ax.figure.canvas.draw_idle()

        elif label == "sx":
            scale = ax.get_xscale()
            if scale == "log":
                ax.set_xscale("linear")
                ax.figure.canvas.draw_idle()
            elif scale == "linear":
                ax.set_xscale("log")
                ax.figure.canvas.draw_idle()

    def onmousemove(self, event):
        # the cursor position is given in the wx status bar
        # self.fig.gca()
        if event.inaxes:
            x, y = event.xdata, event.ydata
            self.statusbar.SetStatusText("%s scale x = %.3g, y = %.3g" % (self.__scale, x, y), 1)
            # self.statusbar.SetStatusText("y = %.3g" %y, 2)

    def onselect(self, eclick, erelease):
        x_range = [eclick.xdata, erelease.xdata]
        y_range = [eclick.ydata, erelease.ydata]
        ax = eclick.inaxes
        self.sliceplot((x_range, y_range), ax)
        print "sliceplot(([%f,%f],[%f,%f]))" % (x_range[0], x_range[1], y_range[0], y_range[1])

    def sliceplots_off(self):
        self.fig.ax.set_position([0.125, 0.1, 0.7, 0.8])
        self.fig.cb.set_position([0.85, 0.1, 0.05, 0.8])
        # self.fig.cb.set_visible(True)
        self.fig.sx.set_visible(False)
        self.fig.sz.set_visible(False)
        if self.fig.slice_overlay:
            self.fig.slice_overlay[0].set_visible(False)
        self.RS.set_active(False)
        self.show_sliceplots = False
        self.fig.canvas.draw()

    def sliceplots_on(self):
        self.fig.ax.set_position([0.125, 0.53636364, 0.35227273, 0.36363636])
        self.fig.cb.set_position([0.49, 0.53636364, 0.02, 0.36363636])
        self.fig.sx.set_position([0.58, 0.53636364, 0.35227273, 0.36363636])
        self.fig.sx.set_visible(True)
        self.fig.sz.set_visible(True)
        # self.fig.cb.set_visible(False)
        if self.fig.slice_overlay:
            self.fig.slice_overlay[0].set_visible(True)
        self.RS.set_active(True)
        self.show_sliceplots = True
        self.fig.canvas.draw()

    def toggle_sliceplots(self):
        """switch between views with and without slice plots"""
        if self.show_sliceplots == True:
            self.sliceplots_off()
        else:  # self.show_sliceplots == False
            self.sliceplots_on()

    def show_slice_overlay(self, x_range, y_range, x, slice_y_data, y, slice_x_data):
        """sum along x and z within the box defined by qX- and qZrange.
        sum along qx is plotted to the right of the data,
        sum along qz is plotted below the data.
        Transparent white rectangle is overlaid on data to show summing region"""
        from matplotlib.ticker import FormatStrFormatter, ScalarFormatter

        if self.fig == None:
            print ("No figure for this dataset is available")
            return

        fig = self.fig
        ax = fig.ax
        extent = fig.im.get_extent()

        if fig.slice_overlay == None:
            fig.slice_overlay = ax.fill(
                [x_range[0], x_range[1], x_range[1], x_range[0]],
                [y_range[0], y_range[0], y_range[1], y_range[1]],
                fc="white",
                alpha=0.3,
            )
            fig.ax.set_ylim(extent[2], extent[3])
        else:
            fig.slice_overlay[0].xy = [
                (x_range[0], y_range[0]),
                (x_range[1], y_range[0]),
                (x_range[1], y_range[1]),
                (x_range[0], y_range[1]),
            ]
        fig.sz.clear()
        default_fmt = ScalarFormatter(useMathText=True)
        default_fmt.set_powerlimits((-2, 4))
        fig.sz.xaxis.set_major_formatter(default_fmt)
        fig.sz.yaxis.set_major_formatter(default_fmt)
        fig.sz.xaxis.set_major_formatter(FormatStrFormatter("%.2g"))
        fig.sz.set_xlim(x[0], x[-1])
        fig.sz.plot(x, slice_y_data)
        fig.sx.clear()
        fig.sx.yaxis.set_major_formatter(default_fmt)
        fig.sx.xaxis.set_major_formatter(default_fmt)
        fig.sx.yaxis.set_ticks_position("right")
        fig.sx.yaxis.set_major_formatter(FormatStrFormatter("%.2g"))
        fig.sx.set_ylim(y[0], y[-1])
        fig.sx.plot(slice_x_data, y)

        fig.im.set_extent(extent)
        fig.canvas.draw()

    def copy_intensity_range_from(self, other_plot):
        if isinstance(other_plot, type(self)):
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()

            self.fig.im.set_clim(other_plot.fig.im.get_clim())
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()

    def sliceplot(self, xy_range, ax=None):
        """sum along x and z within the box defined by qX- and qZrange.
        sum along qx is plotted to the right of the data,
        sum along qz is plotted below the data.
        Transparent white rectangle is overlaid on data to show summing region"""
        self.sliceplots_on()
        x_range, y_range = xy_range
        x, slice_y_data, y, slice_x_data = self.do_xy_slice(x_range, y_range)
        self.x = x
        self.slice_y_data = slice_y_data
        self.y = y
        self.slice_x_data = slice_x_data
        self.slice_xy_range = xy_range

        self.show_slice_overlay(x_range, y_range, x, slice_y_data, y, slice_x_data)

    def do_xy_slice(self, x_range, y_range):
        """ slice up the data, once along x and once along z.
        returns 4 arrays:  a y-axis for the x data,
        an x-axis for the y data."""
        # params = self.params
        print "doing xy slice"
        data = self.data
        pixels = self.pixel_mask
        # zero out any pixels in the sum that have zero in the pixel count:
        data[pixels == 0] = 0

        normalization_matrix = ones(data.shape)
        normalization_matrix[pixels == 0] = 0
        x_min = min(x_range)
        x_max = max(x_range)
        y_min = min(y_range)
        y_max = max(y_range)

        x_size, y_size = data.shape
        global_x_range = self.x_max - self.x_min
        global_y_range = self.y_max - self.y_min

        x_pixel_min = round((x_min - self.x_min) / global_x_range * x_size)
        x_pixel_max = round((x_max - self.x_min) / global_x_range * x_size)
        y_pixel_min = round((y_min - self.y_min) / global_y_range * y_size)
        y_pixel_max = round((y_max - self.y_min) / global_y_range * y_size)

        # correct any sign switches:
        if x_pixel_min > x_pixel_max:
            new_min = x_pixel_max
            x_pixel_max = x_pixel_min
            x_pixel_min = new_min

        if y_pixel_min > y_pixel_max:
            new_min = y_pixel_max
            y_pixel_max = y_pixel_min
            y_pixel_min = new_min

        new_x_min = x_pixel_min / x_size * global_x_range + self.x_min
        new_x_max = x_pixel_max / x_size * global_x_range + self.x_min
        new_y_min = y_pixel_min / y_size * global_y_range + self.y_min
        new_y_max = y_pixel_max / y_size * global_y_range + self.y_min

        x_pixel_min = int(x_pixel_min)
        x_pixel_max = int(x_pixel_max)
        y_pixel_min = int(y_pixel_min)
        y_pixel_max = int(y_pixel_max)

        y_norm_factor = sum(normalization_matrix[x_pixel_min:x_pixel_max, y_pixel_min:y_pixel_max], axis=1)
        x_norm_factor = sum(normalization_matrix[x_pixel_min:x_pixel_max, y_pixel_min:y_pixel_max], axis=0)
        # make sure the normalization has a minimum value of 1 everywhere,
        # to avoid divide by zero errors:
        y_norm_factor[y_norm_factor == 0] = 1
        x_norm_factor[x_norm_factor == 0] = 1

        slice_y_data = sum(data[x_pixel_min:x_pixel_max, y_pixel_min:y_pixel_max], axis=1) / y_norm_factor
        slice_x_data = sum(data[x_pixel_min:x_pixel_max, y_pixel_min:y_pixel_max], axis=0) / x_norm_factor

        # slice_y_data = slice_y_data
        # slice_x_data = slice_x_data

        x_vals = (
            arange(slice_y_data.shape[0], dtype="float") / slice_y_data.shape[0] * (new_x_max - new_x_min) + new_x_min
        )
        y_vals = (
            arange(slice_x_data.shape[0], dtype="float") / slice_x_data.shape[0] * (new_y_max - new_y_min) + new_y_min
        )

        return x_vals, slice_y_data, y_vals, slice_x_data
class RectChooser(object):
    MOUSEUP = ['Q', 'q']
    MOUSEDOWN = ['A', 'a']

    def __init__(self, fitsfile, ax, mags, frms, all_axes, buttons, use_hjd=False,
            display_class=LightcurveDisplay):
        self.fitsfile = fitsfile
        self.ax = ax
        self.mags = mags
        self.frms = frms
        self.all_axes = all_axes
        self.buttons = buttons
        self.use_hjd = use_hjd
        self.display_class = display_class
        self.selector = RectangleSelector(self.ax, self.on_event, drawtype='box')
        self.l = None

    def on_event(self, eclick, erelease):
        logger.debug('startposition: ({}, {})'.format(eclick.xdata, eclick.ydata))
        logger.debug('endposition: ({}, {})'.format(erelease.xdata, erelease.ydata))

        mag_min, mag_max = eclick.xdata, erelease.xdata
        frms_min, frms_max = eclick.ydata, erelease.ydata

        if mag_min > mag_max:
            mag_min, mag_max = mag_max, mag_min

        if frms_min > frms_max:
            frms_min, frms_max = frms_max, frms_min

        logger.info('Using magnitude range {} to {}'.format(mag_min, mag_max))
        logger.info('Using frms range {} to {}'.format(frms_min, frms_max))

        indices = np.arange(self.mags.size)
        chosen = ((self.mags >= mag_min) &
                (self.mags < mag_max) &
                (self.frms >= frms_min) &
                (self.frms < frms_max))

        if not chosen.any():
            logger.error("No lightcurves chosen, please try again")
        else:
            if self.l is not None:
                logger.debug('Lightcurve display present')
                self.reset_buttons()
            self.load_lightcurves(indices[chosen])

    def reset_buttons(self):
        self.buttons[0].disconnect(self.prev_cid)
        self.buttons[1].disconnect(self.next_cid)
        self.l.remove_frms_line()
        del self.l

    def load_lightcurves(self, indices):
        self.l = self.display_class(self.fitsfile, self.all_axes).display_lightcurves(self.mags,
                self.frms, indices, use_hjd=self.use_hjd)

        self.prev_cid = self.buttons[0].on_clicked(self.l.previous)
        self.next_cid = self.buttons[1].on_clicked(self.l.next)


    def toggle_selector(self, event):
        logger.debug(' Key pressed.')
        if event.key in self.MOUSEUP and self.selector.active:
            logger.debug('RectangleSelector deactivated')
            self.selector.set_active(False)

        if event.key in self.MOUSEDOWN and not self.selector.active:
            logger.debug('RectangleSelector activated')
            self.selector.set_active(True)
Ejemplo n.º 25
0
class ImProcc():
    """
    Class that contains all the methods necessary to process microscope images in a GUI.
    """
    def __init__(self, parent, controller):
        """
        Initialize class.
        slef.click are the x, y coordinates of the low-left corner of the square roi.
        self.release are the x, y coordinates of the upright corner of the square roi.
        """
        self.click = [None, None]
        self.release = [None, None]

        self.image_roi = None

        self.parent = parent
        self.controller = controller

        self.shapecells = dict(
        )  # dictionary containing cell shape objects { Cell # : cell object }
        self.cell_zframes = dict()

        # panda DataFrame containing the cells connections and their coordinates
        row = ['zframe', 'cell1', 'cell2', 'centerX', 'centerY']
        data = np.empty((0, 5), int)
        self.connections = pd.DataFrame(data.tolist(), columns=row)

    def roi_selector(self):
        """
        Function to select the region of interest to process

        """

        controller = self.controller

        parent = self.parent

        def line_select_callback(eclick, erelease):
            'eclick and erelease are the press and release events'
            # if self.click[0] is None:
            #     controller.procesbtn.config(state="normal")
            #     controller.modifybtn.config(state="normal")

            self.click[:] = eclick.xdata, eclick.ydata
            self.release[:] = erelease.xdata, erelease.ydata
            self.controller.canvas.draw()

        if controller.roiON.get():
            self.RS = RectangleSelector(
                parent.ax,
                line_select_callback,
                drawtype='box',
                useblit=True,
                button=[1, 3],  # don't use middle button
                minspanx=5,
                minspany=5,
                spancoords='pixels',
                interactive=True)
            if self.click[0] is not None:
                self.RS.to_draw.set_visible(True)
                self.controller.canvas.draw()
                self.RS.extents = (self.click[0], self.release[0],
                                   self.click[1], self.release[1])
            else:
                self.controller.canvas.draw()

        else:

            try:
                self.RS.set_active(False)
                self.controller.canvas.draw()

            except AttributeError:
                pass

    def process_image(self):
        """
        Function to process the image of interest with default options.

        TO WRITE

        """

        parent = self.parent

        apply_threshold(self)

    def apply_threshold(self):
        """
        Function to apply image threshold.

        TO WRITE

        """

        if self.click[0] is None:
            messagebox.showerror("Error", "Select a ROI to process")
        else:
            parent = self.parent
            if parent.shape[2] == 1:
                # images[0] timepoint 0
                # images[n][1,:,:] timepoint n, stack 1
                self.image_roi = parent.file[0][self.click[1]:self.release[1],
                                                self.click[0]:self.release[0]]
            else:
                frame_displ = round(controller.scrollbar.get())
                self.image_roi = parent.file[0][frame_displ,
                                                self.click[1]:self.release[1],
                                                self.click[0]:self.release[0]]

    def open_modify(self):
        """
        Function to initialize the window to modify the image processing.
        """

        self.modproc = modw.ImModify(parent=self, controller=self.controller)

    def open_modify_body_skel(self):
        """
        Function to initialize the window to modify the cell body skeleton detection.
        """

        controller = self.controller
        try:
            cell_id = controller.lbox.curselection()[0] + 1

            self.modcellbody = modbody.CellBodyModify(
                img=self.parent,
                cellshape=self.shapecells[str(cell_id)],
                cellID=cell_id,
                parent=self,
                controller=self.controller)
        except IndexError:

            messagebox.showerror(
                "Error",
                "Select a cell to modify from the list and press again the button."
            )

    def manual_selector(self):
        """
        Function to initialize the window to manually select cell contour.
        """

        self.cellobject = singleCellShape(
            parent=self,
            controller=self.controller)  # object with single cell contour
        self.manselec = ms.ManualSelector(img=self.parent,
                                          parent=self,
                                          controller=self.controller)

    def create_cell_list(self, procfile):
        """
        Function to crete the list of cell processed to show in the listbox of the main GUI window.
        This function is called when a file project is loaded
        """

        self.shapecells = procfile['shapecells']
        self.cell_zframes = procfile['cell_zframes']
        self.connections = procfile['connections']

        for i in self.shapecells.keys():
            self.add_item_cell_list(int(i))

    def add_item_cell_list(self, idx):
        """
        Function to update the list of cell processed to show in the listbox of the main GUI window.
        """
        controller = self.controller

        if idx == 1:
            # turno on buttons
            controller.lbl_cell_list.config(state="normal")
            controller.lbox.config(state="normal")
            controller.filemenu.showMenu.entryconfig(2, state='normal')
            controller.show_cellshapeON.set(1)
            controller.display_selectedbtn.config(state="normal")
            controller.filemenu.summaryMenu.entryconfig(1, state='normal')
            controller.filemenu.summaryMenu.entryconfig(2, state='normal')
            controller.modifyBodyBtn.config(state="normal")

            #  focus_set method to move focus back to the scrollbar of the mainGUI
            controller.scrollbar.focus_set()

        controller.lbox.insert(tk.END, 'Cell # ' + str(idx))
        color = '#%02x%02x%02x' % tuple([
            int(i * 255) for i in self.shapecells[str(idx)].contour['color']
        ])  # Hex color format
        controller.lbox.itemconfig(idx - 1, {'fg': color})

    def cbtn_show_cellprocessed(self):
        """
        Function connect to controller.cbtn_showcell checkbutton
        to show cell shape processed on the main GUI window.
        """

        controller = self.controller

        if controller.show_cellshapeON.get():
            self.show_cellprocessed()
        else:
            controller.img.ax.lines = []
            controller.canvas.draw()

    def display_single_cell_processing(self, shapeobj, **linekwargs):
        """
        Display cell contour in the canvas of the main GUI window.
        """
        controller = self.controller
        l = plt.Line2D(shapeobj.contour['allxpoints'] +
                       [shapeobj.contour['allxpoints'][0]],
                       shapeobj.contour['allypoints'] +
                       [shapeobj.contour['allypoints'][0]],
                       color=shapeobj.contour['color'],
                       **linekwargs)
        controller.img.ax.add_line(l)

    def display_single_cell_skeleton(self, shapeobj, **linekwargs):
        """
        Display cell skeleton in the canvas of the main GUI window.
        """
        controller = self.controller

        color_cellbody = '#ff0000'
        for path in shapeobj.skelbody['paths']:
            l = plt.Line2D(path[:, 1],
                           path[:, 0],
                           color=color_cellbody,
                           **linekwargs)
            controller.img.ax.add_line(l)

        color_secondary = '#ffffff'  # '#ffc03e'
        for protusion in shapeobj.skelprot['secondary-paths']:
            for path in protusion:
                l = plt.Line2D(path[:, 1],
                               path[:, 0],
                               color=color_secondary,
                               **linekwargs)
                controller.img.ax.add_line(l)

        # color = '#FFC125'
        color_primary = '#ffff00'
        for path in shapeobj.skelprot['primary-path']:
            l = plt.Line2D(path[:, 1],
                           path[:, 0],
                           color=color_primary,
                           **linekwargs)
            controller.img.ax.add_line(l)

    def display_cell_connections(self, connobj):
        """
        Display cell connections in the canvas of the main GUI window.
        """
        controller = self.controller
        rgbCol = (1, 0, 0)  # red
        patches = []
        coord = np.array([connobj.centerX.tolist(),
                          connobj.centerY.tolist()]).transpose()
        coord = tuple(map(tuple, coord))
        for c in coord:
            circle = Circle(c, radius=15)
            patches.append(circle)
        connectCollection = PatchCollection(patches, facecolors=rgbCol)
        controller.img.ax.add_collection(connectCollection)

    def show_cellprocessed(self):
        """
        Function to activate cell processed visualization in the main GUI window.
        """

        controller = self.controller

        controller.img.ax.lines = []
        controller.img.ax.collections = []

        zframe = round(controller.scrollbar.get())

        if str(zframe) in self.cell_zframes.keys():

            for nroi in self.cell_zframes[str(zframe)]:
                self.display_single_cell_processing(
                    shapeobj=self.shapecells[str(nroi)])
                self.display_single_cell_skeleton(
                    shapeobj=self.shapecells[str(nroi)])
            self.display_cell_connections(
                connobj=self.connections[self.connections.zframe == zframe])

        else:
            pass
        controller.canvas.draw()

    def display_cell_selected(self):
        """
        Function to change the image zframe visualized in the main GUI window to the one of the cell selected
        in the listbox
        """

        controller = self.controller
        try:
            cell_id = controller.lbox.curselection()[0] + 1
            im_idx = self.shapecells[str(cell_id)].zframe

            controller.scrollbar.set(im_idx)
            controller.show_cellshapeON.set(1)
            self.show_cellprocessed()
        except IndexError:
            messagebox.showerror(
                "Error", 'Select a cell to show from the '
                'Cell Processed List'
                ' and press again the button.')

    def print_summary(self):
        """
        Function to start the visualization and further saving of the parameters extracted by cell shape
        """

        controller = self.controller

        ps.PrintParameters(parent=self, controller=controller)
Ejemplo n.º 26
0
class Application(QDialog):
    keyPressed = QtCore.pyqtSignal(QtCore.QEvent)
    def __init__(self, parent = None):
        pathName = '/Users/bendichter/Desktop/Chang/data/EC125/EC125_B22'
        self.temp = []        
        super(Application, self).__init__()
        self.keyPressed.connect(self.on_key)
        self.init_gui()
        self.showMaximized()
        self.setWindowTitle('')          
        self.show()
        '''
        run the main file
        '''
        
        parameters = {}
        parameters['pars'] = {'Axes': [self.axes1, self.axes2, self.axes3], 'Figure': [self.figure1, self.figure2, self.figure3]}
        parameters['editLine'] = {'qLine0': self.qline0, 'qLine1': self.qline1, 'qLine2': self.qline2, 'qLine3': self.qline3, 
                  'qLine4': self.qline4}        
        self.model = ecogTSGUI(pathName, parameters)
        
#        model.channelScrollUp()
    def keyPressEvent(self, event): 
        super(Application, self).keyPressEvent(event)
        self.keyPressed.emit(event) 
        
    def on_key(self, event):
        if event.key() == QtCore.Qt.Key_W:
            self.model.channel_Scroll_Up()
        elif event.key() == QtCore.Qt.Key_S:            
            self.model.channel_Scroll_Down()
        elif event.key() == QtCore.Qt.Key_A:
            self.model.page_back()
            print "Left"
        elif event.key() == QtCore.Qt.Key_D:
            self.model.page_forward()
            print 'Right'
                    
    def init_gui(self):
        vbox = QVBoxLayout()
        
        groupbox1 = QGroupBox('')
        groupbox1.setFixedHeight(50)
        formlayout4 = QFormLayout()
        self.figure3 = Figure()
        self.axes3 = self.figure3.add_subplot(111)
        canvas2 = FigureCanvas(self.figure3)        
        self.axes3.set_axis_off()
        self.figure3.tight_layout(rect = [0, 0, 1, 1])
        formlayout4.addWidget(canvas2)
        groupbox1.setLayout(formlayout4)
        
        groupbox2 = QGroupBox('Channels Plot')
        groupbox2.setFixedHeight(580)
        formlayout5 = QFormLayout() 
        
        self.figure1 = Figure()
        self.axes1 = self.figure1.add_subplot(111)
        plt.rc('axes', prop_cycle = (cycler('color', ['b', 'g'])))
        self.figure1.tight_layout(rect = [0, 0, 1, 1])        
        self.canvas = FigureCanvas(self.figure1)
        formlayout5.addWidget(self.canvas)
        groupbox2.setLayout(formlayout5)        
        
        groupbox3 = QGroupBox('')
        groupbox3.setFixedHeight(40)
        formlayout6 = QFormLayout()
        self.figure2 = Figure()
        self.axes2 = self.figure2.add_subplot(111)
        canvas1 = FigureCanvas(self.figure2)        
        self.axes2.set_axis_off()
        self.figure2.tight_layout(rect = [0, 0, 1, 1])        
        formlayout6.addWidget(canvas1)        
        groupbox3.setLayout(formlayout6) 
        
        vbox.addWidget(groupbox1)
        vbox.addWidget(groupbox2)
        vbox.addWidget(groupbox3)
        
        
        hbox = QHBoxLayout()
        panel1 = QGroupBox('Panel')
        panel1.setFixedHeight(100)
        form1 = QFormLayout()
        self.push1 = QPushButton('Data Cursor On')
        self.push1.setFixedWidth(200)
        self.push1.clicked.connect(self.Data_Cursor)
        self.push2 = QPushButton('Get Ch')
        self.push2.clicked.connect(self.On_Click)
        self.push2.setFixedWidth(200)
        self.push3 = QPushButton('Save Bad Intervals')
        self.push3.clicked.connect(self.SaveBadIntervals)
        self.push3.setFixedWidth(200)
        self.push4 = QPushButton('Select Bad intervals')
        self.push4.clicked.connect(self.SelectBadInterval)
        self.push4.setFixedWidth(200)
        self.push5 = QPushButton('Delete Intervals')
        self.push5.clicked.connect(self.DeleteBadInterval)
        self.push5.setFixedWidth(200)
        form1.addRow(self.push1, self.push2)
        form1.addRow(self.push4, self.push5)
        form1.addWidget(self.push3)
        panel1.setLayout(form1)
        panel2 = QGroupBox('Signal Type')
        panel2.setFixedWidth(200)
        form2 = QFormLayout()
        self.rbtn1 = QCheckBox('raw ECoG')
        self.rbtn1.setChecked(True)
        self.rbtn2 = QCheckBox('High Gamma')
        self.rbtn2.setChecked(False)
        form2.addWidget(self.rbtn1)
        form2.addWidget(self.rbtn2)
        panel2.setLayout(form2)
        
        panel3 = QGroupBox('Plot Controls')
        gridLayout = QGridLayout()
#        gridLayout.setAlignment(Qt.AlignLeft)
        qlabel1 = QLabel('Ch selected #')
        qlabel1.setFixedWidth(70)
        gridLayout.addWidget(qlabel1, 0, 0)
        self.qline1 = QLineEdit('1')        
        self.qline0 = QLineEdit('16')
        self.qline0.returnPressed.connect(self.channelDisplayed)
        self.qline2 = QLineEdit('0.01')
        self.qline2.returnPressed.connect(self.start_location)
        self.qline1.setFixedWidth(40)
        self.qline2.setFixedWidth(40)
        self.pushbtn1 = QPushButton('^')
        self.pushbtn1.clicked.connect(self.scroll_up)
        self.pushbtn1.setFixedWidth(30)
        self.pushbtn2 = QPushButton('v')
        self.pushbtn2.clicked.connect(self.scroll_down)
        self.pushbtn3 = QPushButton('<<')
        self.pushbtn3.clicked.connect(self.page_backward)
        self.pushbtn4 = QPushButton('<')
        self.pushbtn4.clicked.connect(self.scroll_backward)
        self.pushbtn5 = QPushButton('>>')
        self.pushbtn5.clicked.connect(self.page_forward)
        self.pushbtn6 = QPushButton('>')
        self.pushbtn6.clicked.connect(self.scroll_forward)
        self.pushbtn3.setFixedWidth(30)
        self.pushbtn4.setFixedWidth(30)
        self.pushbtn5.setFixedWidth(30)
        self.pushbtn6.setFixedWidth(30)
        self.pushbtn2.setFixedWidth(30)
        
        gridLayout.addWidget(self.qline1, 0, 1)
        gridLayout.addWidget(self.qline0, 0, 2)
        gridLayout.addWidget(self.pushbtn1, 0, 3)
        gridLayout.addWidget(self.pushbtn2, 0, 4)
        gridLayout.addWidget(self.pushbtn3, 0, 5)
        gridLayout.addWidget(self.pushbtn4, 0, 6)
        
        qlabel2 = QLabel('Interval start(s)')
        qlabel2.setFixedWidth(80)
        gridLayout.addWidget(qlabel2, 0, 7)
        gridLayout.addWidget(self.qline2, 0, 8)
        gridLayout.addWidget(self.pushbtn6, 0, 9)
        gridLayout.addWidget(self.pushbtn5, 0, 10)
        qlabel3 = QLabel('Window')
        qlabel3.setFixedWidth(60)
        self.qline3 = QLineEdit('25')
        self.qline3.returnPressed.connect(self.plot_interval)
        self.qline3.setFixedWidth(40)
        
        gridLayout.addWidget(qlabel3, 0, 11)
        gridLayout.addWidget(self.qline3, 0, 12)
        qlabel4 = QLabel('Vertical Scale')
        
        qlabel4.setFixedWidth(60)
        self.qline4 = QLineEdit('1')
        self.qline4.setFixedWidth(40)
        
        gridLayout.addWidget(qlabel4, 0, 13)
        gridLayout.addWidget(self.qline4, 0, 14)
        
        self.pushbtn7 = QPushButton('*2')
        self.pushbtn7.clicked.connect(self.verticalScaleIncrease)
        self.pushbtn8 = QPushButton('/2')
        self.pushbtn8.clicked.connect(self.verticalScaleDecrease)
        self.pushbtn7.setFixedWidth(30)
        self.pushbtn8.setFixedWidth(30)
        gridLayout.addWidget(self.pushbtn7, 0, 15)
        gridLayout.addWidget(self.pushbtn8, 0, 16)
        panel3.setLayout(gridLayout)
        hbox.addWidget(panel1)
        hbox.addWidget(panel2)
        hbox.addWidget(panel3)       
        vbox.addLayout(hbox)       
        self.setLayout(vbox) 
        
     
        
    def SaveBadIntervals(self):
        self.model.pushSave()
    
    def SelectBadInterval(self):        
        self.toggle_selector = RectangleSelector(self.axes1, self.line_select_callback, useblit = True,
                                                 spancoords = 'pixels', drawtype = 'box', 
                                                 minspanx=5, minspany=5, rectprops = dict(facecolor = 'peachpuff', edgecolor = None,
                 alpha = 0.04, fill = True))
        self.toggle_selector.to_draw.set_visible(True)
             
    def On_Click(self):
        self.model.getChannel()
    
    def line_select_callback(self, eclick, erelease):
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        BadInterval = [x1, x2]
        self.model.addBadTimeSeg(BadInterval)
        self.model.refreshScreen()
                 
    def DeleteBadInterval(self):
        cid = self.figure1.canvas.mpl_connect('button_press_event', self.get_coordinates)
        self.toggle_selector.set_active(False)
    def get_coordinates(self, event):
        x = event.xdata
#        y = event.ydata
        self.model.deleteInterval(x)
    def get_cursor_position(self, event):
        x = event.xdata
        y = event.ydata
        self.model.x_cur = x
        self.model.y_cur = y
        props = dict(boxstyle = 'round', facecolor = 'y',  alpha=0.5)
        text_ = self.axes1.text(x, y, 'x:' + str(round(x,2)) + '\n' + 'y:' + str(round(y, 2)), bbox = props)   
        
        if self.temp == []:
            self.temp = text_
        else:
            self.temp.remove()
            self.temp = text_
            
        self.figure1.canvas.draw()
        
    def Data_Cursor(self):        
        if self.push1.text() == 'Data Cursor On':
            self.push1.setText('Data Cursor Off')
            cid = self.figure1.canvas.mpl_connect('button_press_event', self.get_cursor_position)
           
        elif self.push1.text() == 'Data Cursor Off':
            if self.temp != []:
                self.temp.remove()
                self.temp = []
                self.figure1.canvas.draw()
                
            self.push1.setText('Data Cursor On')
            
    def scroll_up(self):
        self.model.channel_Scroll_Up()
    def scroll_down(self):
        self.model.channel_Scroll_Down()
    def page_backward(self):
        self.model.page_back()
    def scroll_backward(self):
        self.model.scroll_back()
    def page_forward(self):
        self.model.page_forward()
    def scroll_forward(self):
        self.model.scroll_forward()
    def verticalScaleIncrease(self):
        self.model.verticalScaleIncrease()
    def verticalScaleDecrease(self):
        self.model.verticalScaleDecrease()        
    def plot_interval(self):
        self.model.plot_interval()
    def start_location(self):
        self.model.start_location()
    def channelDisplayed(self):        
        self.model.nChannels_Displayed()
Ejemplo n.º 27
0
class GUIrocket(object):
    def __init__(self, bodies, title, axisbgcol='black'):
        """
        bodies is a dict-like mapping of 1 or more:
           <ID int>: {'density': <float>,
                      'radius': <float>,
                      'position': [<float>, <float>]}
        """
        global next_fighandle

        # Sim setup
        self.model = None
        # context objects (lines of interest, domains, etc)
        self.context_objects = []
        # external tracked objects (measures, etc)
        self.tracked_objects = []
        #
        self.selected_object = None
        self.selected_object_temphandle = None
        #
        self.current_domain_handler = dom.GUI_domain_handler(self)

        # name of task controlling mouse click event handler
        self.mouse_wait_state_owner = None

        # last output from a UI action
        self.last_output = None

        # if defined, will be refreshed on each Go!
        self.calc_context = None

        # Graphics widgets to be set for application
        self.widgets = {}


        # --- SPECIFIC TO BOMBARDIER
        # Setup shoot params
        self.vel = 0.8
        self.ang = 0
        self.da = 0.005
        self.dv = 0.0005
        # used for color-coding trajectories by speed
        self.maxspeed = 2.2

        # one time graphics setup
        # for plot handles
        self.trajline = None
        self.startpt = None
        self.endpt = None
        self.quartiles = None
        # Currently unused
        #self.vtext = None
        #self.atext = None

        # ---------

        # ---- Make these generic in a parent class, then
        # specifcaly configured to bombardier here

        # Axes background colour
        self.axisbgcol = axisbgcol

        # Move these to a _recreate method than can be reused for un-pickling
        self.fig = figure(next_fighandle, figsize=(14,9))
        self.fignum = next_fighandle
        plt.subplots_adjust(left=0.09, right=0.98, top=0.95, bottom=0.1,
                               wspace=0.2, hspace=0.23)
        self.ax = plt.axes([0.05, 0.12, 0.9, 0.85], axisbg=axisbgcol)
        self.ax.set_title(title)
        self.name = title
        evKeyOn = self.fig.canvas.mpl_connect('key_press_event', self.key_on)
        evKeyOff = self.fig.canvas.mpl_connect('key_release_event', self.key_off)

        AngSlide = plt.axes([0.1, 0.055, 0.65, 0.03])
        self.widgets['AngBar'] = Slider(AngSlide, 'Shoot Angle', -maxangle, maxangle,
                                            valinit=self.ang, color='b',
                                            dragging=False, valfmt='%2.3f')
        self.widgets['AngBar'].on_changed(self.updateAng)

        VelSlide = plt.axes([0.1, 0.02, 0.65, 0.03])
        self.widgets['VelBar'] = Slider(VelSlide, 'Shoot Speed', 0.01, 2,
                                            valinit=self.vel, color='b',
                                            dragging=False, valfmt='%1.4f')
        self.widgets['VelBar'].on_changed(self.updateVel)

        # assume max of N-2 planetoid bodies + target + source
        self.N = len(bodies)
        self.gen_versioner = common.gen_versioner(os.path.abspath('.'),
                                                         self.name,
                                                         'simgen_N%i'%self.N,
                                                         gentype, 1)

        # Make this more generic for ABC
        self.setup_pars(bodies)

        # --- END OF BOMBARDIER SPECIFICS

        # Move these to a _recreate method than can be reused for un-pickling
        GoButton = Button(plt.axes([0.005, 0.1, 0.045, 0.03]), 'Go!')
        GoButton.on_clicked(self.go)
        self.widgets['Go'] = GoButton

        self.RS_line = RectangleSelector(self.ax, self.onselect_line, drawtype='line') #,
#                                         minspanx=0.005, minspany=0.005)
        self.RS_line.set_active(False)
#        self.RS_box = RectangleSelector(self.ax, self.onselect_box, drawtype='box',
#                                        minspanx=0.005, minspany=0.005)
#        self.RS_box.set_active(False)

        # context_changed flag set when new objects created using declare_in_context(),
        # and unset when Generator is created with the new context code included
        self.context_changed = False
        self.setup_gen()
        self.traj = None
        self.pts = None

        self.mouse_cid = None # event-connection ID
        self.go(run=False)
        # force call to graphics_refresh because run=False above
        self.graphics_refresh(cla=False)
        plt.show()

        # next_fighandle for whenever a new model is put in a new figure (new game instance)
        next_fighandle += 1


    def graphics_refresh(self, cla=True):
        if cla:
            self.ax.cla()
        self.plot_bodies()
        self.plot_traj()
        # plot additional stuff
        self.plot_context()
        self.ax.set_aspect('equal')
        self.ax.set_xlim(-xdomain_halfwidth,xdomain_halfwidth)
        self.ax.set_ylim(0,1)
        self.fig.canvas.draw()

    def plot_context(self):
        for con_obj in self.context_objects:
            con_obj.show()
        for track_obj in self.tracked_objects:
            # external to main GUI window
            track_obj.show()

    # Methods for pickling protocol
    def __getstate__(self):
        d = copy(self.__dict__)
        for fname, finfo in self._funcreg.items():
            try:
                del d[fname]
            except KeyError:
                pass
        # delete MPL objects
        for obj in some_list:
            try:
                del d['']
            except KeyError:
                pass
        return d

    def __setstate__(self, state):
        # INCOMPLETE!
        self.__dict__.update(state)
        self._stuff = None
        if something != {}:
            self._recreate() # or re-call __init__

    def _recreate(self):
        raise NotImplementedError

    def declare_in_context(self, con_obj):
        # context_changed flag set when new objects created and unset when Generator is
        # created with the new context code included
        self.context_changed = True
        self.context_objects.append(con_obj)

    def __str__(self):
        return self.name

    def setup_pars(self, data):
        # Should generalize to non-bombardier application
        N = self.N
        radii = {}
        density = {}
        pos = {}
        for i, body in data.items():
            pos[i] = pp.Point2D(body['position'][0], body['position'][1],
                                labels={'body': i})
            radii[i] = body['radius']
            density[i] = body['density']
        ixs = range(N)
        self.radii = [radii[i] for i in ixs]
        self.density = [density[i] for i in ixs]
        self.pos = [pos[i] for i in ixs] # planet positions
        self.masses = [density[i]*np.pi*r*r for (i,r) in enumerate(self.radii)]
        rdict = dict([('r%i' %i, self.radii[i]) for i in ixs])
        mdict = dict([('m%i' %i, self.masses[i]) for i in ixs])
        posxdict = dict([('bx%i' %i, pos[i][0]) for i in ixs])
        posydict = dict([('by%i' %i, pos[i][1]) for i in ixs])
        pardict = {'G': G}  # global param for gravitational constant
        pardict.update(rdict)
        pardict.update(mdict)
        pardict.update(posxdict)
        pardict.update(posydict)
        self.body_pars = pardict
        self.icpos = np.array((0.0, 0.08))
        self.icvel = np.array((0.0, 0.0))

    def setup_gen(self):
        if self.context_changed:
            self.context_changed = False
            self.make_gen(self.body_pars, 'sim_N%i'%self.N+'_fig%i'%self.fignum)
        else:
            try:
                self.model = self.gen_versioner.load_gen('sim_N%i'%self.N+'_fig%i'%self.fignum)
            except:
                self.make_gen(self.body_pars, 'sim_N%i'%self.N+'_fig%i'%self.fignum)
            else:
                self.model.set(pars=self.body_pars)


    def make_gen(self, pardict, name):
        # scrape GUI diagnostic object extras for generator
        extra_events = []
        extra_fnspecs = {}
        extra_pars = {}
        extra_auxvars = {}
        for gui_obj in self.context_objects:
            extra_events.append(gui_obj.extra_events)
            extra_fnspecs.update(gui_obj.extra_fnspecs)
            extra_pars.update(gui_obj.extra_pars)
            extra_auxvars.update(gui_obj.extra_auxvars)

        Fx_str = ""
        Fy_str = ""
        for i in range(self.N):
            Fx_str += "-G*m%i*(x-bx%i)/pow(d(x,y,bx%i,by%i),3)" % (i,i,i,i)
            Fy_str += "-G*m%i*(y-by%i)/pow(d(x,y,bx%i,by%i),3)" % (i,i,i,i)

        DSargs = args()
        DSargs.varspecs = {'vx': Fx_str, 'x': 'vx',
                           'vy': Fy_str, 'y': 'vy',
                           'Fx_out': 'Fx(x,y)', 'Fy_out': 'Fy(x,y)',
                           'speed': 'sqrt(vx*vx+vy*vy)',
                           'bearing': '90-180*atan2(vy,vx)/pi'}
        DSargs.varspecs.update(extra_auxvars)
        auxfndict = {'Fx': (['x', 'y'], Fx_str),
                     'Fy': (['x', 'y'], Fy_str),
                     'd': (['xx', 'yy', 'x1', 'y1'], "sqrt((xx-x1)*(xx-x1)+(yy-y1)*(yy-y1))")
                    }
        DSargs.auxvars = ['Fx_out', 'Fy_out', 'speed', 'bearing'] + \
            extra_auxvars.keys()
        DSargs.pars = pardict
        DSargs.pars.update(extra_pars)
        DSargs.fnspecs = auxfndict
        DSargs.fnspecs.update(extra_fnspecs)
        DSargs.algparams = {'init_step':0.001,
                            'max_step': 0.01,
                            'max_pts': 20000,
                            'maxevtpts': 2,
                            'refine': 5}

        targetlang = \
            self.gen_versioner._targetlangs[self.gen_versioner.gen_type]

        # Events for external boundaries (left, right, top, bottom)
        Lev = Events.makeZeroCrossEvent('x+%f'%xdomain_halfwidth, -1,
                                        {'name': 'Lev',
                                         'eventtol': 1e-5,
                                         'precise': True,
                                         'term': True},
                                        varnames=['x'],
                                        targetlang=targetlang)
        Rev = Events.makeZeroCrossEvent('x-%f'%xdomain_halfwidth, 1,
                                        {'name': 'Rev',
                                         'eventtol': 1e-5,
                                         'precise': True,
                                         'term': True},
                                        varnames=['x'],
                                        targetlang=targetlang)
        Tev = Events.makeZeroCrossEvent('y-1', 1,
                                        {'name': 'Tev',
                                         'eventtol': 1e-5,
                                         'precise': True,
                                         'term': True},
                                        varnames=['y'],
                                        targetlang=targetlang)
        Bev = Events.makeZeroCrossEvent('y', -1,
                                        {'name': 'Bev',
                                         'eventtol': 1e-5,
                                         'precise': True,
                                         'term': True},
                                        varnames=['y'],
                                        targetlang=targetlang)

        # Events for planetoids
        bevs = []
        for i in range(self.N):
            bev = Events.makeZeroCrossEvent('d(x,y,bx%i,by%i)-r%i' % (i,i,i),
                                            -1,
                                        {'name': 'b%iev' %i,
                                         'eventtol': 1e-5,
                                         'precise': True,
                                         'term': True},
                                        varnames=['x','y'],
                                        parnames=pardict.keys(),
                                        fnspecs=auxfndict,
                                        targetlang=targetlang)
            bevs.append(bev)

        DSargs.events = [Lev, Rev, Tev, Bev] + bevs + extra_events
        DSargs.checklevel = 2
        DSargs.ics = {'x': self.icpos[0], 'y': self.icpos[1],
                      'vx': 0., 'vy': 1.5}
        DSargs.name = name
        DSargs.tdomain = [0, 10000]
        DSargs.tdata = [0, 50]

        # turns arguments into Generator then embed into Model object
        self.model = self.gen_versioner.make(DSargs)


    def go(self, run=True):
        """
        Note: This method can only start a trajectory from the
        launcher at the bottom of the screen!

        To shoot from a specific point that's been set by hand,
        call self.run() then self.graphics_refresh(cla=False)
        """
        a = self.ang
        v = self.vel
        # Angle a of shooting is relative to vertical, up to +/- maxangle degrees
        if a > maxangle:
            # assume is vestigial from a different initial condition
            a = maxangle
        elif a < -maxangle:
            a = -maxangle
        rad = pi*(a-90)/180.
        x = self.radii[0]*cos(rad)
        y = -self.radii[0]*sin(rad)
        vx = v*cos(rad)
        vy = -v*sin(rad)
        self.model.set(ics={'vx': vx, 'vy': vy,
                             'x': x, 'y': y})
        if run:
            self.run()
            self.graphics_refresh(cla=False)
        self.fig.canvas.draw()
        plt.draw()

    def set(self, pair, ic=None, by_vel=False):
        """Set solution pair (ang, speed) and optional (x,y)
        initial condition, where ang is in degrees.

        With option by_vel=True (default False),
         the pair will be treated as (vx, vy) instead
        """
        assert len(pair) == 2
        if ic is not None:
            assert 'x' in ic and 'y' in ic and len(ic) == 2
            self.model.set(ics=ic)
            self.icpos = ic
            if by_vel:
                vx, vy = pair
                # both conversions in this section are -90?
                self.ang = 180*atan2(vy,vx)/pi - 90
                self.vel = sqrt(vx*vx+vy*vy)
            else:
                # can't set ang and vel according to rules for regular
                # shooting because we are reconstructing a partial
                # trajectory out in space
                self.ang, self.vel = pair
                rad = pi*(self.ang-90)/180.
                vx = self.vel*cos(rad)
                vy = -self.vel*sin(rad)
            self.model.set(ics={'vx': vx, 'vy': vy})
        else:
            self.setAng(pair[0])
            self.setVel(pair[1])


    def setAng(self, ang):
        self.widgets['AngBar'].set_val(ang)

    def setVel(self, vel):
        self.widgets['VelBar'].set_val(vel)

    def updateAng(self, ang):
        if ang < -maxangle:
            ang = -maxangle
        elif ang > maxangle:
            ang = maxangle
        self.ang = ang
        self.go(run=False)

    def updateVel(self, vel):
        if vel < 0.01:
            print "Velocity must be >= 0.01"
            vel = 0.01
        self.vel = vel
        self.go(run=False)

    def key_on(self, ev):
        self._key = k = ev.key  # keep record of last keypress
        # TEMP
        print "Pressed", k
        if self.mouse_wait_state_owner == 'domain' and \
           k in change_mouse_state_keys:
            # reset state of domain handler first
            self.current_domain_handler.event('clear')

        if k in da_dict:
            Da = da_dict[k]*self.da
            self.updateAng(self.ang+Da)
            self.widgets['AngBar'].set_val(self.ang)
        elif k in dv_dict:
            Dv = dv_dict[k]*self.dv
            self.updateVel(self.vel+Dv)
            self.widgets['VelBar'].set_val(self.vel)
        elif k == 'g':
            print("Go! Running simulation.")
            self.go()
        elif k == 'l':
            print("Make a line of interest")
            self.RS_line.set_active(True)
            self.mouse_wait_state_owner = 'line'
        elif k == ' ':
            print("Forces at clicked mouse point")
            self.mouse_cid = self.fig.canvas.mpl_connect('button_release_event', self.mouse_event_force)
            self.mouse_wait_state_owner = 'forces'
        elif k == 's':
            print("Snap clicked mouse point to closest point on trajectory")
            self.mouse_cid = self.fig.canvas.mpl_connect('button_release_event', self.mouse_event_snap)
            self.mouse_wait_state_owner = 'snap'
        elif k == dom_key:
            print("Click on domain seed point then initial radius point")
            # grow domain
            if self.current_domain_handler.func is None:
                print("Assign a domain criterion function first!")
                return
            else:
                # this call may have side-effects
                self.current_domain_handler.event('key')
                self.mouse_wait_state_owner = 'domain'

    def key_off(self, ev):
        self._key = None

    def plot_bodies(self):
        for i in range(self.N):
            px, py = self.pos[i]
            if self.radii[i] > 0:
                if self.density[i] == 0:
                    col = 'green'
                else:
                    col = 'grey'
                self.ax.add_artist(plt.Circle((px,py),self.radii[i],color=col))
                self.ax.plot(px,py,'k.')
                self.ax.text(px-0.016,min(0.96, max(0.01,py-0.008)), str(i))

    def plot_traj(self, pts=None, with_speeds=True):
        """
        with_speeds option makes a "heat map" like color code along trajectory that denotes speed.
        """
        if pts is None:
            if self.pts is not None:
                pts = self.pts
            else:
                # nothing to plot
                return
        if self.axisbgcol == 'black':
            col = 'w'
        else:
            col = 'k'
        firstpt = pts[0]
        lastpt = pts[-1]

        if self.startpt is None:
            self.startpt = self.ax.plot(firstpt['x'],firstpt['y'],'ys', markersize=15)[0]
        else:
            self.startpt.set_xdata(firstpt['x'])
            self.startpt.set_ydata(firstpt['y'])

        if self.trajline is not None:
            self.trajline.remove()
        if with_speeds:
            speeds = pts['speed']
            norm = mpl.colors.Normalize(vmin=0, vmax=self.maxspeed)
            cmap=plt.cm.jet #gist_heat
            RGBAs = cmap(norm(speeds))
            xs = pts['x'][1:-1]
            ys = pts['y'][1:-1]
            segments = [( (xs[i], ys[i]), (xs[i+1], ys[i+1]) ) for i in range(len(xs)-1)]
            linecollection = mpl.collections.LineCollection(segments, colors=RGBAs)
            self.trajline = self.ax.add_collection(linecollection)
        else:
            self.trajline = self.ax.plot(pts['x'][1:-1], pts['y'][1:-1], col+'.-')[0]

        if self.endpt is None:
            self.endpt = self.ax.plot(lastpt['x'], lastpt['y'], 'r*', markersize=17)[0]
        else:
            self.endpt.set_xdata(lastpt['x'])
            self.endpt.set_ydata(lastpt['y'])
        n = len(pts)
        ptq1 = pts[int(0.25*n)]
        ptq2 = pts[int(0.5*n)]
        ptq3 = pts[int(0.75*n)]
        if self.quartiles is None:
            self.quartiles = [self.ax.plot(ptq1['x'], ptq1['y'], col+'d', markersize=10)[0],
                              self.ax.plot(ptq2['x'], ptq2['y'], col+'d', markersize=10)[0],
                              self.ax.plot(ptq3['x'], ptq3['y'], col+'d', markersize=10)[0]]
        else:
            self.quartiles[0].set_xdata(ptq1['x'])
            self.quartiles[0].set_ydata(ptq1['y'])
            self.quartiles[1].set_xdata(ptq2['x'])
            self.quartiles[1].set_ydata(ptq2['y'])
            self.quartiles[2].set_xdata(ptq3['x'])
            self.quartiles[2].set_ydata(ptq3['y'])
        plt.draw()


    def run(self, tmax=None):
        self.model.compute('test', force=True)
        self.traj = self.model.trajectories['test']
        self.pts = self.traj.sample()
        if self.calc_context is not None:
            # Update calc context
            self.calc_context()

    def get_forces(self, x, y):
        """
        For given x, y coord arguments, returns two dictionaries keyed
        by body number (1-N):
        net force magnitude, force vector
        """
        # Bombardier specific
        Fxs = []
        Fys = []
        Fs = []
        pars = self.model.query('pars')
        ixs = range(self.N)
        for i in ixs:
            m = pars['m%i'%i]
            bx = pars['bx%i'%i]
            by = pars['by%i'%i]
            p = pow(pp.distfun(x,y,bx,by),3)
            Fx = -m*(x-bx)/p
            Fy = -m*(y-by)/p
            Fxs.append(Fx)
            Fys.append(Fy)
            Fs.append(sqrt(Fx*Fx+Fy*Fy))
        return dict(zip(ixs, Fs)), dict(zip(ixs, zip(Fxs, Fys)))

    def set_planet_data(self, n, data):
        assert n in range(self.N)

        # default to old radius, unless updated (for masses)
        r = self.model.query('pars')['r%i'%n]
        d = self.density[n]
        pardict = {}
        for key, val in data.items():
            if key == 'r':
                pardict['r%i'%n] = val
                r = val
                self.radii[n] = r
            elif key == 'x':
                pardict['bx%i'%n] = val
                p = self.pos[n]
                self.pos[n] = (val, p.y)
            elif key == 'y':
                pardict['by%i'%n] = val
                p = self.pos[n]
                self.pos[n] = (p.x, val)
            elif key == 'd':
                d = val
                self.density[n] = d
            else:
                raise ValueError("Invalid parameter key: %s"%key)
            pardict['m%i'%n] = G*d*np.pi*r*r

        self.model.set(pars=pardict)
        self.body_pars.update(pardict)
        self.ax.cla()
        self.ax.set_aspect('equal')
        self.ax.set_xlim(-xdomain_halfwidth,xdomain_halfwidth)
        self.ax.set_ylim(0,1)

        self.trajline = None
        self.startpt = None
        self.endpt = None
        self.go(run=False)
        #self.graphics_refresh()

    def mouse_event_force(self, ev):
        print("\n(%.4f, %.4f)" %(ev.xdata, ev.ydata))
        fs, fvecs = self.get_forces(ev.xdata, ev.ydata)
        print(fs)
        print("Last output = (force mag dict, force vector dict)")
        self.last_output = (fs, fvecs)
        self.selected_object = pp.Point2D(ev.xdata, ev.ydata)
        if self.selected_object_temphandle is not None:
            self.selected_object_temphandle.remove()
        self.selected_object_temphandle = self.ax.plot(ev.xdata, ev.ydata, 'go')[0]
        self.fig.canvas.draw()
        self.fig.canvas.mpl_disconnect(self.mouse_cid)
        self.mouse_wait_state_owner = None

    def mouse_event_snap(self, ev):
        if self.pts is None:
            print("No trajectory defined")
            return
        print("\nClick: (%.4f, %.4f)" %(ev.xdata, ev.ydata))
        # have to guess phase, use widest tolerance
        try:
            data = pp.find_pt_nophase_2D(self.pts, pp.Point2D(ev.xdata, ev.ydata),
                                         eps=0.1)
        except ValueError:
            print("No nearby point found. Try again")
            self.fig.canvas.mpl_disconnect(self.mouse_cid)
            return
        self.last_output = data
        x_snap = data[2]['x']
        y_snap = data[2]['y']
        self.selected_object = pp.Point2D(x_snap, y_snap)
        if self.selected_object_temphandle is not None:
            self.selected_object_temphandle.remove()
        self.selected_object_temphandle = self.ax.plot(x_snap, y_snap, 'go')[0]
        self.fig.canvas.draw()
        print("Last output = (index, distance, point)")
        print("            = (%i, %.3f, (%.3f, %.3f))" % (data[0], data[1],
                                                          x_snap, y_snap))
        self.fig.canvas.mpl_disconnect(self.mouse_cid)
        self.mouse_wait_state_owner = None

    def onselect_line(self, eclick, erelease):
        if eclick.button == 1:
            # left (primary)
            x1, y1 = eclick.xdata, eclick.ydata
            x2, y2 = erelease.xdata, erelease.ydata
            self.selected_object = graphics.line_GUI(self, self.ax,
                                            pp.Point2D(x1, y1),
                                            pp.Point2D(x2, y2))
            print("Created line as new selected object, now give it a name")
            print("  by writing this object's selected_object.name attribute")
            self.RS_line.set_active(False)
            self.mouse_wait_state_owner = None

    def onselect_box(self, eclick, erelease):
        self.mouse_wait_state_owner = None
Ejemplo n.º 28
0
def select_roi(image, rois=None, ax=None, axim=None, qtapp=None):
    """Return a label image based on polygon selections made with the mouse.

    Parameters
    ----------
    image : (M, N[, 3]) array
        Grayscale or RGB image.
    rois : list, optional
        If given, append ROIs to this existing list. Otherwise a new list
        object will be created.
    ax : matplotlib Axes, optional
        The Axes on which to do the plotting.
    axim : matplotlib AxesImage, optional
        An existing AxesImage on which to show the image.
    qtapp : QtApplication
        The main Qt application for ROI selection. If given, the ROIs will
        be inserted at the right location for the image index.

    Returns
    -------
    rois : list of tuple of ints
        The selected regions, in the form
        [[(row_start, row_end), (col_start, col_end)]].

    Notes
    -----
    Use left click to select the vertices of the polygon
    and right click to confirm the selection once all vertices are selected.

    Examples
    --------
    >>> from skimage import data, future, io
    >>> camera = data.camera()
    >>> mask = future.manual_polygon_segmentation(camera)  # doctest: +SKIP
    >>> io.imshow(mask)  # doctest: +SKIP
    >>> io.show()  # doctest: +SKIP
    """
    if image.ndim not in (2, 3):
        raise ValueError('Only 2D grayscale or RGB images are supported.')

    if ax is None and axim is None:
        fig, ax = plt.subplots()
    if axim is None:
        ax.clear()
        axim = ax.imshow(image, cmap="magma")
        ax.set_axis_off()
    else:
        axim.set_array(image)
    rois = rois or []

    def toggle_selector(event):
        if event.key in ['A', 'a'] and not toggle_selector.RS.active:
            toggle_selector.RS.set_active(True)

    def onselect(eclick, erelease):
        starts = round(eclick.ydata), round(eclick.xdata)
        ends = round(erelease.ydata), round(erelease.xdata)
        slices = tuple((int(s), int(e)) for s, e in zip(starts, ends))
        if qtapp is None:
            rois.append(slices)
        else:
            index = qtapp.image_index
            rois[index] = slices
            qtapp.rectangle_selector.set_active(False)
            qtapp.select_next_image()
            qtapp.rectangle_selector.set_active(True)

    selector = RectangleSelector(ax, onselect, useblit=True)
    if qtapp is None:
        # Ensure that the widget remains active by creating a reference to it.
        # There's probably a better place to put that reference but this will do
        # for now. (From the matplotlib RectangleSelector gallery example.)
        toggle_selector.RS = selector
    else:
        qtapp.rectangle_selector = selector
    ax.figure.canvas.mpl_connect('key_press_event', toggle_selector)
    selector.set_active(True)
    return rois
Ejemplo n.º 29
0
class RectanglePixelRegion(PixelRegion):
    """
    A rectangle in pixel coordinates.

    Parameters
    ----------
    center : `~regions.PixCoord`
        The position of the center of the rectangle.
    width : `float`
        The width of the rectangle (before rotation) in pixels
    height : `float`
        The height of the rectangle (before rotation) in pixels
    angle : `~astropy.units.Quantity`, optional
        The rotation angle of the rectangle, measured anti-clockwise. If set to
        zero (the default), the width axis is lined up with the x axis.
    meta : `~regions.RegionMeta` object, optional
        A dictionary which stores the meta attributes of this region.
    visual : `~regions.RegionVisual` object, optional
        A dictionary which stores the visual meta attributes of this region.

    Examples
    --------

    .. plot::
        :include-source:

        import numpy as np
        from astropy.coordinates import Angle
        from regions import PixCoord, RectanglePixelRegion
        import matplotlib.pyplot as plt

        x, y = 15, 10
        width, height = 8, 5
        angle = Angle(30, 'deg')

        fig, ax = plt.subplots(1, 1)

        center = PixCoord(x=x, y=y)
        reg = RectanglePixelRegion(center=center, width=width,
                                   height=height, angle=angle)
        patch = reg.as_artist(facecolor='none', edgecolor='red', lw=2)
        ax.add_patch(patch)

        plt.xlim(0, 30)
        plt.ylim(0, 20)
        ax.set_aspect('equal')
    """
    _params = ('center', 'width', 'height', 'angle')
    center = ScalarPix('center')
    width = ScalarLength('width')
    height = ScalarLength('height')
    angle = QuantityLength('angle')

    def __init__(self,
                 center,
                 width,
                 height,
                 angle=0 * u.deg,
                 meta=None,
                 visual=None):
        self.center = center
        self.width = width
        self.height = height
        self.angle = angle
        self.meta = meta or {}
        self.visual = visual or {}

    @property
    def area(self):
        """Region area (float)"""
        return self.width * self.height

    def contains(self, pixcoord):
        cos_angle = np.cos(self.angle)
        sin_angle = np.sin(self.angle)
        dx = pixcoord.x - self.center.x
        dy = pixcoord.y - self.center.y
        dx_rot = cos_angle * dx + sin_angle * dy
        dy_rot = sin_angle * dx - cos_angle * dy
        in_rect = (np.abs(dx_rot) < self.width * 0.5) & (np.abs(dy_rot) <
                                                         self.height * 0.5)
        if self.meta.get('include', True):
            return in_rect
        else:
            return np.logical_not(in_rect)

    def to_sky(self, wcs):
        # TODO: write a pixel_to_skycoord_scale_angle
        center = pixel_to_skycoord(self.center.x, self.center.y, wcs)
        _, scale, north_angle = skycoord_to_pixel_scale_angle(center, wcs)
        width = Angle(self.width / scale, 'deg')
        height = Angle(self.height / scale, 'deg')
        return RectangleSkyRegion(center,
                                  width,
                                  height,
                                  angle=self.angle -
                                  (north_angle - 90 * u.deg),
                                  meta=self.meta,
                                  visual=self.visual)

    @property
    def bounding_box(self):
        """
        The minimal bounding box (`~regions.BoundingBox`) enclosing the
        exact rectangular region.
        """

        w2 = self.width / 2.
        h2 = self.height / 2.
        cos_angle = np.cos(self.angle)  # self.angle is a Quantity
        sin_angle = np.sin(self.angle)  # self.angle is a Quantity
        dx1 = abs(w2 * cos_angle - h2 * sin_angle)
        dy1 = abs(w2 * sin_angle + h2 * cos_angle)
        dx2 = abs(w2 * cos_angle + h2 * sin_angle)
        dy2 = abs(w2 * sin_angle - h2 * cos_angle)
        dx = max(dx1, dx2)
        dy = max(dy1, dy2)

        xmin = self.center.x - dx
        xmax = self.center.x + dx
        ymin = self.center.y - dy
        ymax = self.center.y + dy

        return BoundingBox.from_float(xmin, xmax, ymin, ymax)

    def to_mask(self, mode='center', subpixels=5):

        # NOTE: assumes this class represents a single circle

        self._validate_mode(mode, subpixels)

        if mode == 'center':
            mode = 'subpixels'
            subpixels = 1

        # Find bounding box and mask size
        bbox = self.bounding_box
        ny, nx = bbox.shape

        # Find position of pixel edges and recenter so that circle is at origin
        xmin = float(bbox.ixmin) - 0.5 - self.center.x
        xmax = float(bbox.ixmax) - 0.5 - self.center.x
        ymin = float(bbox.iymin) - 0.5 - self.center.y
        ymax = float(bbox.iymax) - 0.5 - self.center.y

        if mode == 'subpixels':
            use_exact = 0
        else:
            use_exact = 1

        fraction = rectangular_overlap_grid(
            xmin,
            xmax,
            ymin,
            ymax,
            nx,
            ny,
            self.width,
            self.height,
            self.angle.to(u.rad).value,
            use_exact,
            subpixels,
        )

        return RegionMask(fraction, bbox=bbox)

    def as_artist(self, origin=(0, 0), **kwargs):
        """
        Matplotlib patch object for this region (`matplotlib.patches.Rectangle`).

        Parameters
        ----------
        origin : array_like, optional
            The ``(x, y)`` pixel position of the origin of the displayed image.
            Default is (0, 0).
        kwargs : `dict`
            All keywords that a `~matplotlib.patches.Rectangle` object accepts

        Returns
        -------
        patch : `~matplotlib.patches.Rectangle`
            Matplotlib circle patch
        """
        from matplotlib.patches import Rectangle
        xy = self._lower_left_xy()
        xy = xy[0] - origin[0], xy[1] - origin[1]
        width = self.width
        height = self.height
        # From the docstring: MPL expects "rotation in degrees (anti-clockwise)"
        angle = self.angle.to('deg').value

        mpl_params = self.mpl_properties_default('patch')
        mpl_params.update(kwargs)

        return Rectangle(xy=xy,
                         width=width,
                         height=height,
                         angle=angle,
                         **mpl_params)

    def _update_from_mpl_selector(self, *args, **kwargs):
        xmin, xmax, ymin, ymax = self._mpl_selector.extents
        self.center = PixCoord(x=0.5 * (xmin + xmax), y=0.5 * (ymin + ymax))
        self.width = (xmax - xmin)
        self.height = (ymax - ymin)
        self.angle = 0. * u.deg
        if self._mpl_selector_callback is not None:
            self._mpl_selector_callback(self)

    def as_mpl_selector(self,
                        ax,
                        active=True,
                        sync=True,
                        callback=None,
                        **kwargs):
        """
        Matplotlib editable widget for this region (`matplotlib.widgets.RectangleSelector`)

        Parameters
        ----------
        ax : `~matplotlib.axes.Axes`
            The Matplotlib axes to add the selector to.
        active : bool, optional
            Whether the selector should be active by default.
        sync : bool, optional
            If `True` (the default), the region will be kept in sync with the
            selector. Otherwise, the selector will be initialized with the
            values from the region but the two will then be disconnected.
        callback : func, optional
            If specified, this function will be called every time the region is
            updated. This only has an effect if ``sync`` is `True`. If a
            callback is set, it is called for the first time once the selector
            has been created.
        kwargs
            Additional keyword arguments are passed to matplotlib.widgets.RectangleSelector`

        Returns
        -------
        selector : `matplotlib.widgets.RectangleSelector`
            The Matplotlib selector.

        Notes
        -----
        Once a selector has been created, you will need to keep a reference to
        it until you no longer need it. In addition, you can enable/disable the
        selector at any point by calling ``selector.set_active(True)`` or
        ``selector.set_active(False)``.
        """

        from matplotlib.widgets import RectangleSelector

        if hasattr(self, '_mpl_selector'):
            raise Exception(
                "Cannot attach more than one selector to a region.")

        if self.angle.value != 0:
            raise NotImplementedError(
                "Cannot create matplotlib selector for rotated rectangle.")

        if sync:
            sync_callback = self._update_from_mpl_selector
        else:

            def sync_callback(*args, **kwargs):
                pass

        self._mpl_selector = RectangleSelector(
            ax,
            sync_callback,
            interactive=True,
            rectprops={
                'edgecolor': self.visual.get('color', 'black'),
                'facecolor': 'none',
                'linewidth': self.visual.get('linewidth', 1),
                'linestyle': self.visual.get('linestyle', 'solid')
            })
        self._mpl_selector.extents = (self.center.x - self.width / 2,
                                      self.center.x + self.width / 2,
                                      self.center.y - self.height / 2,
                                      self.center.y + self.height / 2)
        self._mpl_selector.set_active(active)
        self._mpl_selector_callback = callback

        if sync and self._mpl_selector_callback is not None:
            self._mpl_selector_callback(self)

        return self._mpl_selector

    @property
    def corners(self):
        """
        Return the x, y coordinate pairs that define the corners
        """

        corners = [
            (-self.width / 2, -self.height / 2),
            (self.width / 2, -self.height / 2),
            (self.width / 2, self.height / 2),
            (-self.width / 2, self.height / 2),
        ]
        rotmat = [[np.cos(self.angle), np.sin(self.angle)],
                  [-np.sin(self.angle),
                   np.cos(self.angle)]]

        return np.dot(corners, rotmat) + np.array(
            [self.center.x, self.center.y])

    def to_polygon(self):
        """
        Return a 4-cornered polygon equivalent to this rectangle
        """
        x, y = self.corners.T
        vertices = PixCoord(x=x, y=y)
        return PolygonPixelRegion(vertices=vertices,
                                  meta=self.meta,
                                  visual=self.visual)

    def _lower_left_xy(self):
        """
        Compute lower left `xy` position.

        This is used for the conversion to matplotlib in ``as_artist``

        Taken from http://photutils.readthedocs.io/en/latest/_modules/photutils/aperture/rectangle.html#RectangularAperture.plot
        """
        hw = self.width / 2.
        hh = self.height / 2.
        sint = np.sin(self.angle)
        cost = np.cos(self.angle)
        dx = (hh * sint) - (hw * cost)
        dy = -(hh * cost) - (hw * sint)
        x = self.center.x + dx
        y = self.center.y + dy
        return x, y

    def rotate(self, center, angle):
        """Make a rotated region.

        Rotates counter-clockwise for positive ``angle``.

        Parameters
        ----------
        center : `PixCoord`
            Rotation center point
        angle : `~astropy.coordinates.Angle`
            Rotation angle

        Returns
        -------
        region : `RectanglePixelRegion`
            Rotated region (an independent copy)
        """
        center = self.center.rotate(center, angle)
        angle = self.angle + angle
        return self.copy(center=center, angle=angle)
Ejemplo n.º 30
0
class ObjectLabeler():
    def __init__(self, frameDirectory, annotationFile, number, projectID):

        self.frameDirectory = frameDirectory
        self.annotationFile = annotationFile
        self.number = number
        self.projectID = projectID

        self.frames = sorted([
            x for x in os.listdir(self.frameDirectory)
            if '.jpg' in x and '._' not in x
        ])
        random.Random(4).shuffle(self.frames)
        #self.frames = sorted([x for x in os.listdir(self.frameDirectory) if '.jpg' in x and '._' not in x]) # remove annoying mac OSX files
        assert len(self.frames) > 0

        # Keep track of the frame we are on and how many we have annotated
        self.frame_index = 0
        self.annotated_frames = []

        # Intialize lists to hold annotated objects
        self.coords = ()

        # Create dataframe to hold annotations
        if os.path.exists(self.annotationFile):
            self.dt = pd.read_csv(self.annotationFile, index_col=0)
        else:
            self.dt = pd.DataFrame(columns=[
                'ProjectID', 'Framefile', 'Nfish', 'Sex', 'Box', 'User',
                'DateTime'
            ])
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])

        # Get user and current time
        self.user = os.getenv('USER')
        self.now = datetime.datetime.now()

        # Create Annotation object
        self.annotation = Annotation(self)

        #
        self.annotation_text = ''

        # Start figure
        self._createFigure()

    def _createFigure(self):
        # Create figure
        self.fig = fig = plt.figure(1, figsize=(10, 7))

        # Create image subplot
        self.ax_image = fig.add_axes([0.05, 0.2, .8, 0.75])
        while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index])
                          & (self.dt.User == self.user)]) != 0:
            self.frame_index += 1
        # Create slider for saturation
        self.ax_saturation = fig.add_axes([0.1, 0.08, 0.2, 0.03])
        self.slid_saturation = Slider(self.ax_saturation,
                                      'Saturation',
                                      0,
                                      10,
                                      valinit=1,
                                      valstep=.1)

        # Plot image
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        #print(img.shape)
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        self.image_obj = self.ax_image.imshow(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])

        # Create selectors for identifying bounding bos and body parts (nose, left eye, right eye, tail)

        self.RS = RectangleSelector(
            self.ax_image,
            self._grabBoundingBox,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)
        self.RS.set_active(True)

        # Create radio buttons
        self.ax_radio = fig.add_axes([0.85, 0.85, 0.125, 0.1])
        self.radio_names = [
            r"$\bf{M}$" + 'ale', r"$\bf{F}$" + 'emale', r"$\bf{U}$" + 'nknown'
        ]
        self.bt_radio = RadioButtons(self.ax_radio,
                                     self.radio_names,
                                     active=0,
                                     activecolor='blue')

        # Create click buttons for adding annotations
        self.ax_boxAdd = fig.add_axes([0.85, 0.775, 0.125, 0.04])
        self.bt_boxAdd = Button(self.ax_boxAdd, r"$\bf{A}$" + 'dd Box')
        self.ax_boxClear = fig.add_axes([0.85, 0.725, 0.125, 0.04])
        self.bt_boxClear = Button(self.ax_boxClear, r"$\bf{C}$" + 'lear Box')

        # Create click buttons for saving frame annotations or starting over
        self.ax_frameClear = fig.add_axes([0.85, 0.375, 0.125, 0.04])
        self.bt_frameClear = Button(self.ax_frameClear,
                                    r"$\bf{R}$" + 'eset Frame')
        self.ax_frameAdd = fig.add_axes([0.85, 0.325, 0.125, 0.04])
        self.bt_frameAdd = Button(self.ax_frameAdd, r"$\bf{N}$" + 'ext Frame')
        self.ax_framePrevious = fig.add_axes([0.85, 0.275, 0.125, 0.04])
        self.bt_framePrevious = Button(self.ax_framePrevious,
                                       r"$\bf{P}$" + 'revious Frame')

        # Create click button for quitting annotations
        self.ax_quit = fig.add_axes([0.85, 0.175, 0.125, 0.04])
        self.bt_quit = Button(self.ax_quit, r"$\bf{Q}$" + 'uit and save')

        # Add text boxes to display info on annotations
        self.ax_cur_text = fig.add_axes([0.85, 0.575, 0.125, 0.14])
        self.ax_cur_text.set_axis_off()
        self.cur_text = self.ax_cur_text.text(0,
                                              1,
                                              '',
                                              fontsize=8,
                                              verticalalignment='top')

        self.ax_all_text = fig.add_axes([0.85, 0.425, 0.125, 0.19])
        self.ax_all_text.set_axis_off()
        self.all_text = self.ax_all_text.text(0,
                                              1,
                                              '',
                                              fontsize=9,
                                              verticalalignment='top')

        self.ax_error_text = fig.add_axes([0.3, 0.05, .6, 0.1])
        self.ax_error_text.set_axis_off()
        self.error_text = self.ax_error_text.text(0,
                                                  1,
                                                  '',
                                                  fontsize=14,
                                                  color='red',
                                                  verticalalignment='top')

        # Set buttons in active that shouldn't be pressed
        #self.bt_poses.set_active(False)

        # Turn on keypress events to speed things up
        self.fig.canvas.mpl_connect('key_press_event', self._keypress)

        # Turn off hover event for buttons (no idea why but this interferes with the image rectange remaining displayed)
        self.fig.canvas.mpl_disconnect(self.bt_boxAdd.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_boxClear.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_frameAdd.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_frameClear.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_framePrevious.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_quit.cids[2])

        # Connect buttons to specific functions
        self.bt_boxAdd.on_clicked(self._addBoundingBox)
        self.bt_boxClear.on_clicked(self._clearBoundingBox)
        self.bt_frameClear.on_clicked(self._clearFrame)
        self.bt_framePrevious.on_clicked(self._previousFrame)
        self.bt_frameAdd.on_clicked(self._nextFrame)
        self.bt_quit.on_clicked(self._quit)
        self.slid_saturation.on_changed(self._updateFrame)

        # Show figure
        plt.show()

    def _grabBoundingBox(self, eclick, erelease):
        self.error_text.set_text('')

        # Transform and store image coords
        image_coords = list(self.ax_image.transData.inverted().transform(
            (eclick.x, eclick.y))) + list(
                self.ax_image.transData.inverted().transform(
                    (erelease.x, erelease.y)))

        # Convert to integers:
        image_coords = tuple([int(x) for x in image_coords])

        xy = (min(image_coords[0],
                  image_coords[2]), min(image_coords[1], image_coords[3]))
        width = abs(image_coords[0] - image_coords[2])
        height = abs(image_coords[1] - image_coords[3])
        self.annotation.coords = xy + (width, height)

    def _keypress(self, event):
        if event.key in ['m', 'f', 'u']:
            self.bt_radio.set_active(['m', 'f', 'u'].index(event.key))
            #self.fig.canvas.draw()
        elif event.key == 'a':
            self._addBoundingBox(event)
        elif event.key == 'c':
            self._clearBoundingBox(event)
        elif event.key == 'n':
            self._nextFrame(event)
        elif event.key == 'r':
            self._clearFrame(event)
        elif event.key == 'p':
            self._previousFrame(event)
        elif event.key == 'q':
            self._quit(event)
        else:
            pass

    def _addBoundingBox(self, event):
        if self.annotation.coords == ():
            self.error_text.set_text('Error: Bounding box not set')
            self.fig.canvas.draw()
            return

        displayed_names = [
            r"$\bf{M}$" + 'ale', r"$\bf{F}$" + 'emale', r"$\bf{U}$" + 'nknown'
        ]
        stored_names = ['m', 'f', 'u']

        self.annotation.sex = stored_names[displayed_names.index(
            self.bt_radio.value_selected)]

        # Add new patch rectangle
        #colormap = {self.radio_names[0]:'blue', self.radio_names[1]:'pink', self.radio_names[2]: 'red', self.radio_names[3]: 'black'}
        #color = colormap[self.bt_radio.value_selected]
        self.annotation.addRectangle()

        outrow = self.annotation.retRow()

        if type(outrow) == str:
            self.error_text.set_text(outrow)
            self.fig.canvas.draw()
            return
        else:
            self.f_dt.loc[len(self.f_dt)] = outrow
            self.f_dt.drop_duplicates(
                subset=['ProjectID', 'Framefile', 'User', 'Sex', 'Box'])

        self.annotation_text += self.annotation.sex + ':' + str(
            self.annotation.coords) + '\n'
        # Add annotation to the temporary data frame
        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.annotation.reset()

        self.fig.canvas.draw()

    def _clearBoundingBox(self, event):

        if not self.annotation.removePatches():
            return

        self.annotation_text = self.annotation_text.split(
            self.annotation_text.split('\n')[-2])[0]

        self.annotation.reset()

        self.f_dt.drop(self.f_dt.tail(1).index, inplace=True)

        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.fig.canvas.draw()

    def _nextFrame(self, event):

        if self.annotation.coords != ():
            self.error_text.set_text(
                'Save or clear (esc) current annotation before moving on')
            return

        if len(self.f_dt) == 0:
            self.f_dt.loc[0] = [
                self.projectID, self.frames[self.frame_index], '', '',
                self.user, self.now
            ]
            self.f_dt['Nfish'] = 0
        else:
            self.f_dt['Nfish'] = len(self.f_dt)
        self.dt = self.dt.append(self.f_dt, sort=True)
        # Save dataframe (in case user quits)
        self.dt.to_csv(self.annotationFile,
                       sep=',',
                       columns=[
                           'ProjectID', 'Framefile', 'Nfish', 'Sex', 'Box',
                           'User', 'DateTime'
                       ])
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])
        self.annotated_frames.append(self.frame_index)

        # Remove old patches
        self.ax_image.patches = []

        # Reset annotations
        self.annotation = Annotation(self)
        self.annotation_text = ''

        # Update frame index and determine if all images are annotated
        self.frame_index += 1
        while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index])
                          & (self.dt.User == self.user)]) != 0:
            self.frame_index += 1

        if self.frame_index == len(self.frames) or len(
                self.annotated_frames) == self.number:

            # Disconnect connections and close figure
            plt.close(self.fig)

        self.cur_text.set_text('')
        self.all_text.set_text('')

        # Load new image and save it as the background
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        #print(img.shape)
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        self.image_obj.set_array(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])
        self.fig.canvas.draw()
        #self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)

    def _updateFrame(self, event):
        img = self.converter.enhance(self.slid_saturation.val)
        self.image_obj.set_array(img)
        self.fig.canvas.draw()

    def _clearFrame(self, event):
        print('Clearing')
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])
        # Remove old patches
        self.ax_image.patches = []
        self.annotation_text = ''
        self.annotation = Annotation(self)

        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.fig.canvas.draw()

        # Reset annotations
        self.annotation = Annotation(self)

    def _previousFrame(self, event):
        self.frame_index = self.annotated_frames.pop()
        self.dt = self.dt[self.dt.Framefile != self.frames[self.frame_index]]
        self._clearFrame(event)
        # Load new image and save it as the background
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        self.image_obj.set_array(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])
        self.fig.canvas.draw()

    def _quit(self, event):
        plt.close(self.fig)
Ejemplo n.º 31
0
class BoxEditor(object):
    """ Box editor is to select area using rubber band sort of drawing rectangle.
    it uses matplotlib RectangleSelector under the hood """
    polygon = None

    def __init__(self, axes, canvas):
        """ initialises class and creates a rectangle selector """
        self.axes = axes
        self.canvas = canvas
        self.rectangle_selector = RectangleSelector(axes,
                                                    self.line_select_callback,
                                                    drawtype='box',
                                                    useblit=True,
                                                    button=[
                                                        1,
                                                    ],
                                                    minspanx=5,
                                                    minspany=5,
                                                    spancoords='pixels')

    def line_select_callback(self, eclick, erelease):
        """ callback to the rectangleselector """
        x1_val, y1_val = eclick.xdata, eclick.ydata
        x2_val, y2_val = erelease.xdata, erelease.ydata
        xy_values = np.array([
            [
                x1_val,
                y1_val,
            ],
            [
                x1_val,
                y2_val,
            ],
            [
                x2_val,
                y2_val,
            ],
            [
                x2_val,
                y1_val,
            ],
        ])
        self.reset_polygon()
        self.polygon = Polygon(xy_values, animated=False, alpha=polygon_alpha)
        self.axes.add_patch(self.polygon)
        self.canvas.draw()

    def enable(self):
        """ enable the box selector """
        self.rectangle_selector.set_active(True)

    def disable(self):
        """ disables or removes the box selector """
        self.reset_polygon()
        self.rectangle_selector.set_active(False)
        self.canvas.draw()

    def reset_polygon(self):
        """ resets rectangle polygon """
        if self.polygon != None:
            self.polygon.remove()
            self.polygon = None

    def reset(self):
        """ reset the Box selector """
        self.reset_polygon()
Ejemplo n.º 32
0
class MainWindow(QMainWindow, Ui_MainWindow):
    def __init__(self):

        # Initiate window in class 'QMainWindow'
        super().__init__()

        # Configure window layout in Ui_MainWindow
        self.setupUi(self)

        # Initiate figure, canvas and axes
        self.fig = Figure(figsize=(100, 100))
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.subplots()

        # Define axes lines
        self.ax.axhline(linewidth=1, linestyle="dashdot", color="#6E6E6E")
        self.ax.axvline(linewidth=1, linestyle="dashdot", color="#6E6E6E")

        # Set plot title
        self.ax.set_title("Simple plot tool built with Python")

        # Local dictionary
        rectprops = dict(facecolor='gray', alpha=0.5)

        # Connect event with string *button_press_event* to *on_mouse_press* function
        # https://matplotlib.org/api/backend_bases_api.html?highlight=mpl_connect#matplotlib.backend_bases.FigureCanvasBase.mpl_connect
        self.canvas.mpl_connect('button_press_event', self.on_mouse_press)
        self.canvas.mpl_connect('motion_notify_event', self.on_move_mouse)

        # Create 'RectangleSelector' object to be activated when press on Zoom Rect Buttom
        # REMARK: This functions creates a set of polylines in the axes
        # https://matplotlib.org/api/widgets_api.html?highlight=rectangleselector#matplotlib.widgets.RectangleSelector
        # https://matplotlib.org/gallery/widgets/rectangle_selector.html?highlight=rectangleselector
        self.RS = RectangleSelector(self.ax,
                                    self.on_select_zoom_box,
                                    useblit=True,
                                    rectprops=rectprops)
        self.RS.set_active(False)  # deactivate the selector

        # Create 'SpanSelector' object in vertical and horizontal directions, to be activated with zoom vert and hor
        # https://matplotlib.org/api/widgets_api.html?highlight=spanselector#matplotlib.widgets.SpanSelector
        # https://matplotlib.org/gallery/widgets/span_selector.html?highlight=spanselector
        self.SSv = SpanSelector(self.ax,
                                self.on_vert_zoom,
                                'vertical',
                                useblit=True,
                                rectprops=rectprops)
        self.SSh = SpanSelector(self.ax,
                                self.on_hor_zoom,
                                'horizontal',
                                useblit=True,
                                rectprops=rectprops)
        self.SSv.set_active(False)
        self.SSh.set_active(False)

        # Create 'Multicursor' object in vertical and horizontal directions
        # https://matplotlib.org/api/widgets_api.html#matplotlib.widgets.MultiCursor
        # https://matplotlib.org/gallery/widgets/multicursor.html?highlight=multicursor
        self.MC = MultiCursor(self.canvas, (self.ax, ),
                              useblit=True,
                              horizOn=True,
                              vertOn=True,
                              linewidth=1,
                              color="#C8C8C8")
        self.MC.set_active(True)

        # Add Figure Canvas to PyQt Widget
        # REMARK: It is HERE where the matplotlib canvas is conected to PyQt layout (lacking of official documentation)
        # https://www.riverbankcomputing.com/static/Docs/PyQt5/api/qtwidgets/qboxlayout.html?highlight=addwidget
        self.verticalLayout.addWidget(self.canvas)

        # Add a empty line to end of axis's line list (RectangleSelector already created some)
        self.ax.plot([], [])
        self.lines = 1  # Number of real plot lines

        # Set axis labels
        self.ax.set_xlabel("x axis")
        self.ax.set_ylabel("y axis")

        # Initiate first current path
        self.path = []

        # Plot current equation (method already conected to signal 'returnPressed' of 'lineEditEq', defined bellow)
        self.on_lineEditEq_returnPressed()

        # Configure home axes limits (method already conected to signal 'clicked' of 'pushButtonHome', defined bellow)
        self.on_pushButtonHome_clicked()

        self.pushButtonPlayMovie.setText("Play Movie \n in last plot\n(►)")
        self.running = False

    def on_move_mouse(self, event):

        # Clears terminal
        # clc()

        if event.inaxes:
            # Print coordinates to mouse position
            print("\nPosition :==============")
            print("x = ", event.xdata, " | y = ", event.ydata)
            print("MultiCursor active? ", self.MC.active)

        else:
            # If the mouse is not over an axes
            print("Clicked out of axes")

    # Function to be called when clicking on canvas
    def on_mouse_press(self, event: matplotlib.backend_bases.MouseEvent):
        """ Function that is called when click with mouse on FIGURE CANVAS (not only inside axes)
            This Functions only prints information on the terminal
        
        Arguments:
            event {matplotlib.backend_bases.MouseEvent} -- 
            
            For the location events (button and key press/release), if the mouse is over the axes, 
            the inaxes attribute of the event will be set to the Axes the event occurs is over, and additionally, 
            the variables xdata and ydata attributes will be set to the mouse location in data coordinates. 
            See KeyEvent and MouseEvent for more info.
            https://matplotlib.org/api/backend_bases_api.html?highlight=mpl_connect#matplotlib.backend_bases.KeyEvent
            https://matplotlib.org/api/backend_bases_api.html?highlight=mpl_connect#matplotlib.backend_bases.MouseEvent
        
        """
        # Clears terminal
        clc()

        # If the mouse is over an axes
        if event.inaxes:

            # Print polylines ploted in axes
            print("Polylines objects: =================")
            i = 0
            for line in event.inaxes.lines:
                print("line [", i, "]: ", line)
                i += 1

            # Print coordinates to mouse position
            print("\nPosition :==============")
            print("x = ", event.xdata, " | y = ", event.ydata)
            self.canvas.draw()
        else:
            # If the mouse is not over an axes
            print("Clicked out of axes")

    # Function to be called by 'RectangleSelector' object
    def on_select_zoom_box(self, eclick: matplotlib.backend_bases.MouseEvent,
                           erelease: matplotlib.backend_bases.MouseEvent):
        """Function that is called by "RectangleSelector" object from "matplotlib.widgets"

        Arguments:
            eclick {matplotlib.backend_bases.MouseEvent} -- matplotlib event at press mouse button
            erelease {matplotlib.backend_bases.MouseEvent} -- matplotlib event at release mouse button
            https://matplotlib.org/api/backend_bases_api.html?highlight=matplotlib%20backend_bases%20mouseevent#matplotlib.backend_bases.MouseEvent
        """
        self.MC.set_active(
            True
        )  # Está em primeiro porque reseta os limites do eixo. Se estivesse depois, as linhas abaixo seriam sobrepostas
        self.ax.set_xlim(eclick.xdata, erelease.xdata)
        self.ax.set_ylim(eclick.ydata, erelease.ydata)
        self.get_limits()
        self.canvas.draw()
        print("")
        self.RS.set_active(False)

    # Functions to be called when "zoom" vertical and horizontal directions
    def on_vert_zoom(self, vmin: float, vmax: float):
        """Function to zoom only in vertical direction that is called by de SpanSelector object with direction="vertical"
        
        Arguments:
            vmin {float} -- min range value
            vmax {float} -- max range value
        """
        self.MC.set_active(True)
        self.ax.set_ylim(vmin, vmax)
        self.get_limits()
        self.SSv.set_active(False)

    def on_hor_zoom(self, hmin: float, hmax: float):
        """Function to zoom only in horizontal direction that is called by de SpanSelector object with direction="horizontal"
        
        Arguments:
            hmin {float} -- min range value
            hmax {float} -- max range value
        """
        self.MC.set_active(True)
        self.ax.set_xlim(hmin, hmax)
        self.get_limits()
        self.SSh.set_active(False)

    # Get values from lineEdits and set axes limits to they
    def set_limits(self):
        """Function to get values from 'lineEdits' boxes and set limits of axes"""

        # Get values from edit boxes
        xinf = float(self.lineEditXinf.text())
        xsup = float(self.lineEditXsup.text())
        yinf = float(self.lineEditYinf.text())
        ysup = float(self.lineEditYsup.text())

        # Set axes limits
        self.ax.set_xlim(xinf, xsup)
        self.ax.set_ylim(yinf, ysup)

        # Redraw figure canvas
        self.canvas.draw()

        self.get_limits()

    # Get axes limits and put on lineEdits
    def get_limits(self):
        """Function to get the actual limits of axes and put it on 'lineEdits' """

        self.lineEditXinf.setText("{:0.2f}".format(self.ax.get_xlim()[0]))
        self.lineEditXsup.setText("{:0.2f}".format(self.ax.get_xlim()[1]))
        self.lineEditYinf.setText("{:0.2f}".format(self.ax.get_ylim()[0]))
        self.lineEditYsup.setText("{:0.2f}".format(self.ax.get_ylim()[1]))

    @QtCore.pyqtSlot()
    def on_lineEditEq_returnPressed(self):

        # Get data from edit boxes
        start = float(self.lineEditStart.text())
        stop = float(self.lineEditStop.text())
        num = int(self.lineEditNum.text())

        # Calculate data to plot the curve
        x = linspace(start, stop, num)

        try:
            y = eval(self.lineEditEq.text())
        except:
            return None

        # Set new data to the curve
        self.ax.lines[-1].set_data(x, y)

        # Update x and y
        path = self.ax.lines[-1].get_path()
        x = path.vertices[:, 0]
        y = path.vertices[:, 1]

        # Color new line
        if all(x == x[0]) or all(y == y[0]):
            self.ax.lines[-1].set_color("#969696")
            self.ax.lines[-1].set_linestyle("dashdot")
        else:
            self.ax.lines[-1].set_color("#000000")
            self.ax.lines[-1].set_linestyle("solid")

        # Get the last line path
        self.path = self.ax.lines[-1].get_path()

        # Redraw figure canvas
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_lineEditStart_returnPressed(self):
        self.on_lineEditEq_returnPressed()

    @QtCore.pyqtSlot()
    def on_lineEditStop_returnPressed(self):
        self.on_lineEditEq_returnPressed()

    @QtCore.pyqtSlot()
    def on_lineEditNum_returnPressed(self):
        self.on_lineEditEq_returnPressed()

    @QtCore.pyqtSlot()
    def on_lineEditXinf_returnPressed(self):
        self.set_limits()

    @QtCore.pyqtSlot()
    def on_lineEditXsup_returnPressed(self):
        self.set_limits()

    @QtCore.pyqtSlot()
    def on_lineEditYinf_returnPressed(self):
        self.set_limits()

    @QtCore.pyqtSlot()
    def on_lineEditYsup_returnPressed(self):
        self.set_limits()

    @QtCore.pyqtSlot()
    def on_pushButtonHome_clicked(self):

        # Reset auto-scale
        self.ax.set_autoscale_on(True)

        # Recompute data limits
        self.ax.relim()

        # Automatic axis scaling
        self.ax.autoscale_view()

        # Redraw figure canvas
        self.canvas.draw()

        self.get_limits()

    @QtCore.pyqtSlot()
    def on_pushButtonAddPlot_clicked(self):

        # Add a new line-plot to lines list, if the last wasn't empty
        # or if there is no lines
        if self.lines <= 0 or len(self.ax.lines[-1].get_xdata()) > 0:
            self.ax.plot([], [])
            self.lines += 1

        # Set focus on edit box of equation
        self.lineEditEq.setText("")
        self.lineEditEq.setFocus()

    @QtCore.pyqtSlot()
    def on_pushButtonDelPlot_clicked(self):

        if self.lines > 0:

            # Remove last line
            self.ax.lines.pop()

            # Redraw figure canvas
            self.canvas.draw()

            # Decrease number of curves
            self.lines -= 1

            # Get the last line path
            self.path = self.ax.lines[-1].get_path()

    @QtCore.pyqtSlot()
    def on_pushButtonRect_clicked(self):
        self.MC.set_active(False)
        self.SSv.set_active(False)
        self.SSh.set_active(False)
        self.RS.set_active(True)
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_pushButtonHor_clicked(self):
        self.MC.set_active(False)
        self.RS.set_active(False)
        self.SSv.set_active(False)
        self.SSh.set_active(True)
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_pushButtonVert_clicked(self):
        self.MC.set_active(False)
        self.RS.set_active(False)
        self.SSh.set_active(False)
        self.SSv.set_active(True)
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_lineEditDeltaT_editingFinished(self):
        self.update_Dt()

    @QtCore.pyqtSlot(str)
    def on_lineEditDeltaT_textChanged(self):
        self.update_Dt()

    @QtCore.pyqtSlot()
    def on_lineEditDeltaT_returnPressed(self):
        self.update_Dt()

    @QtCore.pyqtSlot()
    def on_pushButtonPlayMovie_clicked(self):
        if not self.running:
            self.running = True
            self.pushButtonPlayMovie.setText("Pause Movie \n( ▍▍)")

            xt = self.path.vertices[:, 0]
            yt = self.path.vertices[:, 1]
            if all(self.path.vertices[-1, :] ==
                   self.ax.lines[-1].get_path().vertices[-1, :]):
                self.ax.lines[-1].set_data([], [])

            temp_path = self.ax.lines[-1].get_path()
            x = temp_path.vertices[:, 0]
            y = temp_path.vertices[:, 1]
            start_loop = time()
            intervals = []

            i = len(x)
            while self.running and i < len(self.path.vertices[:, 1]):

                i += 1
                x = xt[0:i]
                y = yt[0:i]
                self.ax.lines[-1].set_data(x, y)
                sleep(1)
                self.canvas.start_event_loop(
                    1)  #max([Dt-(time()-start_loop),1e-30]))
                # sleep(max([Dt-(time()-start_loop),1e-30])) # --> nao funciona
                # plt.pause(max([Dt-(time()-start_loop),1e-30])) # --> nao funciona
                intervals.append("Step " + str(i) + ": " +
                                 str(time() - start_loop))
                print(intervals[-1])
                start_loop = time()
                self.canvas.draw()
            self.running = False
            print(array(intervals))
            self.pushButtonPlayMovie.setText("Play Movie \n in last plot\n(►)")
        else:
            self.running = False

    def update_Dt(self):
        global Dt
        try:
            Dt = max([float(self.lineEditDeltaT.text()), 1e-30])
        except:
            Dt = 1.0
        print("Δt = ", Dt)
Ejemplo n.º 33
0
class plot_2d_data(wx.Frame):
    """Generic 2d plotting routine - inputs are:
    - data (2d array of values),
    - x and y extent of the data,
    - title of graph, and
    - pixel mask to be used during summation  - must have same dimensions as data
    (only data entries corresponding to nonzero values in pixel_mask will be summed)
    - plot_title, x_label and y_label are added to the 2d-plot as you might expect"""

    def __init__(self, data, extent, caller = None, scale = 'log', window_title = 'log plot', pixel_mask = None, plot_title = "data plot", x_label = "x", y_label = "y", parent=None):
        wx.Frame.__init__(self, parent=None, title=window_title, pos = wx.DefaultPosition, size=wx.Size(800,600))
        print parent
        self.extent = extent
        self.data = data
        self.caller = caller
        self.window_title = window_title
        x_range = extent[0:2]
        #x_range.sort()
        self.x_min, self.x_max = x_range
        y_range = extent[2:4]
        #y_range.sort()
        self.y_min, self.y_max = y_range
        self.plot_title = plot_title
        self.x_label = x_label
        self.y_label = y_label
        self.slice_xy_range = (x_range, y_range)
        self.ID_QUIT = wx.NewId()
        self.ID_LOGLIN = wx.NewId()
        self.ID_UPCOLLIM = wx.NewId()
        self.ID_LOWCOLLIM = wx.NewId()

        menubar = wx.MenuBar()
        filemenu = wx.Menu()
        quit = wx.MenuItem(filemenu, 1, '&Quit\tCtrl+Q')
        #quit.SetBitmap(wx.Bitmap('icons/exit.png'))
        filemenu.AppendItem(quit)

        plotmenu = wx.Menu()
        self.menu_log_lin_toggle = plotmenu.Append(self.ID_LOGLIN, 'Plot 2d data with log color scale', 'plot 2d on log scale', kind=wx.ITEM_CHECK)
        self.Bind(wx.EVT_MENU, self.toggle_2d_plot_scale, id=self.ID_LOGLIN)
        menu_upper_colormap_limit = plotmenu.Append(self.ID_UPCOLLIM, 'Set upper limit of color map', 'Set upper limit of color map')
        self.Bind(wx.EVT_MENU, self.set_new_upper_color_limit, id=self.ID_UPCOLLIM)
        menu_lower_colormap_limit = plotmenu.Append(self.ID_LOWCOLLIM, 'Set lower limit of color map', 'Set lower limit of color map')
        self.Bind(wx.EVT_MENU, self.set_new_lower_color_limit, id=self.ID_LOWCOLLIM)
        #live_on_off = wx.MenuItem(live_update, 1, '&Live Update\tCtrl+L')
        #quit.SetBitmap(wx.Bitmap('icons/exit.png'))
        #live_update.AppendItem(self.live_toggle)
        #self.menu_log_lin_toggle.Check(True)

        menubar.Append(filemenu, '&File')
        menubar.Append(plotmenu, '&Plot')
        self.SetMenuBar(menubar)
        self.Centre()

        if pixel_mask == None:
            pixel_mask = ones(data.shape)

        if pixel_mask.shape != data.shape:
            print "Warning: pixel mask shape incompatible with data"
            pixel_mask = ones(data.shape)

        self.pixel_mask = pixel_mask

        self.show_data = transpose(data.copy())
        #self.minimum_intensity = self.data[pixel_mask.nonzero()].min()
        # correct for floating-point weirdness:
        self.minimum_intensity = self.data[self.data > 1e-17].min()

        #if scale == 'log':
            #self.show_data = log ( self.data.copy().T + self.minimum_intensity/2.0 )
            #self._scale = 'log'
            #self.menu_log_lin_toggle.Check(True)

        #elif (scale =='lin' or scale == 'linear'):
            #self._scale = 'lin'
            #self.menu_log_lin_toggle.Check(True)


        #self.bin_data = caller.bin_data
        #self.params = caller.params
        #fig = figure()
        self.fig = Figure(dpi=80, figsize=(5,5))
        #self.fig = figure()
        fig = self.fig
        self.canvas = Canvas(self, -1, self.fig)
        self.show_sliceplots = False # by default, sliceplots on
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.TOP | wx.LEFT | wx.EXPAND)

        #self.toolbar = Toolbar(self.canvas)
        self.toolbar = MyNavigationToolbar(self.canvas, True, self)
        self.toolbar.Realize()
        if wx.Platform == '__WXMAC__':
            # Mac platform (OSX 10.3, MacPython) does not seem to cope with
            # having a toolbar in a sizer. This work-around gets the buttons
            # back, but at the expense of having the toolbar at the top
            self.SetToolBar(self.toolbar)
        else:
            # On Windows platform, default window size is incorrect, so set
            # toolbar width to figure width.
            tw, th = self.toolbar.GetSizeTuple()
            fw, fh = self.canvas.GetSizeTuple()
            # By adding toolbar in sizer, we are able to put it at the bottom
            # of the frame - so appearance is closer to GTK version.
            # As noted above, doesn't work for Mac.
            self.toolbar.SetSize(wx.Size(fw, th))
            self.sizer.Add(self.toolbar, 0, wx.LEFT | wx.EXPAND)

        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetFieldsCount(2)
        self.statusbar.SetStatusWidths([-1, -2])
        self.statusbar.SetStatusText("Current Position:", 0)

        self.canvas.mpl_connect('motion_notify_event', self.onmousemove)
        #self.canvas.mpl_connect('button_press_event', self.right_click_handler)
        #self.axes = fig.add_subplot(111)
        #self.axes = self.fig.gca()
        #ax = self.axes
        self.mapper = FigureImage(self.fig)
        #im = self.axes.pcolor(x,y,V,shading='flat')
        #self.mapper.add_observer(im)



        #self.show_data = transpose(log(self.show_data + self.minimum_intensity / 2.0))

        #self.canvas.mpl_connect('pick_event', self.log_lin_select)

        ax = fig.add_subplot(221, label='2d_plot')
        fig.sx = fig.add_subplot(222, label='sx', picker=True)
        fig.sx.xaxis.set_picker(True)
        fig.sx.yaxis.set_picker(True)
        fig.sx.yaxis.set_ticks_position('right')
        fig.sx.set_zorder(1)
        fig.sz = fig.add_subplot(223, label='sz', picker=True)
        fig.sz.xaxis.set_picker(True)
        fig.sz.yaxis.set_picker(True)
        fig.sz.set_zorder(1)
        self.RS = RectangleSelector(ax, self.onselect, drawtype='box', useblit=True)
        fig.slice_overlay = None

        ax.set_position([0.125,0.1,0.7,0.8])
        fig.cb = fig.add_axes([0.85,0.1,0.05,0.8])
        fig.cb.set_zorder(2)

        fig.ax = ax
        fig.ax.set_zorder(2)
        self.axes = ax
        ax.set_title(plot_title)
        #connect('key_press_event', self.toggle_selector)
        if scale == 'log':
            self.show_data = log ( self.data.copy().T + self.minimum_intensity/2.0 )
            self.__scale = 'log'
            self.fig.cb.set_xlabel('$\log_{10}I$')
            self.menu_log_lin_toggle.Check(True)

        elif (scale =='lin' or scale == 'linear'):
            self.__scale = 'lin'
            self.fig.cb.set_xlabel('$I$')
            self.menu_log_lin_toggle.Check(False)

        im = self.axes.imshow(self.show_data, interpolation='nearest', aspect='auto', origin='lower',cmap=cm.jet, extent=extent)
        #im = ax.imshow(data, interpolation='nearest', aspect='auto', origin='lower',cmap=cm.jet, extent=extent)
        fig.im = im
        ax.set_xlabel(x_label, size='large')
        ax.set_ylabel(y_label, size='large')
        self.toolbar.update()
        #zoom_colorbar(im)

        #fig.colorbar(im, cax=fig.cb)
        zoom_colorbar(im=im, cax=fig.cb)
        #figure(fig.number)
        #fig.canvas.draw()
        #return


        self.SetSizer(self.sizer)
        self.Fit()

        self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.OnContext)
        self.Bind(wx.EVT_CLOSE, self.onExit)
        self.sliceplots_off()
        self.SetSize(wx.Size(800,600))
        self.canvas.draw()
        return

    def onExit(self, event):
        self.Destroy()

    def exit(self, event):
        wx.GetApp().Exit()


    def set_new_upper_color_limit(self, evt = None):
        current_uplim = self.fig.im.get_clim()[1]
        current_lowlim = self.fig.im.get_clim()[0]
        dlg = wx.TextEntryDialog(None, "Change upper limit of color map (currently %f)" % current_uplim, defaultValue = "%f" % current_uplim)
        if dlg.ShowModal() == wx.ID_OK:
            new_val = dlg.GetValue()
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()
            self.fig.im.set_clim((current_lowlim, float(new_val)))
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()
        dlg.Destroy()

    def set_new_lower_color_limit(self, evt = None):
        current_uplim = self.fig.im.get_clim()[1]
        current_lowlim = self.fig.im.get_clim()[0]
        dlg = wx.TextEntryDialog(None, "Change lower limit of color map (currently %f)" % current_lowlim, defaultValue = "%f" % current_lowlim)
        if dlg.ShowModal() == wx.ID_OK:
            new_val = dlg.GetValue()
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()
            self.fig.im.set_clim((float(new_val), current_uplim))
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()
        dlg.Destroy()

    def OnContext(self, evt):
        print self.show_sliceplots
        mpl_x = evt.X
        mpl_y = self.fig.canvas.GetSize()[1] - evt.Y
        mpl_mouseevent = matplotlib.backend_bases.MouseEvent('button_press_event', self.canvas, mpl_x, mpl_y, button = 3)

        if (mpl_mouseevent.inaxes == self.fig.ax):
            self.area_context(mpl_mouseevent, evt)
        elif ((mpl_mouseevent.inaxes == self.fig.sx or mpl_mouseevent.inaxes == self.fig.sz) and (self.show_sliceplots == True)):
            self.lineplot_context(mpl_mouseevent, evt)

    def area_context(self, mpl_mouseevent, evt):
        area_popup = wx.Menu()
        item1 = area_popup.Append(wx.ID_ANY,'&Grid on/off', 'Toggle grid lines')
        wx.EVT_MENU(self, item1.GetId(), self.OnGridToggle)
        cmapmenu = CMapMenu(self, callback = self.OnColormap, mapper=self.mapper, canvas=self.canvas)
        item2 = area_popup.Append(wx.ID_ANY,'&Toggle log/lin', 'Toggle log/linear scale')
        wx.EVT_MENU(self, item2.GetId(), lambda evt: self.toggle_log_lin(mpl_mouseevent))
        item3 = area_popup.AppendMenu(wx.ID_ANY, "Colourmaps", cmapmenu)
        self.PopupMenu(area_popup, evt.GetPositionTuple())

    def figure_list_dialog(self):
        figure_list = get_fignums()
        figure_list_names = []
        for fig in figure_list:
            figure_list_names.append('Figure ' + str(fig))
        figure_list_names.insert(0, 'New Figure')
        figure_list.insert(0, None)
        #selection_num = wx.GetSingleChoiceIndex('Choose other plot', '', other_plot_names)
        dlg = wx.SingleChoiceDialog(None, 'Choose figure number', '', figure_list_names)
        dlg.SetSize(wx.Size(640,480))
        if dlg.ShowModal() == wx.ID_OK:
            selection_num=dlg.GetSelection()
        dlg.Destroy()
        print selection_num
        return figure_list[selection_num]

    def lineplot_context(self, mpl_mouseevent, evt):
        popup = wx.Menu()
        item1 = popup.Append(wx.ID_ANY,'&Toggle log/lin', 'Toggle log/linear scale of slices')
        wx.EVT_MENU(self, item1.GetId(), lambda evt: self.toggle_log_lin(mpl_mouseevent))
        if mpl_mouseevent.inaxes == self.fig.sx:
            item2 = popup.Append(wx.ID_ANY, "Save x slice", "save this slice")
            wx.EVT_MENU(self, item2.GetId(), self.save_x_slice)
            item3 = popup.Append(wx.ID_ANY, '&Popout plot', 'Open this data in a figure window')
            wx.EVT_MENU(self, item3.GetId(), lambda evt: self.popout_x_slice())
        elif mpl_mouseevent.inaxes == self.fig.sz:
            item2 = popup.Append(wx.ID_ANY, "Save y slice", "save this slice")
            wx.EVT_MENU(self, item2.GetId(), self.save_y_slice)
            item3 = popup.Append(wx.ID_ANY, '&Popout plot', 'Open this data in a new plot window')
            wx.EVT_MENU(self, item3.GetId(), lambda evt: self.popout_y_slice())
        self.PopupMenu(popup, evt.GetPositionTuple())


    def popout_y_slice(self, event=None, figure_num = None, label = None):
        if figure_num == None:
            figure_num = self.figure_list_dialog()
        fig = figure(figure_num) # if this is None, matplotlib automatically increments figure number to highest + 1
        ax = self.fig.sz
        slice_desc = '\nsliceplot([%f,%f],[%f,%f])' % (self.slice_xy_range[0][0],self.slice_xy_range[0][1],self.slice_xy_range[1][0],self.slice_xy_range[1][1])
        if figure_num == None:
            default_title = self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, 'Enter title for plot', defaultValue = default_title)
            if dlg.ShowModal() == wx.ID_OK:
                title = dlg.GetValue()
            else:
                title = default_title
            dlg.Destroy()
            new_ax = fig.add_subplot(111)
            new_ax.set_title(title, size='large')
            new_ax.set_xlabel(self.x_label, size='x-large')
            new_ax.set_ylabel('$I_{summed}$', size='x-large')
        else:
            new_ax = fig.axes[0]
        if label == None:
            default_label = self.window_title + ': ' + self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, 'Enter data label (for plot legend)', defaultValue = default_label)
            if dlg.ShowModal() == wx.ID_OK:
                label = dlg.GetValue()
            else:
                label = default_label
            dlg.Destroy()
        xy = ax.lines[0].get_data()
        x = xy[0]
        y = xy[1]
        new_ax.plot(x,y, label = label)
        font = FontProperties(size='small')
        lg = legend(prop=font)
        drag_lg = DraggableLegend(lg)
        drag_lg.connect()
        fig.canvas.draw()
        fig.show()

    def popout_x_slice(self, event=None, figure_num = None, label = None):
        if figure_num == None:
            figure_num = self.figure_list_dialog()
        fig = figure(figure_num)
        ax = self.fig.sx
        slice_desc = '\nsliceplot([%f,%f],[%f,%f])' % (self.slice_xy_range[0][0],self.slice_xy_range[0][1],self.slice_xy_range[1][0],self.slice_xy_range[1][1])
        if figure_num == None:
            default_title = self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, 'Enter title for plot', defaultValue = default_title)
            if dlg.ShowModal() == wx.ID_OK:
                title = dlg.GetValue()
            else:
                title = default_title
            dlg.Destroy()
            new_ax = fig.add_subplot(111)
            new_ax.set_title(title, size='large')
            new_ax.set_xlabel(self.y_label, size='x-large')
            new_ax.set_ylabel('$I_{summed}$', size='x-large')
        else:
            new_ax = fig.axes[0]
        if label == None:
            default_label = self.window_title + ': ' + self.plot_title + slice_desc
            dlg = wx.TextEntryDialog(None, 'Enter data label (for plot legend)', defaultValue = default_label)
            if dlg.ShowModal() == wx.ID_OK:
                label = dlg.GetValue()
            else:
                label = default_label
            dlg.Destroy()
        xy = ax.lines[0].get_data()
        x = xy[1]
        y = xy[0]
        new_ax.plot(x,y, label = label)
        font = FontProperties(size='small')
        lg = legend(prop=font)
        drag_lg = DraggableLegend(lg)
        drag_lg.connect()
        fig.canvas.draw()
        fig.show()

    def save_x_slice(self, event=None, outFileName=None):
        if outFileName == None:
            dlg = wx.FileDialog(None, "Save 2d data as:", '', "", "", wx.FD_SAVE)
            if dlg.ShowModal() == wx.ID_OK:
                fn = dlg.GetFilename()
                fd = dlg.GetDirectory()
            dlg.Destroy()
            outFileName = fd + '/' + fn
        outFile = open(outFileName, 'w')
        outFile.write('#'+self.title+'\n')
        outFile.write('#xmin: ' + str(self.slice_xy_range[0][0]) + '\n')
        outFile.write('#xmax: ' + str(self.slice_xy_range[0][1]) + '\n')
        outFile.write('#ymin: ' + str(self.slice_xy_range[1][0]) + '\n')
        outFile.write('#ymax: ' + str(self.slice_xy_range[1][1]) + '\n')
        outFile.write("#y\tslice_x_data\n")
        if not (self.slice_x_data == None):
            for i in range(self.slice_x_data.shape[0]):
                x = self.y[i]
                y = self.slice_x_data[i]
                outFile.write(str(x) + "\t" + str(y) + "\n")
        outFile.close()
        print('saved x slice in %s' % (outFileName))
        return

    def save_y_slice(self, event=None, outFileName=None):
        if outFileName == None:
            dlg = wx.FileDialog(None, "Save 2d data as:", '', "", "", wx.FD_SAVE)
            if dlg.ShowModal() == wx.ID_OK:
                fn = dlg.GetFilename()
                fd = dlg.GetDirectory()
            dlg.Destroy()
            outFileName = fd + '/' + fn
        outFile = open(outFileName, 'w')
        outFile.write('#'+self.title+'\n')
        outFile.write('#xmin: ' + str(self.slice_xrange[0]) + '\n')
        outFile.write('#xmax: ' + str(self.slice_xrange[1]) + '\n')
        outFile.write('#ymin: ' + str(self.slice_yrange[0]) + '\n')
        outFile.write('#ymax: ' + str(self.slice_yrange[1]) + '\n')
        outFile.write("#x\tslice_y_data\n")
        if not (self.slice_y_data == None):
            for i in range(self.slice_y_data.shape[0]):
                x = self.x[i]
                y = self.slice_y_data[i]
                outFile.write(str(x) + "\t" + str(y) + "\n")
        outFile.close()
        print('saved y slice in %s' % (outFileName))
        return


    def OnGridToggle(self, event):
        self.fig.ax.grid()
        self.fig.canvas.draw_idle()

    def OnColormap(self, name):
        print "Selected colormap",name
        self.fig.im.set_cmap(get_cmap(name))
        self.fig.canvas.draw()

    def toggle_2d_plot_scale(self, event=None):
        if self.__scale == 'log':
            self.show_data = self.data.T
            self.fig.im.set_array(self.show_data)
            self.fig.im.autoscale()
            self.fig.cb.set_xlabel('$I$')
            self.__scale = 'lin'
            self.menu_log_lin_toggle.Check(False)
            self.statusbar.SetStatusText("%s scale" % self.__scale, 0)
            self.fig.canvas.draw_idle()
        elif self.__scale == 'lin':
            self.show_data = log ( self.data.copy().T + self.minimum_intensity/2.0 )
            self.fig.im.set_array(self.show_data)
            self.fig.im.autoscale()
            self.fig.cb.set_xlabel('$\log_{10}I$')
            self.__scale = 'log'
            self.menu_log_lin_toggle.Check(True)
            self.statusbar.SetStatusText("%s scale" % self.__scale, 0)
            self.fig.canvas.draw_idle()


    def toggle_log_lin(self,event):

        ax = event.inaxes
        label = ax.get_label()

        if label == '2d_plot':
            self.toggle_2d_plot_scale()

        if label == 'sz':
            scale = ax.get_yscale()
            if scale == 'log':
                ax.set_yscale('linear')
                ax.figure.canvas.draw_idle()
            elif scale == 'linear':
                ax.set_yscale('log')
                ax.figure.canvas.draw_idle()

        elif label == 'sx':
            scale = ax.get_xscale()
            if scale == 'log':
                ax.set_xscale('linear')
                ax.figure.canvas.draw_idle()
            elif scale == 'linear':
                ax.set_xscale('log')
                ax.figure.canvas.draw_idle()


    def onmousemove(self,event):
        # the cursor position is given in the wx status bar
        #self.fig.gca()
        if event.inaxes:
            x, y = event.xdata, event.ydata
            self.statusbar.SetStatusText("%s scale x = %.3g, y = %.3g" % (self.__scale,x,y), 1)
            #self.statusbar.SetStatusText("y = %.3g" %y, 2)


    def onselect(self, eclick, erelease):
        x_range = [eclick.xdata, erelease.xdata]
        y_range = [eclick.ydata, erelease.ydata]
        ax = eclick.inaxes
        self.sliceplot((x_range, y_range), ax)
        print 'sliceplot(([%f,%f],[%f,%f]))' % (x_range[0],x_range[1],y_range[0],y_range[1])

    def sliceplots_off(self):
        self.fig.ax.set_position([0.125,0.1,0.7,0.8])
        self.fig.cb.set_position([0.85,0.1,0.05,0.8])
        #self.fig.cb.set_visible(True)
        self.fig.sx.set_visible(False)
        self.fig.sz.set_visible(False)
        if self.fig.slice_overlay:
            self.fig.slice_overlay[0].set_visible(False)
        self.RS.set_active(False)
        self.show_sliceplots = False
        self.fig.canvas.draw()

    def sliceplots_on(self):
        self.fig.ax.set_position([0.125,0.53636364, 0.35227273,0.36363636])
        self.fig.cb.set_position([0.49,0.53636364, 0.02, 0.36363636])
        self.fig.sx.set_position([0.58,0.53636364, 0.35227273,0.36363636])
        self.fig.sx.set_visible(True)
        self.fig.sz.set_visible(True)
        #self.fig.cb.set_visible(False)
        if self.fig.slice_overlay:
            self.fig.slice_overlay[0].set_visible(True)
        self.RS.set_active(True)
        self.show_sliceplots = True
        self.fig.canvas.draw()

    def toggle_sliceplots(self):
        """switch between views with and without slice plots"""
        if self.show_sliceplots == True:
            self.sliceplots_off()
        else: # self.show_sliceplots == False
            self.sliceplots_on()

    def show_slice_overlay(self, x_range, y_range, x, slice_y_data, y, slice_x_data):
        """sum along x and z within the box defined by qX- and qZrange.
        sum along qx is plotted to the right of the data,
        sum along qz is plotted below the data.
        Transparent white rectangle is overlaid on data to show summing region"""
        from matplotlib.ticker import FormatStrFormatter, ScalarFormatter

        if self.fig == None:
            print('No figure for this dataset is available')
            return

        fig = self.fig
        ax = fig.ax
        extent = fig.im.get_extent()

        if fig.slice_overlay == None:
            fig.slice_overlay = ax.fill([x_range[0],x_range[1],x_range[1],x_range[0]],[y_range[0],y_range[0],y_range[1],y_range[1]],fc='white', alpha=0.3)
            fig.ax.set_ylim(extent[2],extent[3])
        else:
            fig.slice_overlay[0].xy = [(x_range[0],y_range[0]), (x_range[1],y_range[0]), (x_range[1],y_range[1]), (x_range[0],y_range[1])]
        fig.sz.clear()
        default_fmt = ScalarFormatter(useMathText=True)
        default_fmt.set_powerlimits((-2,4))
        fig.sz.xaxis.set_major_formatter(default_fmt)
        fig.sz.yaxis.set_major_formatter(default_fmt)
        fig.sz.xaxis.set_major_formatter(FormatStrFormatter('%.2g'))
        fig.sz.set_xlim(x[0], x[-1])
        fig.sz.plot(x, slice_y_data)
        fig.sx.clear()
        fig.sx.yaxis.set_major_formatter(default_fmt)
        fig.sx.xaxis.set_major_formatter(default_fmt)
        fig.sx.yaxis.set_ticks_position('right')
        fig.sx.yaxis.set_major_formatter(FormatStrFormatter('%.2g'))
        fig.sx.set_ylim(y[0], y[-1])
        fig.sx.plot(slice_x_data, y)

        fig.im.set_extent(extent)
        fig.canvas.draw()

    def copy_intensity_range_from(self, other_plot):
        if isinstance(other_plot, type(self)):
            xlab = self.fig.cb.get_xlabel()
            ylab = self.fig.cb.get_ylabel()

            self.fig.im.set_clim(other_plot.fig.im.get_clim())
            self.fig.cb.set_xlabel(xlab)
            self.fig.cb.set_ylabel(ylab)
            self.fig.canvas.draw()

    def sliceplot(self, xy_range, ax = None):
        """sum along x and z within the box defined by qX- and qZrange.
        sum along qx is plotted to the right of the data,
        sum along qz is plotted below the data.
        Transparent white rectangle is overlaid on data to show summing region"""
        self.sliceplots_on()
        x_range, y_range = xy_range
        x, slice_y_data, y, slice_x_data = self.do_xy_slice(x_range, y_range)
        self.x = x
        self.slice_y_data = slice_y_data
        self.y = y
        self.slice_x_data = slice_x_data
        self.slice_xy_range = xy_range

        self.show_slice_overlay(x_range, y_range, x, slice_y_data, y, slice_x_data)

    def do_xy_slice(self, x_range, y_range):
        """ slice up the data, once along x and once along z.
        returns 4 arrays:  a y-axis for the x data,
        an x-axis for the y data."""
        #params = self.params
        print 'doing xy slice'
        data = self.data
        pixels = self.pixel_mask
        # zero out any pixels in the sum that have zero in the pixel count:
        data[pixels == 0] = 0

        normalization_matrix = ones(data.shape)
        normalization_matrix[pixels == 0] = 0
        x_min = min(x_range)
        x_max = max(x_range)
        y_min = min(y_range)
        y_max = max(y_range)

        x_size,y_size = data.shape
        global_x_range = (self.x_max - self.x_min)
        global_y_range = (self.y_max - self.y_min)

        x_pixel_min = round( (x_min - self.x_min) / global_x_range * x_size )
        x_pixel_max = round( (x_max - self.x_min) / global_x_range * x_size )
        y_pixel_min = round( (y_min - self.y_min) / global_y_range * y_size )
        y_pixel_max = round( (y_max - self.y_min) / global_y_range * y_size )

        #correct any sign switches:
        if (x_pixel_min > x_pixel_max):
            new_min = x_pixel_max
            x_pixel_max = x_pixel_min
            x_pixel_min = new_min

        if (y_pixel_min > y_pixel_max):
            new_min = y_pixel_max
            y_pixel_max = y_pixel_min
            y_pixel_min = new_min

        new_x_min = x_pixel_min / x_size * global_x_range + self.x_min
        new_x_max = x_pixel_max / x_size * global_x_range + self.x_min
        new_y_min = y_pixel_min / y_size * global_y_range + self.y_min
        new_y_max = y_pixel_max / y_size * global_y_range + self.y_min

        x_pixel_min = int(x_pixel_min)
        x_pixel_max = int(x_pixel_max)
        y_pixel_min = int(y_pixel_min)
        y_pixel_max = int(y_pixel_max)

        y_norm_factor = sum(normalization_matrix[x_pixel_min:x_pixel_max,y_pixel_min:y_pixel_max], axis=1)
        x_norm_factor = sum(normalization_matrix[x_pixel_min:x_pixel_max,y_pixel_min:y_pixel_max], axis=0)
        # make sure the normalization has a minimum value of 1 everywhere,
        # to avoid divide by zero errors:
        y_norm_factor[y_norm_factor == 0] = 1
        x_norm_factor[x_norm_factor == 0] = 1

        slice_y_data = sum(data[x_pixel_min:x_pixel_max,y_pixel_min:y_pixel_max], axis=1) / y_norm_factor
        slice_x_data = sum(data[x_pixel_min:x_pixel_max,y_pixel_min:y_pixel_max], axis=0) / x_norm_factor

        #slice_y_data = slice_y_data
        #slice_x_data = slice_x_data

        x_vals = arange(slice_y_data.shape[0], dtype = 'float') / slice_y_data.shape[0] * (new_x_max - new_x_min) + new_x_min
        y_vals = arange(slice_x_data.shape[0], dtype = 'float') / slice_x_data.shape[0] * (new_y_max - new_y_min) + new_y_min

        return x_vals, slice_y_data, y_vals, slice_x_data
Ejemplo n.º 34
0
class Plot(wx.Panel):
    def __init__(self, parent, id=-1, dpi=None, **kwargs):
        wx.Panel.__init__(self, parent, id, **kwargs)
        Plot.figure,Plot.ax= plt.subplots(dpi=dpi, figsize=(6, 5))
        #plt.figure(dpi=dpi, figsize=(2, 2))
        self.canvas = FigureCanvas(self, -1, self.figure)
        self.toolbar = NavigationToolbar(self.canvas)
        self.toolbar.Realize()
        
        self.toggle_selector_RS = RectangleSelector(Plot.ax, self.onselect, drawtype='box', button=1,
                                        #minspanx=5, minspany=5, 
                                        spancoords='data', interactive=True 
                                        )
        Plot.figure.canvas.mpl_connect('key_press_event', self.toggle_selector)

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.canvas, 1, wx.EXPAND)
        sizer.Add(self.toolbar, 0, wx.LEFT | wx.EXPAND)
        self.SetSizer(sizer)
    
    def onselect(self,eclick, erelease):
        #"eclick and erelease are matplotlib events at press and release."
        print('startposition: (%f, %f)' % (eclick.xdata, eclick.ydata))
        print('endposition  : (%f, %f)' % (erelease.xdata, erelease.ydata))
        print('used button  : ', eclick.button)
        self.coordinate=self.toggle_selector_RS.geometry
        self.coordinate=self.coordinate.astype(int)
        print(self.coordinate)

    def toggle_selector(self,event):
        print('Key pressed.')
        if event.key in ['Q', 'q'] and self.toggle_selector_RS.active:
            print('RectangleSelector deactivated.')
            self.toggle_selector_RS.set_active(False)
            MainWindow.Select=False
            
        if MainWindow.Select and not self.toggle_selector_RS.active:
            print('RectangleSelector activated.')
            self.toggle_selector_RS.set_active(True)
        if event.key == "ctrl+alt+a" and not self.toggle_selector_RS.active:
            print('RectangleSelector activated.')
            self.toggle_selector_RS.set_active(True)
            
        if event.key in ['s','S'] or event.key=='enter' and self.toggle_selector_RS.active:
            dlg = wx.MessageDialog(None, "Are you sure for the selection? \n(This will overwrite the previous result.)", "Select Region", wx.YES_NO | wx.ICON_QUESTION)
            if dlg.ShowModal() == wx.ID_YES:
                print("crop image")
                print(MainWindow.dataNew.shape)
                data=MainWindow.dataNew.reshape((MainWindow.mapy,MainWindow.mapx,MainWindow.dataNew.shape[1],MainWindow.dataNew.shape[2]))
                data=data[min(self.coordinate[1,:]):max(self.coordinate[1,:]),min(self.coordinate[:,1]):max(self.coordinate[:,1]),:,:]
                MainWindow.mapy=max(self.coordinate[1,:])-min(self.coordinate[1,:])
                MainWindow.mapx=max(self.coordinate[:,1])-min(self.coordinate[:,1])
                data=data.reshape(MainWindow.mapy*MainWindow.mapx,data.shape[-2],data.shape[-1])
                MainWindow.dataNew=data
                
                self.Close(True)
                print(MainWindow.dataNew.shape)
            dlg.Destroy()

            
            
        rectprops = dict(facecolor='red', edgecolor = 'black',
                     alpha=0.5, fill=True)
class QMPLFitterWidget(QMPLWidget):
    """
    Qt4 Widget with matplotlib elements modified to include interactive
    fitting functionality.
    """
    def __init__(self, parent=None):
        # Initialize QMPLWidget
        super(QMPLFitterWidget, self).__init__(parent)

        # Add a rectangle selector widget for interactive fit selection
        self.selector = RectangleSelector(self.axes,
                                          self.rect_select_callback,
                                          drawtype="box",
                                          useblit=True,
                                          button=[1],
                                          minspanx=5,
                                          minspany=5,
                                          spancoords="pixels")
        self.selector.set_active(False)

        # Add a "fitter" object to the widget
        self.fitter = Fitter()
        # Artists to go along with fitter
        self.fit_line = None
        self.fit_textbox = None

        # Modify navigation toolbar to include a fitting action/icon
        self.expand_toolbar()

    # TODO: This is hacky - look at subclassing the NavigationToolbar2Qt to 
    # get the desired behavior
    def expand_toolbar(self):
        """
        Modify the default navigation toolbar to include a new icon for 
        activating/deactivating an interactive fitter.
        """
        # Add a separator to the end of the toolbar
        self.mpl_toolbar.addSeparator()
        
        # Add interactive fit action
        fit_icon = QtGui.QIcon("/".join((RESOURCE_PATH, "gaus.svg")))
        self.fit_action = self.mpl_toolbar.addAction(fit_icon,
                                                     "Interactive fitting",
                                                     self.activate_fitter)
        self.fit_action.setCheckable(True)

        # Add fit-clearing action
        clearfit_icon = QtGui.QIcon("/".join((RESOURCE_PATH, "clear.svg")))
        self.clearfit_action = self.mpl_toolbar.addAction(clearfit_icon,
                                                          "Clear current fit",
                                                          self.clear_fit)

    def rect_select_callback(self, click_event, release_event):
        """
        Handle events from selector.
        """
        # TODO: Implement
        x1, y1 = click_event.xdata, click_event.ydata
        x2, y2 = release_event.xdata, release_event.ydata
        # Set the data subregion of the fitter
        self.fitter.xmin, self.fitter.xmax = x1, x2
        # Fit data
        self.fitter.fit()
        self.show_fit()
        # Turn fitting action back off
        self.deactivate_fitter()

    def activate_fitter(self):
        self.fit_action.setChecked(True)
        self.selector.set_active(True)

    def deactivate_fitter(self):
        self.fit_action.setChecked(False)
        self.selector.set_active(False)

    # TODO: The functionality in this method should be moved to the Fitter
    # class. See branch feat/fitter_textbox for aborted attempt
    def generate_text_summary(self):
        """
        Summarize the results presenting the optimal parameters in text
        """
        # Return None if no optimized params to summarize
        if self.fitter.popt is None: return
        hdr = "Optimal Parameters $(\pm 1\sigma)$:\n"
        # Summarize fit results
        summary = "\n".join(["  p[%s] = %.3f $\pm$ %.3f" %(p, val, err) for \
                             p, (val, err) in enumerate(zip(self.fitter.popt, self.fitter.perr))])
        return hdr + summary

    def show_fit(self):
        """
        Draw the model with the optimized params on the axis.
        """
        # Generate x-values in ROI at 2x density of actual data
        x = np.linspace(self.fitter.xmin, self.fitter.xmax,
                        2 * self.fitter.xf.shape[0])
        y = self.fitter.model(x, *self.fitter.popt)
        # TODO: How to choose/set the color/texture of fitted model
        self.fit_line = self.axes.plot(x, y, "m-")[0]
        # Add a textbox summarizing the fit
        # TODO: Customize where/how summary of text shows up. Currently, have
        # textbox pop up in upper-left corner
        self.fit_textbox = self.axes.text(0.0, 1.0,   # Axes coordinates
                                          self.generate_text_summary(),
                                          horizontalalignment="left",
                                          verticalalignment="top",
                                          transform=self.axes.transAxes,
                                          bbox={"facecolor" : "yellow",
                                                "alpha"     : 0.5,
                                                "pad"       : 0})
        self.canvas.draw()

    def clear_fit(self):
        """
        Remove artists associated with interactive fitting from the canvas.
        """
        if self.fit_line is not None:
            self.axes.lines.remove(self.fit_line)
            self.fit_line = None
        if self.fit_textbox is not None:
            self.fit_textbox.remove()
            self.fit_textbox = None
        self.canvas.draw()

    # TODO: Explicit wrapper of axes.plot - figure out how to do this 
    # implicitly (class decorator? Populate obj dict with axes.__dict__?)
    def plot(self, *args, **kwargs):
        """
        Wrapper around Axes.plot method that sets the fitter data to the 
        x, y values of the plotted line.
        """
        # Pass arguments along to axis method
        line, = self.axes.plot(*args, **kwargs)
        # Set fitter properties
        self.fitter.set_data(*(line.get_data()))
        # Render
        self.canvas.draw()

    def hist(self, *args, **kwargs):
        # Create histogram first
        h, be, _ = self.axes.hist(*args, **kwargs)
        # Set fitter
        # TODO: Only linear interp of binning here!
        bc = be[:-1] + (be[1:]- be[:-1]) / 2.0
        self.fitter.set_data(bc, h)
        # Render
        self.canvas.draw()
Ejemplo n.º 36
0
class TaylorDiagramWidget(MplWidget):
    def __init__(self, parent=None):
        MplWidget.__init__(self, parent)

        self.markers = [
            'o', 'x', '*', ',', '+', '.', 's', 'v', '<', '>', '^', 'D', 'h',
            'H', '_', '8', 'd', 3, 0, 1, 2, 7, 4, 5, 6, '1', '3', '4', '2',
            '|', 'x'
        ]

    def draw(self):
        (self.coord, self.stats, title, showLegend) = self.inputPorts

        stds, corrs = self.stats.values[:, 0], self.stats.values[:, 1]
        self.Xs = stds * corrs
        self.Ys = stds * np.sin(np.arccos(corrs))

        colors = pylab.cm.jet(np.linspace(0, 1, len(self.stats.ids)))

        pylab.clf()
        fig = pylab.figure(str(self))
        dia = taylor_diagram.TaylorDiagram(stds[0],
                                           corrs[0],
                                           fig=fig,
                                           label=self.stats.labels[0])
        dia.samplePoints[0].set_color(
            colors[0])  # Mark reference point as a red star
        if self.stats.ids[0] in self.selectedIds:
            dia.samplePoints[0].set_markeredgewidth(3)

        # add models to Taylor diagram
        for i, (_id, stddev, corrcoef) in enumerate(
                zip(self.stats.ids[1:], stds[1:], corrs[1:])):
            label = self.stats.labels[i + 1]
            size = 3 if _id in self.selectedIds else 1
            dia.add_sample(
                stddev,
                corrcoef,
                marker='o',  #self.markers[i],
                ls='',
                mfc=colors[i + 1],
                mew=size,
                label=label)

        # Add grid
        dia.add_grid()

        # Add RMS contours, and label them
        contours = dia.add_contours(levels=5, colors='0.5')  # 5 levels in grey
        pylab.clabel(contours, inline=1, fontsize=10, fmt='%.1f')

        # Add a figure legend and title
        if showLegend:
            fig.legend(dia.samplePoints,
                       [p.get_label() for p in dia.samplePoints],
                       numpoints=1,
                       prop=dict(size='small'),
                       loc='upper right')
        fig.suptitle(title, size='x-large')  # Figure title
        self.figManager.canvas.draw()

        self.rectSelector = RectangleSelector(pylab.gca(),
                                              self.onselect,
                                              drawtype='box',
                                              rectprops=dict(
                                                  alpha=0.4,
                                                  facecolor='yellow'))
        self.rectSelector.set_active(True)

    def updateSelection(self, selectedIds):
        self.selectedIds = selectedIds
        self.updateContents()

    def onselect(self, eclick, erelease):
        if (self.coord is None): return

        left, bottom = min(eclick.xdata,
                           erelease.xdata), min(eclick.ydata, erelease.ydata)
        right, top = max(eclick.xdata,
                         erelease.xdata), max(eclick.ydata, erelease.ydata)
        region = Bbox.from_extents(left, bottom, right, top)

        selectedIds = []
        for (x, y, idd) in zip(self.Xs, self.Ys, self.stats.ids):
            if region.contains(x, y):
                selectedIds.append(idd)
        self.coord.notifyModules(selectedIds)
Ejemplo n.º 37
0
class InteractiveCut(object):
    def __init__(self, slice_plot, canvas, ws_title):
        self.slice_plot = slice_plot
        self._canvas = canvas
        self._ws_title = ws_title

        self.horizontal = None
        self.connect_event = [None, None, None, None]
        self._cut_plotter_presenter = CutPlotterPresenter()
        self._rect_pos_cache = [0, 0, 0, 0, 0, 0]
        self.rect = RectangleSelector(self._canvas.figure.gca(),
                                      self.plot_from_mouse_event,
                                      drawtype='box',
                                      useblit=True,
                                      button=[1, 3],
                                      spancoords='pixels',
                                      interactive=True)

        self.connect_event[3] = self._canvas.mpl_connect(
            'draw_event', self.redraw_rectangle)
        self._canvas.draw()

    def plot_from_mouse_event(self, eclick, erelease):
        # Make axis orientation sticky, until user selects entirely new rectangle.
        rect_pos = [
            eclick.x, eclick.y, erelease.x, erelease.y,
            abs(erelease.x - eclick.x),
            abs(erelease.y - eclick.y)
        ]
        rectangle_changed = all([
            abs(rect_pos[i] - self._rect_pos_cache[i]) > 0.1 for i in range(6)
        ])
        if rectangle_changed:
            self.horizontal = abs(erelease.x - eclick.x) > abs(erelease.y -
                                                               eclick.y)
        self.plot_cut(eclick.xdata, erelease.xdata, eclick.ydata,
                      erelease.ydata)
        self.connect_event[2] = self._canvas.mpl_connect(
            'button_press_event', self.clicked)
        self._rect_pos_cache = rect_pos

    def plot_cut(self, x1, x2, y1, y2, store=False):
        if x2 > x1 and y2 > y1:
            ax, integration_start, integration_end = self.get_cut_parameters(
                (x1, y1), (x2, y2))
            units = self._canvas.figure.gca().get_yaxis().units if self.horizontal else \
                self._canvas.figure.gca().get_xaxis().units
            integration_axis = Axis(units, integration_start, integration_end,
                                    0)
            self._cut_plotter_presenter.plot_interactive_cut(
                str(self._ws_title), ax, integration_axis, store)
            self._cut_plotter_presenter.store_icut(self._ws_title, self)

    def get_cut_parameters(self, pos1, pos2):
        start = pos1[not self.horizontal]
        end = pos2[not self.horizontal]
        units = self._canvas.figure.gca().get_xaxis().units if self.horizontal else \
            self._canvas.figure.gca().get_yaxis().units
        step = get_limits(get_workspace_handle(self._ws_title), units)[2]
        ax = Axis(units, start, end, step)
        integration_start = pos1[self.horizontal]
        integration_end = pos2[self.horizontal]
        return ax, integration_start, integration_end

    def clicked(self, event):
        self.connect_event[0] = self._canvas.mpl_connect(
            'motion_notify_event', lambda x: self.plot_cut(*self.rect.extents))
        self.connect_event[1] = self._canvas.mpl_connect(
            'button_release_event', self.end_drag)

    def end_drag(self, event):
        self._canvas.mpl_disconnect(self.connect_event[0])
        self._canvas.mpl_disconnect(self.connect_event[1])

    def redraw_rectangle(self, event):
        if self.rect.active:
            self.rect.update()

    def save_cut(self):
        x1, x2, y1, y2 = self.rect.extents
        self.plot_cut(x1, x2, y1, y2, store=True)
        self.update_workspaces()
        ax, integration_start, integration_end = self.get_cut_parameters(
            (x1, y1), (x2, y2))
        return output_workspace_name(str(self._ws_title), integration_start,
                                     integration_end)

    def update_workspaces(self):
        self.slice_plot.update_workspaces()

    def clear(self):
        self._cut_plotter_presenter.set_is_icut(self._ws_title, False)
        self.rect.set_active(False)
        for event in self.connect_event:
            self._canvas.mpl_disconnect(event)
        self._canvas.draw()

    def flip_axis(self):
        self.horizontal = not self.horizontal
        self.plot_cut(*self.rect.extents)

    def window_closing(self):
        self.slice_plot.toggle_interactive_cuts()
        self.slice_plot.plot_window.action_interactive_cuts.setChecked(False)
Ejemplo n.º 38
0
from matplotlib.widgets import  RectangleSelector
from pylab import *

def onselect(eclick, erelease):
  'eclick and erelease are matplotlib events at press and release'
  print ' startposition : (%f, %f)' % (eclick.xdata, eclick.ydata)
  print ' endposition   : (%f, %f)' % (erelease.xdata, erelease.ydata)
  print ' used button   : ', eclick.button

def toggle_selector(event):
    print ' Key pressed.'
    if event.key in ['Q', 'q'] and toggle_selector.RS.active:
        print ' RectangleSelector deactivated.'
        toggle_selector.RS.set_active(False)
    if event.key in ['A', 'a'] and not toggle_selector.RS.active:
        print ' RectangleSelector activated.'
        toggle_selector.RS.set_active(True)

x = arange(100)/(99.0)
y = sin(x)
fig = figure
ax = subplot(111)
ax.plot(x,y)

#toggle_selector.RS = RectangleSelector(ax, onselect, drawtype='line')

#connect('key_press_event', toggle_selector)
RS = RectangleSelector(ax, onselect, drawtype='box')
RS.set_active(True)
show()
print('done: put breakpoint here!')
Ejemplo n.º 39
0
class InteractiveCut(object):

    def __init__(self, slice_plot, canvas, ws_title):
        self.slice_plot = slice_plot
        self._canvas = canvas
        self._ws_title = ws_title
        self._en_unit = slice_plot.get_slice_cache().energy_axis.e_unit
        self._en_from_meV = EnergyUnits(self._en_unit).factor_from_meV()

        self.horizontal = None
        self.connect_event = [None, None, None, None]
        # We need to access the CutPlotterPresenter instance of the particular CutPlot (window) we are using
        # But there is no way to get without changing the active category then calling the GlobalFigureManager.
        # So we create a new temporary here. After the first time we plot a 1D plot, the correct category is set
        # and we can get the correct CutPlot instance and its CutPlotterPresenter
        self._cut_plotter_presenter = CutPlotterPresenter()
        self._is_initial_cut_plotter_presenter = True
        self._rect_pos_cache = [0, 0, 0, 0, 0, 0]
        self.rect = RectangleSelector(self._canvas.figure.gca(), self.plot_from_mouse_event,
                                      drawtype='box', useblit=True,
                                      button=[1, 3], spancoords='pixels', interactive=True)

        self.connect_event[3] = self._canvas.mpl_connect('draw_event', self.redraw_rectangle)
        self._canvas.draw()

    def plot_from_mouse_event(self, eclick, erelease):
        # Make axis orientation sticky, until user selects entirely new rectangle.
        rect_pos = [eclick.x, eclick.y, erelease.x, erelease.y,
                    abs(erelease.x - eclick.x), abs(erelease.y - eclick.y)]
        rectangle_changed = all([abs(rect_pos[i] - self._rect_pos_cache[i]) > 0.1 for i in range(6)])
        if rectangle_changed:
            self.horizontal = abs(erelease.x - eclick.x) > abs(erelease.y - eclick.y)
        self.plot_cut(eclick.xdata, erelease.xdata, eclick.ydata, erelease.ydata)
        self.connect_event[2] = self._canvas.mpl_connect('button_press_event', self.clicked)
        self._rect_pos_cache = rect_pos

    def plot_cut(self, x1, x2, y1, y2, store=False):
        if x2 > x1 and y2 > y1:
            ax, integration_start, integration_end = self.get_cut_parameters((x1, y1), (x2, y2))
            units = self._canvas.figure.gca().get_yaxis().units if self.horizontal else \
                self._canvas.figure.gca().get_xaxis().units
            integration_axis = Axis(units, integration_start, integration_end, 0, self._en_unit)
            cut = Cut(ax, integration_axis, None, None)
            self._cut_plotter_presenter.plot_interactive_cut(str(self._ws_title), cut, store)
            self._cut_plotter_presenter.set_is_icut(True)
            if self._is_initial_cut_plotter_presenter:
                # First time we've plotted a 1D cut - get the true CutPlotterPresenter
                from mslice.plotting.pyplot import GlobalFigureManager
                self._cut_plotter_presenter = GlobalFigureManager.get_active_figure().plot_handler._cut_plotter_presenter
                self._is_initial_cut_plotter_presenter = False
                GlobalFigureManager.disable_make_current()
            self._cut_plotter_presenter.store_icut(self)

    def get_cut_parameters(self, pos1, pos2):
        start = pos1[not self.horizontal]
        end = pos2[not self.horizontal]
        units = self._canvas.figure.gca().get_xaxis().units if self.horizontal else \
            self._canvas.figure.gca().get_yaxis().units
        step = get_limits(get_workspace_handle(self._ws_title), units)[2] * self._en_from_meV
        ax = Axis(units, start, end, step, self._en_unit)
        integration_start = pos1[self.horizontal]
        integration_end = pos2[self.horizontal]
        return ax, integration_start, integration_end

    def clicked(self, event):
        self.connect_event[0] = self._canvas.mpl_connect('motion_notify_event',
                                                         lambda x: self.plot_cut(*self.rect.extents))
        self.connect_event[1] = self._canvas.mpl_connect('button_release_event', self.end_drag)

    def end_drag(self, event):
        self._canvas.mpl_disconnect(self.connect_event[0])
        self._canvas.mpl_disconnect(self.connect_event[1])

    def redraw_rectangle(self, event):
        if self.rect.active:
            self.rect.update()

    def save_cut(self):
        x1, x2, y1, y2 = self.rect.extents
        self.plot_cut(x1, x2, y1, y2, store=True)
        self.update_workspaces()
        ax, integration_start, integration_end = self.get_cut_parameters((x1, y1), (x2, y2))
        return output_workspace_name(str(self._ws_title), integration_start, integration_end)

    def update_workspaces(self):
        self.slice_plot.update_workspaces()

    def clear(self):
        self._cut_plotter_presenter.set_is_icut(False)
        self.rect.set_active(False)
        for event in self.connect_event:
            self._canvas.mpl_disconnect(event)
        self._canvas.draw()

    def flip_axis(self):
        self.horizontal = not self.horizontal
        self.plot_cut(*self.rect.extents)

    def window_closing(self):
        self.slice_plot.toggle_interactive_cuts()
        self.slice_plot.plot_window.action_interactive_cuts.setChecked(False)
Ejemplo n.º 40
0
def gui():
    """""" """""" """""" """
     G  L  O  B  A  L  S
    """ """""" """""" """"""
    #    global dispInterf_state
    #    global stop_event

    global root, livefeed_canvas, imageplots_frame

    global slider_exposure, measurementfolder_name, calibrationfolder_name
    global button_initcamera, button_release, button_saveconfig
    global text_config, tbox, entry_measfolder, entry_calibfolder, piezosteps_var, CheckVar, output_text
    global a, colors, canvas_plot, line, ax

    global updateconfig_event

    global POI, ROI
    global displaymidline_state
    global piezo_dispaxis, calib_linear_region
    global toggle_selector_RS

    POI, ROI = [], []
    displaymidline_state = False
    measurementfolder_name = 'stack'
    calibrationfolder_name = 'stack'
    calib_linear_region = [19, None]

    #    dispInterf_state = 0

    q = Queue()
    stdout_queue = Queue()

    beginlive_event = Event()
    stoplive_event = Event()
    release_event = Event()
    #    stop_event = Event()
    piezostep_event = Event()
    updateconfig_event = Event()
    plotmidline_event = Event()
    """
    MAIN WINDOW
    """
    root = tk.Tk()
    root.iconbitmap('winicon.ico')
    #root.wm_attributes('-topmost', 1)
    #    w, h = root.winfo_screenwidth(), root.winfo_screenheight()
    #root.geometry("%dx%d+0+0" % (w, h))
    root.title('White light interferometry: Topography')
    root.configure(background='grey')
    """
    MENU
    """
    menubar = tk.Menu(root)
    filemenu = tk.Menu(menubar, tearoff=0)
    #filemenu.add_command(label = 'Load image', command=openimage)
    filemenu.add_command(label='Save displayed image')
    menubar.add_cascade(label='File', menu=filemenu)

    optionsmenu = tk.Menu(menubar, tearoff=0)
    optionsmenu.add_command(label='Configure camera')
    menubar.add_cascade(label='Options', menu=optionsmenu)

    menubar.add_command(label='Help')
    """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """
            L        A        Y        O        U        T
    """ """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """"""
    """
    2 MAIN TABS
    """
    tabControl = ttk.Notebook(root)
    tab_ops = ttk.Frame(tabControl)
    tab_config = ttk.Frame(tabControl)
    tabControl.add(tab_ops, text='Operations')
    tabControl.add(tab_config, text='Camera configuration')
    tabControl.grid(row=0, sticky='we')
    """
    CAMERA CONFIGURATIONS FROM .INI FILE
    """
    text_config = tk.Text(tab_config, bg='gray18', fg='thistle1')
    text_config.grid(row=0, sticky='we')

    scrollb = tk.Scrollbar(tab_config, command=text_config.yview)
    scrollb.grid(row=0, column=1, sticky='nswe')
    text_config['yscrollcommand'] = scrollb.set

    root.config_ini = 'config.ini'
    file_contents = open(root.config_ini).read()
    text_config.insert('end', file_contents)

    tbox_contents = text_config.get('1.0', 'end')
    # c.char is a fixed array, so in case future editing needs more array space,
    # an expanded array is passed as an argument
    tbox = Array(ctypes.c_char, bytes(tbox_contents + '\n' * 10, 'utf8'))

    text_config.tag_config('section', foreground='khaki1')
    update_tboxEmbellishments()
    """
    SAVE CONFIGURATION CHANGES
    """
    button_saveconfig = tk.Button(tab_config,
                                  text="Save changes",
                                  bg="white",
                                  fg="black",
                                  command=update_config)

    button_saveconfig.grid(row=0, column=2, padx=10, pady=10)
    """
    CAMERA CONNECTION/DISCONNECTION FRAME
    """
    cameraonoff_frame = tk.Frame(tab_ops)
    cameraonoff_frame.grid(row=0, sticky='we')

    cameraonoff_frame.grid_rowconfigure(0, weight=1)
    cameraonoff_frame.grid_columnconfigure(0, weight=1)
    cameraonoff_frame.grid_columnconfigure(1, weight=1)
    """
    INITIALIZE CAMERA
    """
    button_initcamera = tk.Button(cameraonoff_frame,
                       text="INITIALIZE CAMERA",
                       bg="white",
                       fg="black",
                       command=lambda: create_cameraprocess(q, release_event, \
                            beginlive_event, stoplive_event, piezostep_event, \
                            updateconfig_event, plotmidline_event, stdout_queue))

    button_initcamera.grid(row=0, column=0, padx=10, pady=10)
    """
    RELEASE CAMERA
    """
    button_release = tk.Button(
        cameraonoff_frame,
        text="RELEASE CAMERA",
        bg="white",
        fg="black",
        command=lambda: notify_releasecamera(release_event))
    button_release.grid(row=0, column=1, padx=10, pady=10)

    #    """
    #    EXPOSURE TIME CONFIGURATION
    #    """
    #    label_exposure = tk.Label(cameraonoff_frame, text='Exposure time: ')
    #    label_exposure.grid(row=0,column=2,padx=10,pady=10,sticky='we')
    #
    #    slider_exposure = tk.Scale(cameraonoff_frame, from_=6, to=2000, orient='horizontal')
    #    slider_exposure.grid(row=0,column=3,padx=10,pady=10,sticky='we')
    #
    #    entry_exposure = tk.Entry(cameraonoff_frame)
    #    entry_exposure.grid(row=0,column=4,padx=10,pady=10,sticky='we')
    #    entry_exposure.delete(0, 'end')
    #    entry_exposure.insert(0, '')
    #    entry_exposure.bind("<Return>", get_exposure)
    #
    #    button_exposure = tk.Button(cameraonoff_frame,
    #                       text="Set",
    #                       bg="white",
    #                       fg="black",
    #                       command=set_exposure)
    #    button_exposure.grid(row=0,column=5,padx=10,pady=10,sticky='we')

    root.rowconfigure(0, weight=1)
    root.columnconfigure(0, weight=1)

    tab_ops.rowconfigure(0, weight=1)
    tab_ops.columnconfigure(0, weight=1)

    tab_config.rowconfigure(0, weight=1)
    tab_config.columnconfigure(0, weight=1)

    main_frame = tk.Frame(tab_ops)
    main_frame.grid(row=1, column=0, sticky='nswe')

    main_frame.grid_rowconfigure(0, weight=1)
    main_frame.grid_rowconfigure(1, weight=1)
    main_frame.grid_rowconfigure(2, weight=1)
    main_frame.grid_rowconfigure(3, weight=1)
    main_frame.grid_columnconfigure(0, weight=1)
    main_frame.grid_columnconfigure(1, weight=1)
    main_frame.grid_columnconfigure(2, weight=1)

    imageplots_frame = tk.Frame(main_frame)
    imageplots_frame.grid(row=0, column=0, sticky='nswe')

    imageplots_frame.grid_rowconfigure(0, weight=1)
    imageplots_frame.grid_columnconfigure(0, weight=1)

    oscillation_frame = tk.Frame(main_frame)
    oscillation_frame.grid(row=0, column=1, sticky='nswe')

    oscillation_frame.grid_rowconfigure(0, weight=1)
    oscillation_frame.grid_columnconfigure(0, weight=1)

    buttons_frame = tk.Frame(main_frame)
    buttons_frame.grid(row=0, column=2, sticky='nswe')

    buttons_frame.grid_rowconfigure(0, weight=1)
    buttons_frame.grid_columnconfigure(0, weight=1)

    piezosteps_frame = tk.Frame(buttons_frame)
    piezosteps_frame.grid(row=0, column=0, sticky='nswe')

    piezosteps_frame.grid_rowconfigure(0, weight=1)
    piezosteps_frame.grid_columnconfigure(0, weight=1)

    preperation_frame = tk.Frame(buttons_frame, borderwidth=1, relief='solid')
    preperation_frame.grid(row=1, column=0, sticky='nswe')

    preperation_frame.grid_rowconfigure(0, weight=1)
    preperation_frame.grid_columnconfigure(0, weight=1)

    measurement_frame = tk.Frame(buttons_frame, borderwidth=1, relief='solid')
    measurement_frame.grid(row=2, column=0, sticky='nswe', pady=10)

    measurement_frame.grid_rowconfigure(0, weight=1)
    measurement_frame.grid_columnconfigure(0, weight=1)

    selections_frame = tk.Frame(main_frame)
    selections_frame.grid(row=2, column=0, sticky='nswe')

    selections_frame.grid_rowconfigure(0, weight=1)
    selections_frame.grid_columnconfigure(0, weight=1)
    """
    CREATE CANVAS FOR IMAGE DISPLAY
    """
    #    sections = cfg.load_config_fromtkText('ALL', text_config.get('1.0', 'end'))
    #    h = eval(sections[2]['height'])
    #    w = eval(sections[2]['width'])
    dpi = 96.0
    #    f = Figure(figsize=(w/dpi,h/dpi))
    f = Figure(figsize=(500 / dpi, 500 / dpi), dpi=96)
    f.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0)
    ax = f.add_subplot(111)
    ax.set_axis_off()
    img = Image.frombytes('L', (500, 500), b'\x00' * 250000)
    ax.imshow(img, cmap='gray', vmin=0, vmax=65535)
    livefeed_canvas = FigureCanvasTkAgg(f, master=imageplots_frame)
    livefeed_canvas.get_tk_widget().grid(row=0, column=0, sticky='nswe')
    livefeed_canvas.draw()
    """
    TOOLBAR - IMAGE SHOW
    """
    toolbarimshowFrame = tk.Frame(master=imageplots_frame)
    toolbarimshowFrame.grid(row=1, column=0)
    toolbarimshow = NavigationToolbar2Tk(livefeed_canvas, toolbarimshowFrame)
    toolbarimshow.update()

    # create global list of different plotting colors
    colors = lspec()
    """
    CREATE CANVAS FOR POI AND MIDLINE PLOTTING
    """
    #    dpi = 96
    #    fig = Figure(figsize=(imageplots_frame.winfo_height()/dpi, imageplots_frame.winfo_width()/3/dpi))
    fig = Figure(tight_layout=True)
    a = fig.add_subplot(111)
    canvas_plot = FigureCanvasTkAgg(fig, master=imageplots_frame)
    canvas_plot.get_tk_widget().grid(row=0, column=1, sticky='nswe')
    """
    TOOLBAR - LINE PLOT
    """
    toolbarplotFrame = tk.Frame(master=imageplots_frame)
    toolbarplotFrame.grid(row=1, column=1)
    toolbarplot = NavigationToolbar2Tk(canvas_plot, toolbarplotFrame)
    toolbarplot.update()

    label_preparation = tk.Label(preperation_frame,
                                 text='P R E P A R A T I O N')
    label_preparation.grid(row=0, columnspan=2)
    """
    POINTS OF INTEREST SELECTION BUTTONS
    """
    button_POIenable = tk.Button(preperation_frame,
                                 text="Select POI",
                                 fg="gold2",
                                 bg='grey18',
                                 command=enable_POIsel)
    button_POIenable.grid(row=1, column=0, padx=0, pady=10, sticky='ew')

    button_POIdisable = tk.Button(preperation_frame,
                                  text="OK!",
                                  fg="black",
                                  command=disable_POIsel)
    button_POIdisable.grid(row=1, column=1, padx=0, pady=0, sticky='ew')
    """
    REGION OF INTEREST SELECT BUTTONS
    """
    button_ROIenable = tk.Button(preperation_frame,
                                 text="Select ROI",
                                 fg="chocolate2",
                                 bg='grey18',
                                 command=enable_ROIsel)
    button_ROIenable.grid(row=2, column=0, padx=0, pady=0, sticky='ew')

    #    button_ROIdisable = tk.Button(preperation_frame,
    #                       text="OK!",
    #                       fg="black",
    #                       command=disable_ROIsel)
    #    button_ROIdisable.grid(row=2,column=1,padx=0,pady=0,sticky='ew')

    simplebuttons_frame = tk.Frame(selections_frame)
    simplebuttons_frame.grid(row=0, column=0, sticky='nswe')

    selections_frame.grid_rowconfigure(0, weight=1)
    selections_frame.grid_columnconfigure(0, weight=1)
    """
    OSCILLATE PIEZO BUTTON
    """
    button_oscillate = tk.Button(preperation_frame,
                                 text="Oscillate Piezo",
                                 fg="black",
                                 command=lambda: create_oscillationprocess(
                                     piezostep_event, stdout_queue))
    button_oscillate.grid(row=3, column=0, padx=0, pady=0, sticky='ew')
    """
    PLOT MIDLINE BUTTON
    """
    button_oscillate = tk.Button(
        preperation_frame,
        text="Display intensity across\nhor/ntal line",
        fg="black",
        command=lambda: toggle_displayMidline(plotmidline_event))
    button_oscillate.grid(row=3, column=1, padx=0, pady=0, sticky='ew')
    """
    OPTION LIST FOR NUMBER OF PIEZO STEPS - LABEL
    """
    label_piezosteps = tk.Label(piezosteps_frame, text='Piezo steps:')
    label_piezosteps.grid(row=0, column=0, padx=0, pady=0, sticky='ew')
    """
    OPTION LIST FOR NUMBER OF PIEZO STEPS - VALUES
    """
    piezosteps_options = ['100', '200', '300', '400', '500', '600']
    piezosteps_var = tk.StringVar()
    piezosteps_var.set(piezosteps_options[-1])  # default value

    optionmenu_piezosteps = tk.OptionMenu(piezosteps_frame, piezosteps_var,
                                          *piezosteps_options)
    optionmenu_piezosteps.grid(row=0, column=1, padx=0, pady=0, sticky='ew')

    label_measurement = tk.Label(measurement_frame,
                                 text='M E A S U R E M E N T')
    label_measurement.grid(row=0, columnspan=2)
    """
    EXECUTE PIEZO SCAN FOR INTERFEROGRAM STACK CAPTURING
    """
    button_oscillate = tk.Button(measurement_frame,
                                 text="Piezo scan\nCapture interferograms",
                                 fg="khaki1",
                                 bg='grey18',
                                 command=prepare_piezoscan)
    button_oscillate.grid(row=1,
                          column=0,
                          columnspan=2,
                          padx=0,
                          pady=10,
                          sticky='ew')

    CheckVar = tk.StringVar()
    CheckVar.set('m')
    C1 = tk.Radiobutton(measurement_frame,
                        text='Measurement',
                        variable=CheckVar,
                        value='m')
    C2 = tk.Radiobutton(measurement_frame,
                        text='Calibration',
                        variable=CheckVar,
                        value='c')
    C1.grid(row=2, column=0, padx=0, pady=0, sticky='we')
    C2.grid(row=2, column=1, padx=0, pady=0, sticky='we')
    """
    EXECUTE PIEZO SCAN FOR CALIBRATION PROCESS
    """
    button_oscillate = tk.Button(measurement_frame,
                                 text="Calibrate",
                                 fg="khaki1",
                                 bg='grey18',
                                 command=prepare_calibrate)
    button_oscillate.grid(row=5, column=0, padx=0, pady=0, sticky='ew')
    """
    INSERT MEASUREMENT IMAGE STACK FOLDER NAME
    """
    label_folder = tk.Label(measurement_frame,
                            text='Measurement\nfolder name:')
    label_folder.grid(row=3, column=0, padx=0, pady=0, sticky='ew')

    entry_measfolder = tk.Entry(measurement_frame)
    entry_measfolder.grid(row=3, column=1, padx=0, pady=0, sticky='we')
    entry_measfolder.delete(0, 'end')
    entry_measfolder.insert(0, 'stack')
    entry_measfolder.bind("<Return>", setMeasurementFolderName)
    """
    INSERT CALIBRATION IMAGE STACK FOLDER NAME
    """
    label_folder = tk.Label(measurement_frame,
                            text='Calibration\nfolder name:')
    label_folder.grid(row=4, column=0, padx=0, pady=0, sticky='ew')

    entry_calibfolder = tk.Entry(measurement_frame)
    entry_calibfolder.grid(row=4, column=1, padx=0, pady=0, sticky='we')
    entry_calibfolder.delete(0, 'end')
    entry_calibfolder.insert(0, 'stack')
    entry_calibfolder.bind("<Return>", setCalibrationFolderName)
    """
    EXECUTE INTERFEROGRAM STACK ANALYSIS FOR SURFACE ELEVATION MAP EXTRACTION
    """
    button_oscillate = tk.Button(measurement_frame,
                                 text="Analyze",
                                 fg="khaki1",
                                 bg='grey18',
                                 command=prepare_analysis)
    button_oscillate.grid(row=5, column=1, padx=0, pady=0, sticky='ew')

    #    """
    #    DISPLAY POI INTERFEROGRAMS
    #    """
    #    button_toggleDispInterf = tk.Button(oscillation_frame,
    #                       text="Display/Hide\nPOI Interferograms",
    #                       fg="black",
    #                       command=toggleDispInterf_threaded)
    #    button_toggleDispInterf.grid(row=2,column=1,padx=0,pady=0,sticky='ew')
    """
    LIVE FEED BUTTON
    """
    button_live = tk.Button(simplebuttons_frame,
                            text="Live!",
                            fg="red",
                            bg='grey18',
                            command=lambda: notify_beginlive(beginlive_event))
    button_live.grid(row=0, column=0, padx=10, pady=10, sticky='ew')
    """
    STOP LIVE BUTTON
    """
    button_reset = tk.Button(simplebuttons_frame,
                             text="Stop Live Feed",
                             fg="red",
                             bg='grey18',
                             command=lambda: notify_stoplive(stoplive_event))
    button_reset.grid(row=0, column=1, padx=10, pady=10, sticky='we')
    """""" """
     || || ||
     || || ||
     VV VV VV
    """ """"""
    piezo_dispaxis = np.loadtxt('Mapping_Steps_Displacement_2.txt')
    """
    OUTPUT TEXT
    """
    redirectstdout_frame = tk.Frame(main_frame)
    redirectstdout_frame.grid(row=3, column=0, sticky='nswe')

    output_text = ScrolledText(redirectstdout_frame,
                               bg='gray18',
                               fg='thistle1',
                               width=75,
                               height=10)
    output_text.see('end')
    output_text.grid(row=0, padx=10, pady=10, sticky='nswe')
    """
    CLEAR OUTPUT TEXT BOX
    """
    button_cleartbox = tk.Button(redirectstdout_frame,
                                 text="Clear",
                                 fg="lawn green",
                                 bg='grey18',
                                 command=clear_outputtext)
    button_cleartbox.grid(row=0, column=1, padx=10, pady=10, sticky='we')
    """
    RECTANGLE SELECTOR OBJECT - for ROI selection
    """
    toggle_selector_RS = RectangleSelector(
        ax,
        line_select_callback,
        drawtype='box',
        useblit=True,
        button=[1, 3],  # don't use middle button
        minspanx=5,
        minspany=5,
        spancoords='pixels',
        interactive=True)
    toggle_selector_RS.set_active(False)
    toggle_selector_RS.set_visible(False)
    """
    STDOUT REDIRECTION
    """
    sys.stdout = StdoutQueue(stdout_queue)
    sys.stderr = StdoutQueue(stdout_queue)

    # Instantiate and start the text monitor
    monitor = Thread(target=text_catcher, args=(output_text, stdout_queue))
    monitor.daemon = True
    monitor.start()

    #    sys.stdout = StdoutRedirector(output_text)
    #    sys.stderr = StderrRedirector(output_text)

    root.protocol("WM_DELETE_WINDOW", on_close)
    root.config(menu=menubar)
    root.mainloop()
Ejemplo n.º 41
0
class Thumbnail:
    def __init__(self, image, downscale, channel_names, probability_table,
                 output_dir):

        # [x0, y0] -> coordinates of top left selected box
        # [x1, y1] -> coordinates of top left selected box
        self.x0 = None
        self.x1 = None
        self.y0 = None
        self.y1 = None

        # handles for fig, ax, rectangle selector
        self.fig = None  # figure object
        self.ax = None  # axes object
        self.RS = None  # rectangle selector object
        self.dsr = downscale  # DownScale Ratio
        self.thumbnail_width = 3

        self.channel_fnames = channel_names
        self.channels = [
            os.path.splitext(os.path.split(channel)[1])[0]
            for channel in self.channel_fnames
        ]

        # read probability table
        self.probability_table = self.read_probability_table(probability_table)

        # generate groundtruth table based on probability table and set values of channels to NaN
        self.create_groundtruth_table()

        self.output_dir = output_dir

        # read thumbnail image
        self.image = None
        self.image_size = None

        self.read_image(image)

        self.plot_thumbnail()

    def read_image(self, image_filename):
        # read image
        self.image = memmap(
            os.path.join(self.output_dir, 'memmap',
                         os.path.split(image_filename)[1]))
        self.image_size = self.image.shape[::-1]  # (width, height)

        # downscale image
        self.image = self.image[::self.dsr, ::self.dsr]

        # adjust intensity to 2% and 98% percentile
        p2, p98 = np.percentile(self.image, (2, 98))
        self.image = rescale_intensity(self.image, in_range=(p2, p98))

    def read_probability_table(self, probability_table_fname):
        probability_table = pd.read_csv(probability_table_fname)

        # check the table contains 'ID',  'centroid_x', 'centroid_y', 'xmin', 'ymin', 'xmax', 'ymax'
        must_have_columns = [
            'ID', 'centroid_x', 'centroid_y', 'xmin', 'ymin', 'xmax', 'ymax'
        ]
        assert all(item in probability_table.columns for item in must_have_columns), "table must contain 'ID', " \
                                                                                     "'centroid_x', 'centroid_y'," \
                                                                                     "'xmin', 'ymin', 'xmax', 'ymax'"

        # set ID as index
        probability_table.set_index('ID', inplace=True)

        # get column names of the selected channels
        column_indices = [
            list(probability_table.columns.str.lower()).index(im.lower())
            for im in self.channels
        ]
        self.column_names = [
            list(probability_table.columns)[i] for i in column_indices
        ]

        # Just keep 'ID' (index), 'centroid_x', 'centroid_y', 'xmin', 'ymin', 'xmax', 'ymax' and selected channels
        columns_to_keep = list(
            probability_table.columns[:6]) + self.column_names
        return probability_table.loc[:, columns_to_keep]

    def create_groundtruth_table(self):
        self.groundtruth_table = self.probability_table.copy()

        # set the values of the channel columns to nan
        self.groundtruth_table.loc[:, self.column_names] = np.nan

    def postprocess_groundtruth_table(self):

        # drop rows with na
        self.groundtruth_table = self.groundtruth_table.dropna()

        # change to int
        self.groundtruth_table = self.groundtruth_table.astype(int)

        # drop rows with no class assigned (no biomarker)
        all_zeros = (self.groundtruth_table[self.column_names] !=
                     0).sum(1) == 0
        self.groundtruth_table.drop(all_zeros[all_zeros].index, inplace=True)

        # drop rows with more than 1 class assigned (dual biomarker)
        dual_markers = (self.groundtruth_table[self.column_names] !=
                        0).sum(1) > 1
        self.groundtruth_table.drop(dual_markers[dual_markers].index,
                                    inplace=True)

    def plot_thumbnail(self):

        # plot figure and set size
        fig_size_inch = self.thumbnail_width, self.thumbnail_width * self.image.shape[
            0] / self.image.shape[1]
        self.fig, self.ax = plt.subplots(figsize=fig_size_inch, dpi=150)

        # plot image
        plt.ion()
        plt.imshow(self.image, cmap='gray')
        plt.axis('off')
        plt.tight_layout()

        # add function when any key pressed
        plt.connect('key_press_event', self.toggle_selector)

        # create rectangle selector
        rectprops = dict(facecolor='white',
                         edgecolor='white',
                         linewidth=2,
                         alpha=0.2,
                         fill=True)
        self.RS = RectangleSelector(self.ax,
                                    self.onselect,
                                    drawtype='box',
                                    rectprops=rectprops,
                                    interactive=True)
        self.RS.set_active(False)

        # add button to initiate annotator
        axbtn = plt.axes([0.7, 0.05, 0.2, 0.075])
        annotate_btn = Button(axbtn, 'Annotate')
        annotate_btn.on_clicked(self.generate_annotator)

        # show plot
        plt.show(block=True)

        # postprocess and save groundtruth table
        self.postprocess_groundtruth_table()
        self.groundtruth_table.to_csv(
            os.path.join(self.output_dir, 'groundtruth_table.csv'))

        # temp
        # check the saved table
        from utils import center_image
        centers = self.groundtruth_table[['centroid_x', 'centroid_y']].values
        for ch in self.column_names:
            center_image(os.path.join(self.output_dir, ch + '.tif'),
                         centers[self.groundtruth_table[ch].values == 1, :],
                         self.image_size)

    def generate_annotator(self, event):
        annotation = Annotator(self.channel_fnames, self.probability_table,
                               self.output_dir)
        new_table = annotation.annotate([self.x0, self.y0], [self.x1, self.y1])

        # add the new table to the groundtruth table
        self.groundtruth_table.loc[
            new_table.index,
            self.column_names] = new_table.loc[:, self.column_names]

    def onselect(self, eclick, erelease):
        self.x0 = int(min(eclick.xdata, erelease.xdata)) * self.dsr
        self.x1 = int(max(eclick.xdata, erelease.xdata)) * self.dsr
        self.y0 = int(min(eclick.ydata, erelease.ydata)) * self.dsr
        self.y1 = int(max(eclick.ydata, erelease.ydata)) * self.dsr
        print('selected box = [{}, {}, {}, {}]'.format(self.x0, self.x1,
                                                       self.y0, self.y1))

    def toggle_selector(self, event):
        if event.key in ['A', 'a'] and not self.RS.active:
            print(' RectangleSelector activated.')
            self.RS.set_active(True)
        elif event.key in ['A', 'a'] and self.RS.active:
            print(' RectangleSelector deactivated.')
            self.RS.set_active(False)
class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red',
                              edgecolor='black',
                              alpha=0.5,
                              fill=True)
    selector_lineprops = dict(color='black',
                              linestyle='-',
                              linewidth=2,
                              alpha=0.5)

    def __init__(self,
                 data=None,
                 bands=None,
                 classes=None,
                 source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)

        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors

        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None

        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()

    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
            (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source

    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)

        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry

        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()

        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback

        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
            self.selection = None
            return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()

    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({
            'cmap': cm,
            'vmin': 0,
            'norm': NoNorm(),
            'interpolation': self._interpolation
        })
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes,
                                         mask=(self.classes == 0))
        else:
            self.class_rgb = np.array(self.classes)

    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()

    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
class MplInteraction(object):


    def __init__(self, figure):
        """Initializer
        :param Figure figure: The matplotlib figure to attach the behavior to.
        """
        self._fig_ref = weakref.ref(figure)
        self.canvas = FigureCanvas(figure)
        self._cids_zoom = []
        self._cids_pan = []
        self._cids = []
        self._cids_callback_zoom = {}
        self._cids_callback_pan = {}
        self._callback_rectangle = None
        self._rectangle_selector = None
        self._xLimits = None
        self._yLimits = None

        #Create invokers
        self._invokerZoom = UndoHistoryZoomInvoker(figure.canvas)
        self._invokerPan = UndoHistoryPanInvoker(figure.canvas)

    def _add_connection(self, event_name, callback):

        cid = self.canvas.mpl_connect(event_name, callback)
        self._cids.append(cid)

    def create_rectangle_ax(self, ax):
        rectprops = dict(facecolor = None, edgecolor = 'black', alpha = 1,
         fill=False, linewidth = 1, linestyle = '-')
        self._rectangle_selector = RectangleSelector(ax, self._callback_rectangle,
                                           drawtype='box', useblit=True,
                                           rectprops = rectprops,
                                           button=[1, 3],  # don't use middle button
                                           minspanx=5, minspany=5,
                                           spancoords='pixels',
                                           interactive=False)

        self._rectangle_selector.set_visible(False)

    def set_axes_limits(self, xLimits, yLimits):
        """
        Get initial limits to allow to adjust the last zoom command to be
        the inital limits
        """
        self._xLimits = xLimits
        self._yLimits = yLimits

    def __del__(self):
        self.disconnect()

    def _add_rectangle_callback(self, callback):
        """
        Beacuse the callback method can only be created when the
        Zoom event is created and the axe can only be know after the creation of it,
        this method allow to assign the callback before
        the creation of the rectangle selector object
        """
        self._callback_rectangle = callback

    def _add_connection_zoom(self, event_name, callback):
        """Called to add a connection of type zoom to an event of the figure
        :param str event_name: The matplotlib event name to connect to.
        :param callback: The callback to register to this event.
        """

        #cid = self.canvas.mpl_connect(event_name, callback)
        #self._cids_zoom.append(cid)
        self._cids_callback_zoom[event_name] = callback

    def _add_connection_pan(self, event_name, callback):
        """Called to add a connection of type pan to an event of the figure
        :param str event_name: The matplotlib event name to connect to.
        :param callback: The callback to register to this event.
        """
        #cid = self.canvas.mpl_connect(event_name, callback)
        #self._cids_pan.append(cid)
        self._cids_callback_pan[event_name] = callback

    def disconnect_zoom(self):
        """
        Disconnect all zoom events and disable the rectangle selector
        """
        if self._fig_ref is not None:
            figure = self._fig_ref()
            if figure is not None:
                for cid in self._cids_zoom:
                    figure.canvas.mpl_disconnect(cid)
                self._disable_rectangle()

        self._cids_zoom.clear()

    def disconnect_pan(self):
        """
        Disconnect all pan events
        """
        if self._fig_ref is not None:
            figure = self._fig_ref()
            if figure is not None:
                for cid in self._cids_pan:
                    figure.canvas.mpl_disconnect(cid)
        self._cids_pan.clear()

    def _disable_rectangle(self):
        self._rectangle_selector.set_visible(False)
        self._rectangle_selector.set_active(False)

    def _enable_rectangle(self):
        self._rectangle_selector.set_visible(True)
        self._rectangle_selector.set_active(True)

    def connect_zoom(self):
        """
        Assign all callback zoom events to the mpl
        """
        for event_name, callback in self._cids_callback_zoom.items():
            cid = self.canvas.mpl_connect(event_name, callback)
            self._cids_zoom.append(cid)

        self._enable_rectangle()

    def connect_pan(self):
        """
        Assign all callback pan events to the mpl
        """
        for event_name, callback in self._cids_callback_pan.items():
            cid = self.canvas.mpl_connect(event_name, callback)
            self._cids_pan.append(cid)

    def disconnect(self):
        """Disconnect interaction from Figure."""
        if self._fig_ref is not None:
            figure = self._fig_ref()
            if figure is not None:
                for cid in self._cids_zoom:
                    figure.canvas.mpl_disconnect(cid)
                for cid in self._cids_pan:
                    figure.canvas.mpl_disconnect(cid)
                for cid in self._cids:
                    figure.canvas.mpl_disconnect(cid)
            self._fig_ref = None

    @property
    def figure(self):
        """The Figure this interaction is connected to or
        None if not connected."""
        return self._fig_ref() if self._fig_ref is not None else None


    def undo_last_action(self):
        """
        First, it undo the last action made by the zoom event
        Second, because the command list contains each command, the first one
        is related to adjust the zoom, which ocurred before, so the command list
        execute twice the same event, and because of that, the undo button need to
        be disabled and the command list clear
        """
        self._invokerZoom.undo()
        if self._invokerZoom.command_list_length() <= 1:
            self._invokerZoom.clear_command_list()
            pub.sendMessage('setStateUndo', state = False)

    def add_zoom_reset(self):
        if self._invokerZoom.command_list_length() == 0:
            #Send the signal to change undo button state
            pub.sendMessage('setStateUndo', state = True)

        zoomResetCommand = ZoomResetCommand(self.figure, self._xLimits, self._yLimits)
        self._invokerZoom.command(zoomResetCommand)
        self._draw_idle()

    def _add_initial_zoom_reset(self):
        if self._invokerZoom.command_list_length() == 0:
            #Send the signal to change undo button state
            pub.sendMessage('setStateUndo', state = True)

            zoomResetCommand = ZoomResetCommand(self.figure, self._xLimits, self._yLimits)
            self._invokerZoom.command(zoomResetCommand)
            self._draw_idle()


    def clear_commands(self):
        """
        Delete all commands
        """
        self._invokerZoom.clear_command_list()

    def _draw_idle(self):
        """Update altered figure, but not automatically re-drawn"""
        self.canvas.draw_idle()

    def _draw(self):
        """Conveninent method to redraw the figure"""
        self.canvas.draw()
Ejemplo n.º 44
0
class View(wx.Frame):

    def __init__(self, controller, model):
        self.controller = controller
        self.model = model
        self.zoom_controller = zoom_controller.Controller(controller, self, model)
        self.aspect = 1.0
        self.toolbar_ids = {}
        self.menubar_ids = {}
        self.connect_ids = []
        self.ov_axes = ''
        self.toggle_selector = None
        self.figure = None

        wx.Frame.__init__(self,
                          parent=None,
                          title="Coral X-Ray Viewer",
                          size=(850, 750),
                          pos=(0,0))
        self.SetMinSize((100, 100))

        self.scroll = wx.ScrolledWindow(self, -1)
        self.scroll.SetBackgroundColour('grey')    
        
        self.create_menubar()
        self.create_toolbar()
        self.create_statusbar()

        self.scroll.Bind(wx.EVT_SCROLLWIN, self.controller.on_scroll) # scroll event
        self.scroll.Bind(wx.EVT_SCROLL, self.on_scroll)
        self.Bind(wx.EVT_SIZE, self.controller.on_resize)
        self.Bind(wx.EVT_ACTIVATE, self.controller.cleanup)
        self.Bind(wx.EVT_CLOSE, self.controller.on_quit)
        self.Bind(wx.EVT_CONTEXT_MENU, self.controller.on_show_popup)

        self.Center()
        self.Show()

    def on_scroll(self, event):
        event.Skip()
        self.controller.state_changed(True)

    def create_menubar(self):
        self.menubar = wx.MenuBar()
        for name in self.menu_names():
            menu = wx.Menu()
            for each in self.menu_options()[self.menu_names().index(name)]:
                self.add_menu_option(menu, *each)
            self.menubar.Append(menu, name)
        self.SetMenuBar(self.menubar)
        
    def add_menu_option(self, menu, label, accel, handler, enabled, has_submenu, submenu):
        if not label:
            menu.AppendSeparator()
        else:
            menu_id = wx.NewId()
            self.menubar_ids[label] = menu_id
            if has_submenu:
                if label == 'Filter Plugins':
                    option = menu.AppendMenu(menu_id, label, self.plugin_submenu())
                else:
                    option = menu.AppendMenu(menu_id, label, submenu)
            else:
                option = menu.Append(menu_id, label)
            option.Enable(enabled)
            if accel:
                wx.AcceleratorTable([ (accel[0], ord(accel[1]), option.GetId()) ])
            self.Bind(wx.EVT_MENU, handler, option)
        
    def menu_names(self):
        return ('File', 'Tools', 'Help')
    
    def menu_options(self):
        """ ('TEXT', (ACCELERATOR), HANDLER, ENABLED, HAS_SUBMENU, SUBMENU METHOD """
        return ( [ # File
                  ('&Open...\tCtrl+O', (wx.ACCEL_CTRL, 'O'), self.controller.on_open, True, False, None),
                  ('&Save\tCtrl+S', (wx.ACCEL_CTRL, 'S'), self.controller.on_save, False, False, None),
                  ('Save As...', (), self.controller.on_save_as, False, False, None),
                  ('', '', '', True, False, None),
                  ('Export', (), self.controller.on_export, False, False, None),
                  ('', '', '', True, False, None),
                  ('&Quit\tCtrl+Q', (wx.ACCEL_CTRL, 'Q'), self.controller.on_quit, True, False, None)
                  ],
                 [ # Tools
                  ('Image Overview', (), self.controller.on_overview, False, False, None),
                  ('Image Information', (), self.controller.on_image_info, False, False, None),
                  ('', '', '', True, False, None),
#                  ('Rotate Image', (), self.controller.on_rotate_image, False, False, None), 
                  ('Pan Image', (), self.controller.on_pan_image_menu, False, False, None),
                  ('Rotate Image', (), self.controller.on_rotate, False, False, None),
                  ('Zoom In', (), self.zoom_controller.on_zoom_in_menu, False, False, None),
                  ('Zoom Out', (), self.zoom_controller.on_zoom_out, False, False, None),
                  ('', '', '', True, False, None),
#                  ('Adjust Contrast', (), self.controller.on_contrast, False, False, None),
#                  ('', '', '', True, False, None),
                  ('Adjust Target Area', (), self.controller.on_coral_menu, False, False, None),
                  ('', '', '', True, False, None),
                  ('Filtered Overlays', (), self.controller.on_overlay, False, False, None),
                  ('Filter Plugins', (), self.controller.on_plugin, True, True, None),
                  ('', '', '', True, False, None),
                  ('Adjust Calibration Region', (), self.controller.on_calibrate_menu, False, False, None),
                  ('Set Calibration Parameters', (), self.controller.on_density_params, False, False, None),
                  ('Show Density Chart', (), self.controller.on_show_density_chart, False, False, None),
                  ('', '', '', True, False, None),
                  ('Draw Polylines', (), self.controller.on_polyline_menu, False, False, None),
                  ],
                 [ # Help
                  ('Help', (), self.controller.on_help, True, False, None),
                  ('About', (), self.controller.on_about, True, False, None)
                  ]
                )

    def plugin_submenu(self):
        """ Creates the plugin submenu in the menubar which displays all plugins
        and allows the user to specify a secondary plugin directory.
        """
        menu = wx.Menu()

        """
        # Add a plugin from another directory to the default directory
        addPlugin = wx.MenuItem(menu, wx.ID_ANY, 'Add Plugin')
        menu.AppendItem(addPlugin)
        self.Bind(wx.EVT_MENU, self.controller.on_add_plugin, addPlugin)
        """

        # Set directory where extra plugins are held
        props = wx.MenuItem(menu, wx.ID_ANY, 'Set Directory')
        menu.AppendItem(props)
        self.Bind(wx.EVT_MENU, self.controller.on_plugin_properties, props)

        menu.AppendSeparator()

        # Get the default plugin directory, using XML
        path = os.path.expanduser('~')
        xml = xml_controller.Controller(path + '\.cxvrc.xml')
        xml.load_file()

        if os.path.exists(os.path.expanduser('~') + os.sep + "plugins"):
            default_dir = os.path.expanduser('~') + os.sep + "plugins"
        else:
            default_dir = self.get_main_dir() + os.sep + "plugins"

        if xml.get_plugin_directory() == "" or xml.get_plugin_directory() is None:
            directory = [default_dir]
        else:
            directory = [default_dir, xml.get_plugin_directory()]

        # Load the plugins from the specified plugin directory/s.
        manager = PluginManager()
        manager.setPluginPlaces(directory)
        manager.setPluginInfoExtension('plugin')
        manager.collectPlugins()

        for plugin in manager.getAllPlugins():
            item = wx.MenuItem(menu, wx.ID_ANY, plugin.name)
            menu.AppendItem(item)
            self.better_bind(wx.EVT_MENU, item, self.controller.on_about_filter, plugin)

        return menu

    def better_bind(self, evt_type, instance, handler, *args, **kwargs):
        self.Bind(evt_type, lambda event: handler(event, *args, **kwargs), instance)

    def create_toolbar(self):
        self.toolbar = self.CreateToolBar()
        for each in self.toolbar_data():
            self.add_tool(self.toolbar, *each)
        self.toolbar.Realize()
    
    def add_tool(self, toolbar, tool_type, label, bmp, handler, enabled):
        if tool_type == 'separator':
            toolbar.AddSeparator()
        elif tool_type == 'control':
            toolbar.AddControl(label)
        else:
            bmp = wx.Image(self.get_main_dir() + os.sep + bmp, wx.BITMAP_TYPE_ANY).ConvertToBitmap()
            tool_id = wx.NewId()
            self.toolbar_ids[label] = tool_id
            if tool_type == 'toggle':
                tool = toolbar.AddCheckTool(tool_id, bmp, wx.NullBitmap, label, '')
            elif tool_type == 'simple':
                tool = toolbar.AddSimpleTool(tool_id, bmp, label, '')
            toolbar.EnableTool(tool_id, enabled)
            self.Bind(wx.EVT_MENU, handler, tool)
        
    def toolbar_data(self):
        aspects = ['100%', '75%', '50%', '25%', '10%', 'Zoom to fit']
        self.aspect_cb = wx.ComboBox(self.toolbar, -1, '100%',
                                     choices=aspects,
                                     style=wx.CB_DROPDOWN)
        self.aspect_cb.SetValue('Zoom to fit')
        self.Bind(wx.EVT_COMBOBOX, self.controller.on_aspect, self.aspect_cb)
        self.Bind(wx.EVT_TEXT_ENTER, self.controller.on_aspect, self.aspect_cb)
        self.aspect_cb.Disable()
        return (# tool type, description text, icon directory, handler
                ('simple', '&Open...\tCtrl+O', 'images' + os.sep + 'open.png', self.controller.on_open, True),
                ('simple', '&Save\tCtrl+S', 'images' + os.sep + 'save.png', self.controller.on_save, False),
                ('separator', '', '', '', ''),
                ('simple', 'Image Overview', 'images' + os.sep + 'overview.png', self.controller.on_overview, False),
                ('simple', 'Image Information', 'images' + os.sep + 'info.png', self.controller.on_image_info, False),
                ('separator', '', '', '', ''),
#                ('simple', 'Rotate Image', 'images' + os.sep + 'rotate_counter-clock.png', self.controller.on_rotate_image, False),
                ('toggle', 'Pan Image', 'images' + os.sep + 'cursor_hand.png', self.controller.on_pan_image, False),
                ('simple', 'Rotate Image', 'images' + os.sep + 'rotate_image.png', self.controller.on_rotate, False),
                ('toggle', 'Zoom In', 'images' + os.sep + 'zoom_in_toolbar.png', self.zoom_controller.on_zoom_in, False),
                ('simple', 'Zoom Out', 'images' + os.sep + 'zoom_out_toolbar.png', self.zoom_controller.on_zoom_out, False),
                ('control', self.aspect_cb, '', '', ''),
                ('separator', '', '', '', ''),
#                ('simple', 'Adjust Contrast', 'images' + os.sep + 'contrast.png', self.controller.on_contrast, False),
#                ('separator', '', '', '', ''),
                ('toggle', 'Adjust Target Area', 'images' + os.sep + 'coral.png', self.controller.on_coral, False),
#                ('simple', 'Lock Target Area', 'images' + os.sep + 'lock_coral.png', self.controller.on_lock_coral, False),
#                ('separator', '', '', '', ''),
                ('simple', 'Filtered Overlays', 'images' + os.sep + 'overlay.png', self.controller.on_overlay, False),
                ('separator', '', '', '', ''),
                ('toggle', 'Adjust Calibration Region', 'images' + os.sep + 'calibrate.png', self.controller.on_calibrate, False),
                ('toggle', 'Set Calibration Parameters', 'images' + os.sep + 'density.png', self.controller.on_density_params, False),
                ('simple', 'Show Density Chart', 'images' + os.sep + 'chart_line.png', self.controller.on_show_density_chart, False),
                ('separator', '', '', '', ''),
                ('toggle', 'Draw Polylines', 'images' + os.sep + 'polyline.png', self.controller.on_polyline, False),
               )
        
    def create_statusbar(self):
        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetFieldsCount(2)
        self.statusbar.SetStatusWidths([-5, -5])
        
    def mpl_bindings(self):
        for each in self.mpl_binds():
            self.connect(*each)
            
    def connect(self, event, handler):
        cid = self.canvas.mpl_connect(event, handler)
        self.connect_ids.append(cid)
        
    def disconnect(self):
        for cid in self.connect_ids:
            self.canvas.mpl_disconnect(cid)
        self.connect_ids = []

    # matplotlib events
    def mpl_binds(self):
        return [
                ('motion_notify_event', self.controller.on_mouse_motion),
                ('figure_leave_event', self.controller.on_figure_leave),
                ('button_press_event', self.controller.on_mouse_press),
                ('button_release_event', self.controller.on_mouse_release),
                ('key_press_event', self.controller.on_key_press),
                ('key_release_event', self.controller.on_key_release)
                ]
        
    def init_plot(self, new):
        y, x = self.model.get_image_shape()
        if new:
            self.figure = Figure(figsize=(x*2/72.0, y*2/72.0), dpi=72)
            self.canvas = FigureCanvasWxAgg(self.scroll, -1, self.figure)
            self.canvas.SetBackgroundColour('grey')
        self.axes = self.figure.add_axes([0.0, 0.0, 1.0, 1.0])
        self.axes.set_axis_off()
        self.axes.imshow(self.model.get_image(), aspect='auto') # aspect='auto' sets image aspect to match the size of axes
        self.axes.set_autoscale_on(False)   # do not apply autoscaling on plot commands - VERY IMPORTANT!
        self.mpl_bindings()
        y, = self.scroll.GetSizeTuple()[-1:]
        iHt, = self.model.get_image_shape()[:-1]
        self.aspect = (float(y)/float(iHt))
        self.controller.resize_image()

        # Set the RectangleSelector so that the user can drag zoom when enabled
        rectprops = dict(facecolor='white', edgecolor = 'white', alpha=0.25, fill=True)
        self.toggle_selector = RectangleSelector(self.axes,
                                        self.zoom_controller.on_zoom,
                                        drawtype='box',
                                        useblit=True,
                                        rectprops=rectprops,
                                        button=[1], # Left mouse button
                                        minspanx=1, minspany=1,
                                        spancoords='data')
        self.toggle_selector.set_active(False)
        
    def main_is_frozen(self):
        return (hasattr(sys, "frozen") or # new py2exe
            hasattr(sys, "importers") or # old py2exe
            imp.is_frozen("__main__")) # tools/freeze
        
    def get_main_dir(self):
        if self.main_is_frozen():
            return os.path.dirname(sys.executable)
        return os.path.dirname(sys.argv[0])
Ejemplo n.º 45
0
class RectSelector:
    def __init__(self, ax, canvas):
        self.rectProps  = dict(facecolor='red', edgecolor = 'white',
                 alpha=0.5, fill=True)
        self.indicatorProps = dict(facecolor='white', edgecolor='black', alpha=0.5, fill=True)
        self.__selector = RectangleSelector(ax, self._on_select, drawtype='box', rectprops=self.rectProps)
        self.__axes     = ax
        self.__canvas   = canvas
        self.mode = None 
        # mode:
        #   None or 'rect': get the selected rect region
        #   'peak': get the peak point in the selected rect region
        self.__rect = None
        self.__peakpos = None
        self.__callback = None
        
        
    @property
    def callback(self):
        return self.__callback
    
    
    @callback.setter
    def callback(self, val):
        if not callable(val):
            raise ValueError
        self.__callback = val
        
        
    @property
    def rect(self):
        return self.__rect
    
    
    @property
    def peakpos(self):
        return self.__peakpos
        
        
    def _on_select(self, epress, erelease):
        start   = (int(epress.xdata), int(epress.ydata))
        stop    = (int(erelease.xdata), int(erelease.ydata))
        self.__rect = start + (stop[0]-start[0], stop[1]-start[1])

        if self.mode == 'peak':
            ax      = self.__axes
            data_matrix  = ax.axes.get_images()[0].get_array()
            clip_matrix  = data_matrix[start[1]:(stop[1]+1), start[0]:(stop[0]+1)]
            peak_pos     = nonzero(clip_matrix == clip_matrix.max())
            peak_pos     = (peak_pos[1][0] + start[0], peak_pos[0][0] + start[1])
            self.__peakpos = peak_pos
            circle      = Circle(peak_pos, 4, **self.indicatorProps)
            ax.add_patch(circle)
            self.__canvas.draw()
            
        self.callback(self.__rect, self.__peakpos)
        
        
    def activate(self):
        self.__selector.set_active(True)
        
    def deactivate(self):
        self.__selector.set_active(False)
        
    @property
    def is_active(self):
        return self.__selector.active