Ejemplo n.º 1
0
class SelectFromCollection(object):
    """Select indices from a matplotlib collection using `LassoSelector`.

    Selected indices are saved in the `ind` attribute. This tool highlights
    selected points by fading them out (i.e., reducing their alpha values).
    If your collection has alpha < 1, this tool will permanently alter them.

    Note that this tool selects collection objects based on their *origins*
    (i.e., `offsets`).

    Parameters
    ----------
    ax : :class:`~matplotlib.axes.Axes`
        Axes to interact with.

    collection : :class:`matplotlib.collections.Collection` subclass
        Collection you want to select from.

    alpha_other : 0 <= float <= 1
        To highlight a selection, this tool sets all selected points to an
        alpha value of 1 and non-selected points to `alpha_other`.
    """

    def __init__(self, ax, collection, alpha_other=0.3):
        self.canvas = ax.figure.canvas
        self.collection = collection
        self.alpha_other = alpha_other

        self.xys = collection.get_offsets()
        self.Npts = len(self.xys)

        # Ensure that we have separate colors for each object
        self.fc = collection.get_facecolors()
        if len(self.fc) == 0:
            raise ValueError('Collection must have a facecolor')
        elif len(self.fc) == 1:
            self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1)

        self.lasso = RectangleSelector(ax, onselect=self.onselect)  # Sprememba glede na originalno kodo
        self.ind = []

    def onselect(self, verts):
        path = Path(verts)
        self.ind = np.nonzero([path.contains_point(xy) for xy in self.xys])[0]
        self.fc[:, -1] = self.alpha_other
        self.fc[self.ind, -1] = 1
        self.collection.set_facecolors(self.fc)
        self.canvas.draw_idle()

    def disconnect(self):
        self.lasso.disconnect_events()
        self.fc[:, -1] = 1
        self.collection.set_facecolors(self.fc)
        self.canvas.draw_idle()
Ejemplo n.º 2
0
class regions(object):
    '''
    A class used for manual inspection and processing of all picks for the user.

    Examples:

    regions.chooseRectangles():
     - lets the user choose several rectangular regions in the plot

    regions.plotTracesInActiveRegions():
     - creates plots (shot.plot_traces) for all traces in the active regions (i.e. chosen by e.g. chooseRectangles)

    regions.setAllActiveRegionsForDeletion():
     - highlights all shots in a the active regions for deletion

    regions.deleteAllMarkedPicks():
     - deletes the picks (pick flag set to False) for all shots set for deletion

    regions.deselectSelection(number):
     - deselects the region of number = number

    '''
    def __init__(self, ax, cbar, survey, qt_interface=False):
        self.ax = ax
        self.cbar = cbar
        self.cbv = 'log10SNR'
        self._xlim0 = self.ax.get_xlim()
        self._ylim0 = self.ax.get_ylim()
        self._xlim = self.ax.get_xlim()
        self._ylim = self.ax.get_ylim()
        self.survey = survey
        self.shot_dict = self.survey.getShotDict()
        self._x0 = []
        self._y0 = []
        self._x1 = []
        self._y1 = []
        self._polyx = []
        self._polyy = []
        self._allpicks = None
        self.shots_found = {}
        self.shots_for_deletion = {}
        self._generateList()
        self.qt_interface = qt_interface
        if not qt_interface:
            self.buttons = {}
            self._addButtons()
        self.addTextfield()
        self.drawFigure()

    def _generateList(self):
        allpicks = []
        for shot in self.shot_dict.values():
            for traceID in shot.getTraceIDlist():
                allpicks.append(
                    (shot.getDistance(traceID),
                     shot.getPickIncludeRemoved(traceID), shot.getShotnumber(),
                     traceID, shot.getPickFlag(traceID)))

        allpicks.sort()
        self._allpicks = allpicks

    def getShotDict(self):
        return self.shot_dict

    def getShotsForDeletion(self):
        return self.shots_for_deletion

    def _onselect_clicks(self, eclick, erelease):
        '''eclick and erelease are matplotlib events at press and release'''
        print('region selected x0, y0 = (%3s, %3s), x1, y1 = (%3s, %3s)' %
              (eclick.xdata, eclick.ydata, erelease.xdata, erelease.ydata))
        x0 = min(eclick.xdata, erelease.xdata)
        x1 = max(eclick.xdata, erelease.xdata)
        y0 = min(eclick.ydata, erelease.ydata)
        y1 = max(eclick.ydata, erelease.ydata)

        shots, numtraces = self.findTracesInShotDict((x0, x1), (y0, y1))
        self.printOutput('Found %d traces in rectangle: %s' %
                         (numtraces, shots))
        key = self.getKey()
        self.shots_found[key] = {
            'shots': shots,
            'selection': 'rect',
            'xvalues': (x0, x1),
            'yvalues': (y0, y1)
        }
        self.markRectangle((x0, x1), (y0, y1), key)
        if not self.qt_interface:
            self.disconnectRect()

    def _onselect_verts(self, verts):
        x = verts[0][0]
        y = verts[0][1]
        self._polyx.append(x)
        self._polyy.append(y)

        self.drawPolyLine()

    def _onpress(self, event):
        if event.button == 3:
            self.disconnectPoly()
            self.printOutput('Disconnected polygon selection')

    def addTextfield(self, xpos=0, ypos=0.95, width=1, height=0.03):
        '''
        Adds an ax for text output to the plot.
        '''
        self.axtext = self.ax.figure.add_axes([xpos, ypos, width, height])
        self.axtext.xaxis.set_visible(False)
        self.axtext.yaxis.set_visible(False)

    def writeInTextfield(self, text=None):
        self.setXYlim(self.ax.get_xlim(), self.ax.get_ylim())
        self.axtext.clear()
        self.axtext.text(0.01,
                         0.5,
                         text,
                         verticalalignment='center',
                         horizontalalignment='left')
        self.drawFigure()

    def _addButtons(self):
        xpos1 = 0.13
        xpos2 = 0.6
        dx = 0.06
        self.addButton('Rect',
                       self.chooseRectangles,
                       xpos=xpos1,
                       color='white')
        self.addButton('Poly',
                       self.choosePolygon,
                       xpos=xpos1 + dx,
                       color='white')
        self.addButton('Plot',
                       self.plotTracesInActiveRegions,
                       xpos=xpos1 + 2 * dx,
                       color='yellow')
        self.addButton('SNR',
                       self.refreshLog10SNR,
                       xpos=xpos1 + 3 * dx,
                       color='cyan')
        self.addButton('PE',
                       self.refreshPickerror,
                       xpos=xpos1 + 4 * dx,
                       color='cyan')
        self.addButton('SPE',
                       self.refreshSPE,
                       xpos=xpos1 + 5 * dx,
                       color='cyan')
        self.addButton('DesLst',
                       self.deselectLastSelection,
                       xpos=xpos2 + dx,
                       color='green')
        self.addButton('SelAll',
                       self.setAllActiveRegionsForDeletion,
                       xpos=xpos2 + 2 * dx)
        self.addButton('DelAll',
                       self.deleteAllMarkedPicks,
                       xpos=xpos2 + 3 * dx,
                       color='red')

    def addButton(self, name, action, xpos, ypos=0.91, color=None):
        from matplotlib.widgets import Button
        self.buttons[name] = {
            'ax': None,
            'button': None,
            'action': action,
            'xpos': xpos
        }
        ax = self.ax.figure.add_axes([xpos, ypos, 0.05, 0.03])
        button = Button(ax, name, color=color, hovercolor='grey')
        button.on_clicked(action)
        self.buttons[name]['ax'] = ax
        self.buttons[name]['button'] = button
        self.buttons[name]['xpos'] = xpos

    def getKey(self):
        if self.shots_found.keys() == []:
            key = 1
        else:
            key = max(self.getShotsFound().keys()) + 1
        return key

    def drawPolyLine(self):
        self.setXYlim(self.ax.get_xlim(), self.ax.get_ylim())
        x = self._polyx
        y = self._polyy
        if len(x) >= 2 and len(y) >= 2:
            self.ax.plot(x[-2:], y[-2:], 'k', alpha=0.1, linewidth=1)
        self.drawFigure()

    def drawLastPolyLine(self):
        self.setXYlim(self.ax.get_xlim(), self.ax.get_ylim())
        x = self._polyx
        y = self._polyy
        if len(x) >= 2 and len(y) >= 2:
            self.ax.plot((x[-1], x[0]), (y[-1], y[0]), 'k', alpha=0.1)
        self.drawFigure()

    def finishPolygon(self):
        self.drawLastPolyLine()
        x = self._polyx
        y = self._polyy
        self._polyx = []
        self._polyy = []

        key = self.getKey()
        self.markPolygon(x, y, key=key)

        shots, numtraces = self.findTracesInPoly(x, y)
        self.shots_found[key] = {
            'shots': shots,
            'selection': 'poly',
            'xvalues': x,
            'yvalues': y
        }
        self.printOutput('Found %d traces in polygon: %s' % (numtraces, shots))

    def printOutput(self, text):
        print(text)
        self.writeInTextfield(text)

    def chooseRectangles(self, event=None):
        '''
        Activates matplotlib widget RectangleSelector.
        '''
        from matplotlib.widgets import RectangleSelector
        if hasattr(self, '_cidPoly'):
            self.disconnectPoly()
        self.printOutput(
            'Select rectangle is active. Press and hold left mousebutton.')
        self._cidRect = None
        self._cidRect = self.ax.figure.canvas.mpl_connect(
            'button_press_event', self._onpress)
        self._rectangle = RectangleSelector(self.ax, self._onselect_clicks)
        return self._rectangle

    def choosePolygon(self, event=None):
        '''
        Activates matplotlib widget LassoSelector.
        '''
        from matplotlib.widgets import LassoSelector
        if hasattr(self, '_cidRect'):
            self.disconnectRect()
        self.printOutput(
            'Select polygon is active. Add points with leftclick. Finish with rightclick.'
        )
        self._cidPoly = None
        self._cidPoly = self.ax.figure.canvas.mpl_connect(
            'button_press_event', self._onpress)
        self._lasso = LassoSelector(self.ax, self._onselect_verts)
        return self._lasso

    def disconnectPoly(self, event=None):
        if not hasattr(self, '_cidPoly'):
            self.printOutput('no poly selection found')
            return
        self.ax.figure.canvas.mpl_disconnect(self._cidPoly)
        del self._cidPoly
        self.finishPolygon()
        self._lasso.disconnect_events()
        print('disconnected poly selection\n')

    def disconnectRect(self, event=None):
        if not hasattr(self, '_cidRect'):
            self.printOutput('no rectangle selection found')
            return
        self.ax.figure.canvas.mpl_disconnect(self._cidRect)
        del self._cidRect
        self._rectangle.disconnect_events()
        print('disconnected rectangle selection\n')

    def deselectLastSelection(self, event=None):
        if self.shots_found.keys() == []:
            self.printOutput('No selection found.')
            return
        key = max(self.shots_found.keys())
        self.deselectSelection(key)
        self.refreshFigure()

    def deselectSelection(self, key, color='green', alpha=0.1):
        if key not in self.shots_found.keys():
            self.printOutput('No selection found.')
            return
        if color is not None:
            if self.shots_found[key]['selection'] == 'rect':
                self.markRectangle(self.shots_found[key]['xvalues'],
                                   self.shots_found[key]['yvalues'],
                                   key=key,
                                   color=color,
                                   alpha=alpha,
                                   linewidth=1)
            elif self.shots_found[key]['selection'] == 'poly':
                self.markPolygon(self.shots_found[key]['xvalues'],
                                 self.shots_found[key]['yvalues'],
                                 key=key,
                                 color=color,
                                 alpha=alpha,
                                 linewidth=1)
        value = self.shots_found.pop(key)
        self.printOutput('Deselected selection number %d' % key)

    def findTracesInPoly(self, x, y, picks='normal', highlight=True):
        def dotproduct(v1, v2):
            return sum((a * b for a, b in zip(v1, v2)))

        def getlength(v):
            return math.sqrt(dotproduct(v, v))

        def getangle(v1, v2):
            return np.rad2deg(
                math.acos(
                    dotproduct(v1, v2) / (getlength(v1) * getlength(v2))))

        def insidePoly(x, y, pickX, pickY):
            angle = 0
            epsilon = 1e-07
            for index in range(len(x)):
                xval1 = x[index - 1]
                yval1 = y[index - 1]
                xval2 = x[index]
                yval2 = y[index]
                angle += getangle([xval1 - pickX, yval1 - pickY],
                                  [xval2 - pickX, yval2 - pickY])
            if 360 - epsilon <= angle <= 360 + epsilon:  ### IMPROVE THAT??
                return True

        if len(x) == 0 or len(y) == 0:
            self.printOutput('No polygon defined.')
            return

        shots_found = {}
        numtraces = 0
        x0 = min(x)
        x1 = max(x)
        y0 = min(y)
        y1 = max(y)

        shots, numtracesrect = self.findTracesInShotDict((x0, x1), (y0, y1),
                                                         highlight=False)
        for shotnumber in shots.keys():
            shot = self.shot_dict[shotnumber]
            for traceID in shots[shotnumber]:
                if shot.getPickFlag(traceID):
                    pickX = shot.getDistance(traceID)
                    pickY = shot.getPick(traceID)
                    if insidePoly(x, y, pickX, pickY):
                        if shotnumber not in shots_found.keys():
                            shots_found[shotnumber] = []
                        shots_found[shotnumber].append(traceID)
                        if highlight == True:
                            self.highlightPick(shot, traceID)
                        numtraces += 1

        self.drawFigure()
        return shots_found, numtraces

    def findTracesInShotDict(self, xtup, ytup, picks='normal', highlight=True):
        '''
        Returns traces corresponding to a certain area in the plot with all picks over the distances.
        '''
        x0, x1 = xtup
        y0, y1 = ytup

        shots_found = {}
        numtraces = 0
        if picks == 'normal':
            pickflag = False
        elif picks == 'includeCutOut':
            pickflag = None

        for line in self._allpicks:
            dist, pick, shotnumber, traceID, flag = line
            if flag == pickflag: continue  ### IMPROVE THAT
            if (x0 <= dist <= x1 and y0 <= pick <= y1):
                if shotnumber not in shots_found.keys():
                    shots_found[shotnumber] = []
                shots_found[shotnumber].append(traceID)
                if highlight == True:
                    self.highlightPick(self.shot_dict[shotnumber], traceID)
                numtraces += 1

        self.drawFigure()
        return shots_found, numtraces

    def highlightPick(self, shot, traceID, annotations=True):
        '''
        Highlights a single pick for a shot(object)/shotnumber and traceID.
        If annotations == True: Displays shotnumber and traceID in the plot.
        '''
        if type(shot) == int:
            shot = self.survey.getShotDict()[shot]

        if not shot.getPickFlag(traceID):
            return

        self.ax.scatter(shot.getDistance(traceID),
                        shot.getPick(traceID),
                        s=50,
                        marker='o',
                        facecolors='none',
                        edgecolors='m',
                        alpha=1)
        if annotations == True:
            self.ax.annotate(s='s%s|t%s' % (shot.getShotnumber(), traceID),
                             xy=(shot.getDistance(traceID),
                                 shot.getPick(traceID)),
                             fontsize='xx-small')

    def highlightAllActiveRegions(self):
        '''
        Highlights all picks in all active regions.
        '''
        for key in self.shots_found.keys():
            for shotnumber in self.shots_found[key]['shots'].keys():
                for traceID in self.shots_found[key]['shots'][shotnumber]:
                    self.highlightPick(self.shot_dict[shotnumber], traceID)
        self.drawFigure()

    def plotTracesInActiveRegions(self,
                                  event=None,
                                  keys='all',
                                  maxfigures=20,
                                  qt=False,
                                  qtMainwindow=None):
        '''
        Plots all traces in the active region or for all specified keys.

        :param: keys
        :type: int or list

        :param: maxfigures, maximum value of figures opened
        :type: int
        '''
        if qt:
            from PySide import QtGui
            from asp3d.gui.windows import Repicking_window
            repickingQt = []

        ntraces = 0
        traces2plot = []
        if keys == 'all':
            keys = self.shots_found.keys()
        elif type(keys) == int:
            keys = [keys]

        if len(self.shots_found) > 0:
            for shot in self.shot_dict.values():
                for key in keys:
                    for shotnumber in self.shots_found[key]['shots']:
                        if shot.getShotnumber() == shotnumber:
                            for traceID in self.shots_found[key]['shots'][
                                    shotnumber]:
                                ntraces += 1
                                if ntraces > maxfigures:
                                    print(
                                        'Maximum number of figures ({}) reached. Figure #{} for shot #{} '
                                        'will not be opened.'.format(
                                            maxfigures, ntraces,
                                            shot.getShotnumber()))
                                    break
                                traces2plot.append([shot, traceID])
        else:
            self.printOutput('No picks defined in that region(s)')
            return

        if ntraces > maxfigures and qt:
            reply = QtGui.QMessageBox.question(
                qtMainwindow, 'Message', 'Maximum number of opened figures is '
                '{}. Do you want to continue without plotting all '
                '{} figures?'.format(maxfigures, ntraces),
                QtGui.QMessageBox.Yes | QtGui.QMessageBox.No,
                QtGui.QMessageBox.No)
            if not reply == QtGui.QMessageBox.Yes:
                return

        for shot, traceID in traces2plot:
            if qt:
                repickingQt.append(
                    Repicking_window(qtMainwindow, shot, traceID))
            else:
                shot.plot_traces(traceID)

        if qt:
            return repickingQt

    def setAllActiveRegionsForDeletion(self, event=None):
        keys = []
        for key in self.shots_found.keys():
            keys.append(key)
        self.setRegionForDeletion(keys)

    def setRegionForDeletion(self, keys):
        if type(keys) == int:
            keys = [keys]

        for key in keys:
            for shotnumber in self.shots_found[key]['shots'].keys():
                if shotnumber not in self.shots_for_deletion:
                    self.shots_for_deletion[shotnumber] = []
                for traceID in self.shots_found[key]['shots'][shotnumber]:
                    if traceID not in self.shots_for_deletion[shotnumber]:
                        self.shots_for_deletion[shotnumber].append(traceID)
            self.deselectSelection(key, color='red', alpha=0.2)

            self.deselectSelection(key, color='red', alpha=0.2)

        self.printOutput('Set region(s) %s for deletion' % keys)

    def markAllActiveRegions(self):
        for key in self.shots_found.keys():
            if self.shots_found[key]['selection'] == 'rect':
                self.markRectangle(self.shots_found[key]['xvalues'],
                                   self.shots_found[key]['yvalues'],
                                   key=key)
            if self.shots_found[key]['selection'] == 'poly':
                self.markPolygon(self.shots_found[key]['xvalues'],
                                 self.shots_found[key]['yvalues'],
                                 key=key)

    def markRectangle(self,
                      xtup,
                      ytup,
                      key=None,
                      color='grey',
                      alpha=0.1,
                      linewidth=1):
        '''
        Mark a rectangular region on the axes.
        '''
        from matplotlib.patches import Rectangle
        x0, x1 = xtup
        y0, y1 = ytup
        self.ax.add_patch(
            Rectangle((x0, y0),
                      x1 - x0,
                      y1 - y0,
                      alpha=alpha,
                      facecolor=color,
                      linewidth=linewidth))
        if key is not None:
            self.ax.text(x0 + (x1 - x0) / 2, y0 + (y1 - y0) / 2, str(key))
        self.drawFigure()

    def markPolygon(self,
                    x,
                    y,
                    key=None,
                    color='grey',
                    alpha=0.1,
                    linewidth=1):
        from matplotlib.patches import Polygon
        poly = Polygon(np.array(zip(x, y)),
                       color=color,
                       alpha=alpha,
                       lw=linewidth)
        self.ax.add_patch(poly)
        if key is not None:
            self.ax.text(
                min(x) + (max(x) - min(x)) / 2,
                min(y) + (max(y) - min(y)) / 2, str(key))
        self.drawFigure()

    def clearShotsForDeletion(self):
        '''
        Clears the list of shots marked for deletion.
        '''
        self.shots_for_deletion = {}
        print('Cleared all shots that were set for deletion.')

    def getShotsFound(self):
        return self.shots_found

    def deleteAllMarkedPicks(self, event=None):
        '''
        Deletes all shots set for deletion.
        '''
        if len(self.getShotsForDeletion()) is 0:
            self.printOutput('No shots set for deletion.')
            return

        for shot in self.getShotDict().values():
            for shotnumber in self.getShotsForDeletion():
                if shot.getShotnumber() == shotnumber:
                    for traceID in self.getShotsForDeletion()[shotnumber]:
                        shot.removePick(traceID)
                        print(
                            "Deleted the pick for traceID %s on shot number %s"
                            % (traceID, shotnumber))
        self.clearShotsForDeletion()
        self.refreshFigure()

    def highlightPicksForShot(self, shot, annotations=False):
        '''
        Highlight all picks for a given shot.
        '''
        if type(shot) is int:
            shot = self.survey.getShotDict()[shot.getShotnumber()]

        for traceID in shot.getTraceIDlist():
            if shot.getPickFlag(traceID):
                self.highlightPick(shot, traceID, annotations)

        self.drawFigure()

    def setXYlim(self, xlim, ylim):
        self._xlim, self._ylim = xlim, ylim

    def refreshLog10SNR(self, event=None):
        cbv = 'log10SNR'
        self.cbv = cbv
        self.refreshFigure(self, colorByVal=cbv)

    def refreshPickerror(self, event=None):
        cbv = 'pickerror'
        self.cbv = cbv
        self.refreshFigure(self, colorByVal=cbv)

    def refreshSPE(self, event=None):
        cbv = 'spe'
        self.cbv = cbv
        self.refreshFigure(self, colorByVal=cbv)

    def refreshFigure(self, event=None, colorByVal=None):
        if colorByVal == None:
            colorByVal = self.cbv
        else:
            self.cbv = colorByVal
        self.printOutput('Refreshing figure...')
        self.ax.clear()
        self.ax = self.survey.plotAllPicks(ax=self.ax,
                                           cbar=self.cbar,
                                           refreshPlot=True,
                                           colorByVal=colorByVal)
        self.setXYlim(self.ax.get_xlim(), self.ax.get_ylim())
        self.markAllActiveRegions()
        self.highlightAllActiveRegions()
        self.drawFigure()
        self.printOutput('Done!')

    def drawFigure(self, resetAxes=True):
        if resetAxes == True:
            self.ax.set_xlim(self._xlim)
            self.ax.set_ylim(self._ylim)
        self.ax.figure.canvas.draw()
class WidgetFigure(BasicFigure):
    nColorsFromColormap = Int(5)

    unlock_all_btn = Button('(Un-) Lock')
    widget_list = List()
    widget_sel = Enum(values='widget_list')
    widget_clear_btn = Button('Clear Current Widgets')

    drawn_patches = List()
    drawn_patches_names = List()

    drawn_lines = List()
    drawn_lines_names = List()

    def _widget_list_default(self):
        w = list()
        w.append('Line Selector')
        w.append('Rectangle Selector')
        return w

    def _widget_sel_default(self):
        w = 'Line Selector'
        return w

    def _widget_clear_btn_fired(self):
        if self.widget_sel == self.widget_list[0]:
            self.clear_lines()

        if self.widget_sel == self.widget_list[1]:
            self.clear_patches()

    @on_trait_change(['widget_sel', 'axes'])
    def set_selector(self, widget):
        print('Setting selector ...')
        # TODO: check if the argument of set_selector is an axes object
        # if widget == self.widget_list[0]:
        self._line_selector()

        if widget == self.widget_list[1]:
            self._rec_selector()

    def act_all(self):
        for i in self.drawn_patches:
            i.connect()
        for i in self.drawn_lines:
            i.connect()

    def _line_selector(self):
        try:
            self.rs.disconnect_events()
            DraggableResizeableRectangle.lock = True
            print('Rectangles are locked')

        except:
            print('Rectangles could not be locked')

        print(self.__class__.__name__, ": Connecting Line Selector")
        DraggableResizeableLine.lock = None
        self.ls = RectangleSelector(self.axes_selector,
                                    self.line_selector_func,
                                    drawtype='line',
                                    useblit=True,
                                    button=[3])

    def line_selector_func(self, eclick, erelease, cmap=mpl.cm.jet):
        print(self.__class__.__name__, "Line Selector:")
        print(self.__class__.__name__,
              "eclick: {} \n erelease: {}".format(eclick, erelease))
        print()

        x0, y0 = eclick.xdata, eclick.ydata
        x1, y1 = erelease.xdata, erelease.ydata

        cNorm = mpl.colors.Normalize(vmin=0, vmax=self.nColorsFromColormap)
        scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
        color = scalarMap.to_rgba(len(self.drawn_lines) + 1)
        text = 'line ' + str(len(self.drawn_lines))
        line = AnnotatedLine(self.axes_selector,
                             x0,
                             y0,
                             x1,
                             y1,
                             text=text,
                             color=color)
        self.drawn_lines_names.append(line.text)
        self.drawn_lines.append(line)
        self.canvas.draw()

    def get_widget_line(self, line_name):
        line_handle = None
        for i, line in enumerate(self.drawn_lines):
            if line.text == line_name:
                line_handle = line
                break
        return line_handle

    def clear_lines(self):
        print(self.__class__.__name__, ": Clearing selection lines")
        if len(self.drawn_lines) != 0:
            print(self.__class__.__name__, ": Clearing selection lines")
            for l in self.drawn_lines:
                try:
                    l.remove()
                except ValueError:
                    print(self.__class__.__name__, ": Line was not found.")
            self.drawn_lines = []
            self.drawn_lines_names = []

        self.canvas.draw()

    def _rec_selector(self):
        try:
            self.ls.disconnect_events()
            DraggableResizeableLine.lock = True
            print('Line Selector is locked')
        except:
            print('Line Selector could not be locked')
        DraggableResizeableRectangle.lock = None

        print(self.__class__.__name__, ": Connecting Rectangle Selector")

        self.rs = RectangleSelector(self.axes_selector,
                                    self.rectangle_selector_func,
                                    drawtype='box',
                                    useblit=True,
                                    button=[3])

    def rectangle_selector_func(self, eclick, erelease, cmap=mpl.cm.jet):
        """
            Usage:
            @on_trait_change('fig:selectionPatches:rectUpdated')
            function name:
                for p in self.fig.selectionPatches:
                    do p

        """
        print(self.__class__.__name__, "Rectangle Selector:")
        print(self.__class__.__name__,
              "eclick: {} \n erelease: {}".format(eclick, erelease))
        print()

        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        cNorm = mpl.colors.Normalize(vmin=0, vmax=self.nColorsFromColormap)
        scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)

        color = scalarMap.to_rgba(len(self.drawn_patches) + 1)

        self.an_rect = AnnotatedRectangle(self.axes_selector,
                                          x1,
                                          y1,
                                          x2,
                                          y2,
                                          'region ' +
                                          str(len(self.drawn_patches)),
                                          color=color)
        self.drawn_patches_names.append(self.an_rect.text)
        self.drawn_patches.append(self.an_rect)
        self.canvas.draw()

    def get_widget_patch(self, patch_name):
        patch = None
        for i, rect in enumerate(self.drawn_patches):
            if rect.text == patch_name:
                patch = rect
                break
        return patch

    def clear_patches(self):
        if len(self.drawn_patches) != 0:
            print(self.__class__.__name__, ": Clearing selection patches")
            for p in self.drawn_patches:
                try:
                    p.remove()
                except ValueError:
                    print(self.__class__.__name__, ": Patch was not found.")
            DraggableResizeableLine.reset_borders()
            DraggableResizeableRectangle.reset_borders()
            self.drawn_patches = []
            self.canvas.draw()

    def clear_widgets(self):
        self.clear_patches()
        self.clear_lines()

    def options_group(self):
        g = HGroup(
            UItem('options_btn'),
            UItem('clear_btn'),
            UItem('line_selector', visible_when='not img_bool'),
            UItem('copy_data_btn', visible_when='not img_bool'),
            HGroup(
                VGroup(
                    HGroup(
                        Item('normalize_bool', label='normalize'),
                        Item('log_bool', label='log scale'),
                        Item('cmap_selector',
                             label='cmap',
                             visible_when='img_bool'),
                        UItem('image_slider_btn', visible_when='img_bool'),
                        UItem('save_fig_btn'),
                    ), HGroup(
                        UItem('widget_sel'),
                        UItem('widget_clear_btn'),
                    ))),
        )
        return g
Ejemplo n.º 4
0
class SelectFromCollection(object):
    """Select points from a matplotlib collection using `RectangleSelector`.

    Selected indices are saved in the `ind` attribute. This tool fades out the
    points that are not part of the selection (i.e., reduces their alpha
    values). If your collection has alpha < 1, this tool will permanently
    alter the alpha values.

    Note that this tool selects collection objects based on their *origins*
    (i.e., `offsets`).

    Parameters
    ----------
    ax : :class:`~matplotlib.axes.Axes`
        Axes to interact with.

    collection : :class:`matplotlib.collections.Collection` subclass
        Collection you want to select from.

    alpha_other : 0 <= float <= 1
        To highlight a selection, this tool sets all selected points to an
        alpha value of 1 and non-selected points to `alpha_other`.
    """

    def __init__(self, ax, collection, alpha_other=0.3):
        self.ax = ax
        self.canvas = ax.figure.canvas
        self.collection = collection
        self.alpha_other = alpha_other

        self.xys = np.array(collection.get_offsets())
        self.Npts = len(self.xys)

        # Ensure that we have separate colors for each object
        self.fc = collection.get_facecolors()
        assert len(self.fc)==1
        self.orig_fc = np.array(self.fc[0])
        if len(self.fc) == 0:
            raise ValueError('Collection must have a facecolor')
        elif len(self.fc) == 1:
            self.fc = np.tile(self.fc, (self.Npts, 1))

        self.selector = RectangleSelector(ax,
                         onselect=self.onselect, useblit=False)
        self.ind = []

    def onselect(self, eclick, erelease):
        lowerx = min(eclick.xdata, erelease.xdata)
        upperx = max(eclick.xdata, erelease.xdata)
        lowery = min(eclick.ydata, erelease.ydata)
        uppery = max(eclick.ydata, erelease.ydata)
        self.ind = np.nonzero((self.xys[:,0] >= lowerx)
                              *(self.xys[:,0] <= upperx)
                              *(self.xys[:,1] >= lowery)
                              *(self.xys[:,1] <= uppery))[0]
        self.fc[:, -1] = self.alpha_other
        self.fc[:, 0:3] = self.orig_fc[None,:3]
        #Red color for selection
        self.fc[self.ind, -1] = 1
        self.fc[self.ind, 0] = 1
        self.fc[self.ind, 1] = 0
        self.fc[self.ind, 2] = 0
        self.collection.set_facecolors(self.fc)
        self.canvas.draw_idle()

    def disconnect(self):
        self.selector.disconnect_events()
        self.fc[:, -1] = 1
        self.collection.set_facecolors(self.fc)
        self.canvas.draw_idle()
Ejemplo n.º 5
0
class PatternSelectionWidget(QWidget, DockMixin):
    """A wdiget to select patterns in the image

    This widget consist of an :class:`EmbededMplCanvas` to display the template
    for the pattern and uses the :func:`skimage.feature.match_template`
    function to identify it in the :attr:`arr`

    See Also
    --------
    straditize.widget.selection_toolbar.SelectionToolbar.start_pattern_selection
    """

    #: The template to look for in the :attr:`arr`
    template = None

    #: The selector to select the template in the original image
    selector = None

    #: The extents of the :attr:`template` in the original image
    template_extents = None

    #: The matplotlib artist of the :attr:`template` in the
    #: :attr:`template_fig`
    template_im = None

    #: The :class:`EmbededMplCanvas` to display the :attr:`template`
    template_fig = None

    axes = None

    #: A QSlider to set the threshold for the template correlation
    sl_thresh = None

    _corr_plot = None

    key_press_cid = None

    def __init__(self, arr, data_obj, remove_selection=False, *args, **kwargs):
        """
        Parameters
        ----------
        arr: np.ndarray of shape ``(Ny, Nx)``
            The labeled selection array
        data_obj: straditize.label_selection.LabelSelection
            The data object whose image shall be selected
        remove_selection: bool
            If True, remove the selection on apply
        """
        super(PatternSelectionWidget, self).__init__(*args, **kwargs)
        self.arr = arr
        self.data_obj = data_obj
        self.remove_selection = remove_selection
        self.template = None

        # the figure to show the template
        self.template_fig = EmbededMplCanvas()
        # the button to select the template
        self.btn_select_template = QPushButton('Select a template')
        self.btn_select_template.setCheckable(True)
        # the checkbox to allow fractions of the template
        self.fraction_box = QGroupBox('Template fractions')
        self.fraction_box.setCheckable(True)
        self.fraction_box.setChecked(False)
        self.fraction_box.setEnabled(False)
        self.sl_fraction = QSlider(Qt.Horizontal)
        self.lbl_fraction = QLabel('0.75')
        self.sl_fraction.setValue(75)
        # the slider to select the increments of the fractions
        self.sl_increments = QSlider(Qt.Horizontal)
        self.sl_increments.setValue(3)
        self.sl_increments.setMinimum(1)
        self.lbl_increments = QLabel('3')
        # the button to perform the correlation
        self.btn_correlate = QPushButton('Find template')
        self.btn_correlate.setEnabled(False)
        # the button to plot the correlation
        self.btn_plot_corr = QPushButton('Plot correlation')
        self.btn_plot_corr.setCheckable(True)
        self.btn_plot_corr.setEnabled(False)
        # slider for subselection
        self.btn_select = QPushButton('Select pattern')
        self.sl_thresh = QSlider(Qt.Horizontal)
        self.lbl_thresh = QLabel('0.5')

        self.btn_select.setCheckable(True)
        self.btn_select.setEnabled(False)
        self.sl_thresh.setValue(75)
        self.sl_thresh.setVisible(False)
        self.lbl_thresh.setVisible(False)

        # cancel and close button
        self.btn_cancel = QPushButton('Cancel')
        self.btn_close = QPushButton('Apply')
        self.btn_close.setEnabled(False)

        vbox = QVBoxLayout()

        vbox.addWidget(self.template_fig)
        hbox = QHBoxLayout()
        hbox.addStretch(0)
        hbox.addWidget(self.btn_select_template)
        vbox.addLayout(hbox)

        fraction_layout = QGridLayout()
        fraction_layout.addWidget(QLabel('Fraction'), 0, 0)
        fraction_layout.addWidget(self.sl_fraction, 0, 1)
        fraction_layout.addWidget(self.lbl_fraction, 0, 2)
        fraction_layout.addWidget(QLabel('Increments'), 1, 0)
        fraction_layout.addWidget(self.sl_increments, 1, 1)
        fraction_layout.addWidget(self.lbl_increments, 1, 2)

        self.fraction_box.setLayout(fraction_layout)

        vbox.addWidget(self.fraction_box)
        vbox.addWidget(self.btn_correlate)
        vbox.addWidget(self.btn_plot_corr)
        vbox.addWidget(self.btn_select)
        thresh_box = QHBoxLayout()
        thresh_box.addWidget(self.sl_thresh)
        thresh_box.addWidget(self.lbl_thresh)
        vbox.addLayout(thresh_box)

        hbox = QHBoxLayout()
        hbox.addWidget(self.btn_cancel)
        hbox.addWidget(self.btn_close)
        vbox.addLayout(hbox)
        self.setLayout(vbox)

        self.btn_select_template.clicked.connect(
            self.toggle_template_selection)
        self.sl_fraction.valueChanged.connect(
            lambda i: self.lbl_fraction.setText(str(i / 100.)))
        self.sl_increments.valueChanged.connect(
            lambda i: self.lbl_increments.setText(str(i)))
        self.btn_correlate.clicked.connect(self.start_correlation)
        self.btn_plot_corr.clicked.connect(self.toggle_correlation_plot)
        self.sl_thresh.valueChanged.connect(
            lambda i: self.lbl_thresh.setText(str((i - 50) / 50.)))
        self.sl_thresh.valueChanged.connect(self.modify_selection)
        self.btn_select.clicked.connect(self.toggle_selection)

        self.btn_cancel.clicked.connect(self.cancel)
        self.btn_close.clicked.connect(self.close)

    def toggle_template_selection(self):
        """Enable or disable the template selection"""
        if (not self.btn_select_template.isChecked()
                and self.selector is not None):
            self.selector.set_active(False)
            for a in self.selector.artists:
                a.set_visible(False)
            self.btn_select_template.setText('Select a template')
        elif self.selector is not None and self.template_im is not None:
            self.selector.set_active(True)
            for a in self.selector.artists:
                a.set_visible(True)
            self.btn_select_template.setText('Apply')
        else:
            self.selector = RectangleSelector(self.data_obj.ax,
                                              self.update_image,
                                              interactive=True)
            if self.template_extents is not None:
                self.selector.draw_shape(self.template_extents)
            self.key_press_cid = self.data_obj.ax.figure.canvas.mpl_connect(
                'key_press_event', self.update_image)
            self.btn_select_template.setText('Cancel')
        self.data_obj.draw_figure()
        if self.template is not None:
            self.fraction_box.setEnabled(True)
            self.sl_increments.setMaximum(min(self.template.shape[:2]))
            self.btn_correlate.setEnabled(True)

    def update_image(self, *args, **kwargs):
        """Update the template image based on the :attr:`selector` extents"""
        if self.template_im is not None:
            self.template_im.remove()
            del self.template_im
        elif self.axes is None:
            self.axes = self.template_fig.figure.add_subplot(111)
            self.template_fig.figure.subplots_adjust(bottom=0.3)
        if not self.selector.artists[0].get_visible():
            self.template_extents = None
            self.template = None
            self.btn_select_template.setText('Cancel')
        else:
            self.template_extents = np.round(self.selector.extents).astype(int)
            x, y = self.template_extents.reshape((2, 2))
            if getattr(self.data_obj, 'extent', None) is not None:
                extent = self.data_obj.extent
                x -= int(min(extent[:2]))
                y -= int(min(extent[2:]))
            slx = slice(*sorted(x))
            sly = slice(*sorted(y))
            self.template = template = self.arr[sly, slx]
            if template.ndim == 3:
                self.template_im = self.axes.imshow(template)
            else:
                self.template_im = self.axes.imshow(template, cmap='binary')
            self.btn_select_template.setText('Apply')
        self.template_fig.draw()

    def start_correlation(self):
        """Look for the correlations of template and source"""
        if self.fraction_box.isChecked():
            self._fraction = self.sl_fraction.value() / 100.
            increments = self.sl_increments.value()
        else:
            self._fraction = 0
            increments = 1
        corr = self.correlate_template(self.arr, self.template, self._fraction,
                                       increments)
        if corr is not None:
            self._correlation = corr
        enable = self._correlation is not None
        self.btn_plot_corr.setEnabled(enable)
        self.btn_select.setEnabled(enable)

    def toggle_selection(self):
        """Modifiy the selection (or not) based on the template correlation"""
        obj = self.data_obj
        if self.btn_select.isChecked():
            self._orig_selection_arr = obj._selection_arr.copy()
            self._selected_labels = obj.selected_labels
            self._select_cmap = obj._select_cmap
            self._select_norm = obj._select_norm
            self.btn_select.setText('Reset')
            self.btn_close.setEnabled(True)
            obj.unselect_all_labels()
            self.sl_thresh.setVisible(True)
            self.lbl_thresh.setVisible(True)
            self.modify_selection(self.sl_thresh.value())
        else:
            if obj._selection_arr is not None:
                obj._selection_arr[:] = self._orig_selection_arr
                obj._select_img.set_array(self._orig_selection_arr)
                obj.select_labels(self._selected_labels)
                obj._update_magni_img()
            del self._orig_selection_arr, self._selected_labels
            self.btn_select.setText('Select pattern')
            self.btn_close.setEnabled(False)
            self.sl_thresh.setVisible(False)
            self.lbl_thresh.setVisible(False)
            obj.draw_figure()

    def modify_selection(self, i):
        """Modify the selection based on the correlation threshold

        Parameters
        ----------
        i: int
            An integer between 0 and 100, the value of the :attr:`sl_thresh`
            slider"""
        if not self.btn_select.isChecked():
            return
        obj = self.data_obj
        val = (i - 50.) / 50.
        # select the values above 50
        if not self.remove_selection:
            # clear the selection
            obj._selection_arr[:] = obj._orig_selection_arr.copy()
            select_val = obj._selection_arr.max() + 1
            obj._selection_arr[self._correlation >= val] = select_val
        else:
            obj._selection_arr[:] = self._orig_selection_arr.copy()
            obj._selection_arr[self._correlation >= val] = -1
        obj._select_img.set_array(obj._selection_arr)
        obj._update_magni_img()
        obj.draw_figure()

    def correlate_template(self,
                           arr,
                           template,
                           fraction=False,
                           increment=1,
                           report=True):
        """Correlate a template with the `arr`

        This method uses the :func:`skimage.feature.match_template` function
        to find the given `template` in the source array `arr`.

        Parameters
        ----------
        arr: np.ndarray of shape ``(Ny,Nx)``
            The labeled selection array (see :attr:`arr`), the source of the
            given `template`
        template: np.ndarray of shape ``(nx, ny)``
            The template from ``arr`` that shall be searched
        fraction: float
            If not null, we will look through the given fraction of the
            template to look for partial matches as well
        increment: int
            The increment of the loop with the `fraction`.
        report: bool
            If True and `fraction` is not null, a QProgressDialog is opened
            to inform the user about the progress"""
        from skimage.feature import match_template
        mask = self.data_obj.selected_part
        x = mask.any(axis=0)
        if not x.any():
            raise ValueError("No data selected!")
        y = mask.any(axis=1)
        xmin = x.argmax()
        xmax = len(x) - x[::-1].argmax()
        ymin = y.argmax()
        ymax = len(y) - y[::-1].argmax()
        if arr.ndim == 3:
            mask = np.tile(mask[..., np.newaxis], (1, 1, arr.shape[-1]))
        src = np.where(mask[ymin:ymax, xmin:xmax], arr[ymin:ymax, xmin:xmax],
                       0)
        sny, snx = src.shape
        if not fraction:
            corr = match_template(src, template)
            full_shape = np.array(corr.shape)
        else:  # loop through the template to allow partial hatches
            shp = np.array(template.shape, dtype=int)[:2]
            ny, nx = shp
            fshp = np.round(fraction * shp).astype(int)
            fny, fnx = fshp
            it = list(
                product(range(0, fny, increment), range(0, fnx, increment)))
            ntot = len(it)
            full_shape = fshp - shp + src.shape
            corr = np.zeros(full_shape, dtype=float)
            if report:
                txt = 'Searching template...'
                dialog = QProgressDialog(txt, 'Cancel', 0, ntot)
                dialog.setWindowModality(Qt.WindowModal)
                t0 = dt.datetime.now()
            for k, (i, j) in enumerate(it):
                if report:
                    dialog.setValue(k)
                    if k and not k % 10:
                        passed = (dt.datetime.now() - t0).total_seconds()
                        dialog.setLabelText(txt + ' %1.0f seconds remaning' %
                                            ((passed * (ntot / k - 1.))))
                if report and dialog.wasCanceled():
                    return
                else:
                    y_end, x_start = fshp - (i, j) - 1
                    sly = slice(y_end, full_shape[0])
                    slx = slice(0, -x_start or full_shape[1])
                    corr[sly, slx] = np.maximum(
                        corr[sly, slx],
                        match_template(src, template[:-i or ny, j:]))
        ret = np.zeros_like(arr, dtype=corr.dtype)
        dny, dnx = src.shape - full_shape
        for i, j in product(range(dny + 1), range(dnx + 1)):
            ret[ymin + i:ymax - dny + i, xmin + j:xmax - dnx + j] = np.maximum(
                ret[ymin + i:ymax - dny + i, xmin + j:xmax - dnx + j], corr)
        return np.where(mask, ret, 0)

    def toggle_correlation_plot(self):
        """Toggle the correlation plot between :attr:`template` and :attr:`arr`
        """
        obj = self.data_obj
        if self._corr_plot is None:
            self._corr_plot = obj.ax.imshow(
                self._correlation,
                extent=obj._select_img.get_extent(),
                zorder=obj._select_img.zorder + 0.1)
            self._corr_cbar = obj.ax.figure.colorbar(self._corr_plot,
                                                     orientation='vertical')
            self._corr_cbar.set_label('Correlation')
        else:
            for a in [self._corr_cbar, self._corr_plot]:
                try:
                    a.remove()
                except ValueError:
                    pass
            del self._corr_plot, self._corr_cbar
        obj.draw_figure()

    def to_dock(self,
                main,
                title=None,
                position=None,
                docktype='df',
                *args,
                **kwargs):
        if position is None:
            position = main.dockWidgetArea(main.help_explorer.dock)
        connect = self.dock is None
        ret = super(PatternSelectionWidget, self).to_dock(main,
                                                          title,
                                                          position,
                                                          docktype=docktype,
                                                          *args,
                                                          **kwargs)
        if connect:
            self.dock.toggleViewAction().triggered.connect(self.maybe_tabify)
        return ret

    def maybe_tabify(self):
        main = self.dock.parent()
        if self.is_shown and main.dockWidgetArea(
                main.help_explorer.dock) == main.dockWidgetArea(self.dock):
            main.tabifyDockWidget(main.help_explorer.dock, self.dock)

    def cancel(self):
        if self.btn_select.isChecked():
            self.btn_select.setChecked(False)
            self.toggle_selection()
        self.close()

    def close(self):
        from psyplot_gui.main import mainwindow
        if self.selector is not None:
            self.selector.disconnect_events()
            for a in self.selector.artists:
                try:
                    a.remove()
                except ValueError:
                    pass
            self.data_obj.draw_figure()
            del self.selector
        if self._corr_plot is not None:
            self.toggle_correlation_plot()
        if self.key_press_cid is not None:
            self.data_obj.ax.figure.canvas.mpl_disconnect(self.key_press_cid)
        for attr in ['data_obj', 'arr', 'template', 'key_press_cid']:
            try:
                delattr(self, attr)
            except AttributeError:
                pass
        mainwindow.removeDockWidget(self.dock)
        return super(PatternSelectionWidget, self).close()
Ejemplo n.º 6
0
class SelectFromCollection(object):
    """Select indices from a matplotlib collection using `LassoSelector`.

    Selected indices are saved in the `ind` attribute. This tool highlights
    selected points by fading them out (i.e., reducing their alpha values).
    If your collection has alpha < 1, this tool will permanently alter them.

    Note that this tool selects collection objects based on their *origins*
    (i.e., `offsets`).

    Parameters
    ----------
    ax : :class:`~matplotlib.axes.Axes`
        Axes to interact with.

    collection : :class:`matplotlib.collections.Collection` subclass
        Collection you want to select from.

    alpha_other : 0 <= float <= 1
        To highlight a selection, this tool sets all selected points to an
        alpha value of 1 and non-selected points to `alpha_other`.
    """
    def __init__(self, ax, xcollection, ycollection, gid, parent, method):
        self.ax = ax
        self.canvas = ax.figure.canvas
        self.x = xcollection
        self.y = ycollection
        self.parent = parent
        self.gid = gid

        if method == 'manual':
            self.lasso = RectangleSelector(ax, onselect=self.onselect)
        elif method == 'time':
            w = CalendarDialog(self.ax.get_xlim())
            values = w.getResults()
            if values:
                xmin = datetime(values[0][0], values[0][1], values[0][2])
                xmax = datetime(values[1][0], values[1][1], values[1][2])
                self.selector(
                    lim=[date2num(xmin),
                         date2num(xmax), -np.inf, np.inf])
                self.canvas.draw_idle()
        elif method == 'peaks':
            w = PeaksDialog()
            values = w.getResults()
            if values:
                self.selector(peaks=values)
                self.canvas.draw_idle()

        elif method == 'spanselector':

            self.span = SpanSelector(
                ax,
                self.printspanselector,
                "horizontal",
                useblit=True,
                rectprops=dict(alpha=0.5, facecolor="red"),
            )
            self.canvas.draw_idle()

        self.ind = []

    def printspanselector(self, xmin, xmax):
        for child in self.parent.plotCanvas.fig1.get_children():
            if child.get_gid() == 'ax':
                objs = child.get_children()

        statf = ['mean', 'min', 'max', [25, 50, 75]]
        mat = []
        columnName = []
        for stat in statf:
            if isinstance(stat, str):
                columnName.append(stat)
            elif isinstance(stat, list):
                for p in stat:
                    columnName.append('P' + str(p))

        rowName = []
        color = []
        for i in range(0, len(self.x)):
            gid = self.gid[i]
            for obj in objs:
                if hasattr(obj, 'get_xydata') and obj.get_gid() == gid:
                    color.append(obj.get_color())
                    break

            Y = self.y[i]

            rowName.append(gid)
            row = []
            X = self.x[i]
            if isinstance(X[0], np.datetime64):
                X = date2num(X)

            ind = np.nonzero(((X >= xmin) & (X <= xmax)))[0]
            for stat in statf:
                if isinstance(stat, str):
                    fct = getattr(np, 'nan' + stat)
                    row.append('%.2f' % fct(Y[ind]))
                elif isinstance(stat, list):
                    perc = list(np.nanpercentile(Y[ind], stat))
                    row += ['%.2f' % x for x in perc]

            mat.append(row)

        tb = self.ax.table(cellText=mat,
                           rowLabels=rowName,
                           colLabels=columnName,
                           loc='top',
                           cellLoc='center')
        for k, cell in six.iteritems(tb._cells):
            cell.set_edgecolor('black')
            if k[0] == 0:
                cell.set_text_props(weight='bold', color='w')
                cell.set_facecolor(self.ax.get_facecolor())
            else:
                cell.set_text_props(color=color[k[0] - 1])
                cell.set_facecolor(self.ax.get_facecolor())

        #tb._cells[(0, 0)].set_facecolor(self.ax.get_facecolor())
        self.span.set_visible(False)

    def selector(self, lim=None, peaks=None):
        if lim:
            xmin = lim[0]
            xmax = lim[1]
            ymin = lim[2]
            ymax = lim[3]

        for i in range(0, len(self.x)):
            Y = self.y[i]
            gid = self.gid[i]
            X = self.x[i]
            if isinstance(X[0], np.datetime64):
                X = date2num(X)

            if lim:
                self.ind =np.nonzero(((X>=xmin) & (X<=xmax))\
                    & ((Y>=ymin) & (Y<=ymax)))[0]

            if peaks:
                self.ind = find_peaks(Y, **peaks)[0]

            x_data = X[self.ind]
            y_data = Y[self.ind]

            self.ax.plot(x_data, y_data, 'r+', gid='selected_' + gid)
        self.ax.set_xlim(X[0], X[-1])

    def onselect(self, eclick, erelease):
        # if self.parent._active == "ZOOM" or self.parent._active == "PAN":
        #         return
        self.selector(
            lim=[eclick.xdata, erelease.xdata, eclick.ydata, erelease.ydata])

    def disconnect(self):
        if hasattr(self, 'lasso'):
            self.lasso.disconnect_events()
        self.canvas.draw_idle()
class WidgetFigure(BasicFigure):
    nColorsFromColormap = Int(5)

    unlock_all_btn = Button('(Un-) Lock')
    widget_list = List()
    widget_sel = Enum(values='widget_list')
    widget_clear_btn = Button('Clear Current Widgets')

    drawn_patches = List()
    drawn_patches_names = List()
    patch_data = List  # [patch no, 2D Arr]

    drawn_lines = List()
    drawn_lines_names= List()
    drawn_lines_selector = Enum(values='drawn_lines_names')

    line_width = Range(0, 1000, 0, tooltip='average over number of lines in both directions.')
    line_interpolation_order = Range(0, 5, 1)
    line_data = List  # [line no, xvals, yvals]

    def _widget_list_default(self):
        w = list()
        w.append('Line Selector')
        w.append('Rectangle Selector')
        return w

    def _widget_sel_default(self):
        w = 'Line Selector'
        return w

    def _widget_clear_btn_fired(self):
        if self.widget_sel == self.widget_list[0]:
            self.clear_lines()

        if self.widget_sel == self.widget_list[1]:
            self.clear_patches()

    @on_trait_change('widget_sel')
    def set_selector(self,widget):
        if widget == self.widget_list[0]:
            self._line_selector()

        if widget == self.widget_list[1]:
            self._rectangle_selector()

    def act_all(self):
        for i in self.drawn_patches: i.connect()
        for i in self.drawn_lines: i.connect()

    @on_trait_change('drawn_lines:lineReleased, drawn_lines:lineUpdated, line_width, line_interpolation_order, img_data[]')
    def get_line_data(self):
        line_data = []
        for line in self.drawn_lines_names:
            x, y = self.get_widget_line(line).line.get_data()

            len_x = abs(x[1] - x[0])
            len_y = abs(y[1] - y[0])
            len_line = np.sqrt(len_x ** 2 + len_y ** 2)
            x_float = np.linspace(x[0], x[1], len_line)
            y_float = np.linspace(y[0], y[1], len_line)
            x, y = x_float.astype(np.int), y_float.astype(np.int)

            data = []
            for i in range(-self.line_width, self.line_width + 1):
                n1, n2 = self.get_normal(x[0], x[1], y[0], y[1])
                n1 = int(n1)
                n2 = int(n2)

                zi = ndimage.map_coordinates(self.img_data, np.vstack((y_float+n2*i, x_float+n1*i)),
                                             order=self.line_interpolation_order)

                data.append(zi)

            line_cut_mean = np.mean(data, axis=0)
            xvals = np.arange(0, line_cut_mean.shape[0], 1)
            line_data.append(np.array([xvals, line_cut_mean]))


        self.line_data = line_data

    @staticmethod
    def get_normal(x1, x2, y1, y2):
        """
        calculates the normalized normal vector to the straight line defined by x1, x2, y1, y2

        thx Denis!

        :param x1: float
        :param x2: float
        :param y1: float
        :param y2: float
        :return: normalized normal vector with one component = 1 and one component < 1
        """
        delta_x = float(x1-x2)
        delta_y = float(y1-y2)
        if delta_y != 0:
            n1 = 1.0
            n2 = -delta_x/delta_y
            if abs(n2) > 1.0:
                n1 = n1/n2
                n2 = 1.0
        else:
            n1 = 0.0
            n2 = 1.0
        return n1, n2

    @on_trait_change('line_width')
    def _update_line_width(self):
        if self.line_selector is None:
            line = 'line 0'
        else:
            line = self.get_widget_line(self.line_selector)

        self.get_widget_line(line).drl.line.set_linewidth(2 * self.line_width + 1)  # how to avoid this?!
        self.get_widget_line(line).drl.line.set_alpha((np.exp(-self.line_width / 20)))  # exponentially decreasing alpha
        self.draw()


    def _line_selector(self):
        try:
            self.rs.disconnect_events()
            DraggableResizeableRectangle.lock = True
            print('Rectangles are locked')

        except:
            print('Rectangles could not be locked')

        print(self.__class__.__name__, ": Connecting Line Selector")
        DraggableResizeableLine.lock = None
        self.ls = RectangleSelector(self.axes_selector, self.line_selector_func, drawtype='line', useblit=True, button=[3])

    def line_selector_func(self, eclick, erelease, cmap=mpl.cm.jet):
        print(self.__class__.__name__, "Line Selector:")
        print(self.__class__.__name__, "eclick: {} \n erelease: {}".format(eclick, erelease))
        print()

        x0, y0 = eclick.xdata, eclick.ydata
        x1, y1 = erelease.xdata, erelease.ydata

        cNorm = mpl.colors.Normalize(vmin=0, vmax=self.nColorsFromColormap)
        scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)
        color = scalarMap.to_rgba(len(self.drawn_lines) + 1)
        text = 'line ' + str(len(self.drawn_lines))
        line = AnnotatedLine(self.axes_selector,x0, y0, x1, y1, text=text, color=color)
        self.drawn_lines_names.append(line.text)
        self.drawn_lines.append(line)
        self.canvas.draw()

    def get_widget_line(self, line_name):
        line_handle = None
        for i, line in enumerate(self.drawn_lines):
            if line.text == line_name:
                line_handle = line
                break
        return line_handle

    def clear_lines(self):
        print(self.__class__.__name__, ": Clearing selection lines")
        if len(self.drawn_lines) != 0:
            print(self.__class__.__name__, ": Clearing selection lines")
            for l in self.drawn_lines:
                try:
                    l.remove()
                except ValueError:
                    print(self.__class__.__name__, ": Line was not found.")
            self.drawn_lines = []
            self.drawn_lines_names = []

        self.canvas.draw()

    @on_trait_change('drawn_patches:rectUpdated')
    def calculate_picture_region_sum(self):
        data = []
        for p in self.drawn_patches:
            x1, y1 = p.rectangle.get_xy()
            x2 = x1 + p.rectangle.get_width()
            y2 = y1 + p.rectangle.get_height()

            if p.rectangle.get_width() < 0:
                x2, x1 = x1, x2
            if p.rectangle.get_height() < 0:
                y2, y1 = y1, y2
            if p.rectangle.get_width() == 0 or p.rectangle.get_height() == 0:
                print('Zero Patch dimension')

            # data & extent
            data.append([self.img_data[int(y1):int(y2),int(x1):int(x2)], [int(x1), int(x2), int(y1), int(y2)]])

        self.patch_data = data

    def _rectangle_selector(self):
        try:
            self.ls.disconnect_events()
            DraggableResizeableLine.lock = True
            print('Line Selector is locked')
        except:
            print('Line Selector could not be locked')
        DraggableResizeableRectangle.lock = None

        print(self.__class__.__name__, ": Connecting Rectangle Selector")

        self.rs = RectangleSelector(self.axes_selector, self.rectangle_selector_func, drawtype='box', useblit=True, button=[3])


    def rectangle_selector_func(self, eclick, erelease, cmap=mpl.cm.jet):
        """
            Usage:
            @on_trait_change('fig:selectionPatches:rectUpdated')
            function name:
                for p in self.fig.selectionPatches:
                    do p

        """
        print(self.__class__.__name__, "Rectangle Selector:")
        print(self.__class__.__name__, "eclick: {} \n erelease: {}".format(eclick, erelease))
        print()

        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        cNorm = mpl.colors.Normalize(vmin=0, vmax=self.nColorsFromColormap)
        scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cmap)

        color = scalarMap.to_rgba(len(self.drawn_patches) + 1)

        self.an_rect = AnnotatedRectangle(self.axes_selector, x1, y1, x2, y2, 'region ' + str(len(self.drawn_patches)), color=color)
        self.drawn_patches_names.append(self.an_rect.text)
        self.drawn_patches.append(self.an_rect)
        self.canvas.draw()

    def get_widget_patch(self, patch_name):
        patch = None
        for i, rect in enumerate(self.drawn_patches):
            if rect.text == patch_name:
                patch = rect
                break
        return patch

    def clear_patches(self):
        if len(self.drawn_patches) != 0:
            print(self.__class__.__name__, ": Clearing selection patches")
            for p in self.drawn_patches:
                try:
                    p.remove()
                except ValueError:
                    print(self.__class__.__name__, ": Patch was not found.")
            DraggableResizeableLine.reset_borders()
            DraggableResizeableRectangle.reset_borders()
            self.drawn_patches = []
            self.canvas.draw()

    def clear_widgets(self):
        self.clear_patches()
        self.clear_lines()

    def options_group(self):
        g = HGroup(
            VGroup(
                HGroup(
                    UItem('options_btn'),
                    UItem('clear_btn'),
                    UItem('line_selector', visible_when='not img_bool'),
                    UItem('copy_data_btn', visible_when='not img_bool'),
                    Item('normalize_bool', label='normalize'),
                    Item('log_bool', label='log scale'),
                    Item('cmap_selector', label='cmap', visible_when='img_bool'),
                    UItem('image_slider_btn', visible_when='img_bool'),
                    UItem('save_fig_btn'),
                    label='Basic'
                ),
                HGroup(
                    UItem('widget_sel'),
                    UItem('widget_clear_btn'),
                    Item('drawn_lines_selector', label='drawn lines', tooltip='select here line for line property (e.g. width) changes'),
                    Item('line_width', tooltip='average over number of lines in both directions.'),
                    Item('line_interpolation_order'),
                    label='Widgets'
                ),
                layout='tabbed',
            ),
        )
        return g