Пример #1
0
    def CreateDisplay(self):
        rax = plt.axes([0.025, 0.8, 0.15, 0.15])
        radioSelectOperation = RadioButtons(rax, ("Search", "Insert"),
                                            active=0)
        radioSelectOperation.on_clicked(self.OnOperationTypeSelect)
        radioSelectOperation.set_active(0)

        axAS = plt.axes([0.25, 0.20, 0.65, 0.03])
        axC = plt.axes([0.25, 0.15, 0.65, 0.03])
        axGS = plt.axes([0.25, 0.10, 0.65, 0.03])

        sAtomicSize = Slider(axAS,
                             'AtomicSize',
                             0,
                             len(self.AtomicSize) - 1,
                             valinit=0,
                             valfmt="%1.2f")
        sAtomicSize.on_changed(partial(self.setAS_slider, sAtomicSize))
        sAtomicSize.set_val(0.0)

        sCapacity = Slider(axC,
                           'Capacity',
                           0,
                           len(self.Capacity) - 1,
                           valinit=0,
                           valfmt="%i")
        sCapacity.on_changed(partial(self.setC_slider, sCapacity))
        sCapacity.set_val(0.0)

        sGridSize = Slider(axGS,
                           'GridSize',
                           0,
                           len(self.GridSize) - 1,
                           valinit=0,
                           valfmt="%i")
        sGridSize.on_changed(partial(self.setGS_slider, sGridSize))
        sGridSize.set_val(0.0)
        plt.show()
Пример #2
0
class financeViewer():
    def __init__(self):

        self.box = dict(facecolor='blue',
                        pad=3,
                        alpha=0.2,
                        boxstyle="Round4,pad=0.3")
        self.transactionString = """Date: {}
                            Sum: {} HUF
                            Comment: {}"""
        self.initString = """
        Select a period to inspect transactions
        using your mouse, or change the settings
        """

        self.scale1 = 'log'
        self.scale2 = 'log'
        self.mode = 'transaction'  # the other mode is balance mode, modifies the top plot

        self.start, self.end = None, None

    def createFigure(self):
        # disable toolbar
        matplotlib.rcParams['toolbar'] = 'None'
        self.fig = plt.figure(figsize=(WIDTH, WIDTH / G_RAT),
                              facecolor=LIGHT_GREY)

        self.gsp = gridspec.GridSpec(nrows=3,
                                     ncols=2,
                                     wspace=0.05,
                                     hspace=0.45,
                                     width_ratios=[G_RAT, 1],
                                     height_ratios=[(1 + G_RAT) / G_RAT, G_RAT,
                                                    1])

        self.ax1 = plt.subplot(self.gsp[0, :])
        self.ax2 = plt.subplot(self.gsp[1:, 0])
        self.ax3 = plt.subplot(self.gsp[2, 1])
        self.ax4 = plt.subplot(self.gsp[1, 1])

    def drawAxes(self):

        for ax in [self.ax1, self.ax2, self.ax3, self.ax4]:
            ax.set_facecolor(GREY)

        #####BIG PLOT##
        self.plotAx1()

        ####ZOOM PLOT##
        self.plotAx2()

        ##info plot##
        self.txt = self.ax3.text(0.1,
                                 0.5,
                                 self.initString,
                                 horizontalalignment='left',
                                 verticalalignment='center',
                                 fontsize=13,
                                 color='black',
                                 wrap=True)
        self.ax3.set_xticks([])
        self.ax3.set_yticks([])
        self.ax3.set_title('info about the transactions', bbox=self.box)

        ### place of buttons##
        self.ax4.set_xticks([])
        self.ax4.set_yticks([])

    def on_plot_hover(self, event):

        if not event.inaxes: return
        if event.inaxes != self.ax2: return

        for idx, bar in enumerate(self.ax2.patches):
            if bar.get_x() < event.xdata < bar.get_x() + bar.get_width():
                if bar.get_y() < event.ydata < bar.get_y() + bar.get_height():

                    self.ax2.patches[idx].set_facecolor(
                        dark2light[bar.get_edgecolor()])
                    date_ordinal, y = self.ax2.transData.inverted().transform(
                        [event.x, event.y]) + 0.5

                    # convert the numeric date into a datetime
                    transDate = num2date(date_ordinal).strftime(
                        DATEFORMATSTRING)
                    pdDate = num2date(date_ordinal).strftime('%Y%m%d')
                    try:
                        comment = self.cleanDf.loc[
                            (self.cleanDf['date'] == int(pdDate)) &
                            (abs(self.cleanDf['sum'], ) == bar.get_height()),
                            'comment'].iloc[0]
                    except:
                        comment = 'Record not found'

                    newStr = self.transactionString.format(
                        transDate, bar.get_height(), comment)
                    self.txt.set_text(newStr)
            else:
                self.ax2.patches[idx].set_facecolor(
                    dark2normal[bar.get_edgecolor()])
        self.fig.canvas.draw()

    def reset_button_on_clicked(self, mouse_event):
        self.plotAx2()

    def balanceView_button_on_clicked(self, mouse_event):
        self.txt.set_text('Not implemented yet')

    def transView_button_on_clicked(self, mouse_event):
        self.txt.set_text('Not implemented yet')

    def plotAx2(self, ):
        self.ax2.cla()
        self.ax2.set_title('Selected duration', bbox=self.box)
        if self.start != None:
            startDate = self.pdRange[self.start]
            endDate = self.pdRange[self.end]
            currentRange = pd.date_range(
                start=startDate,
                end=endDate,
                periods=None,
                freq='D',
            )
            indexes = []

            for idx, day in enumerate(self.incomeX):
                if (len(np.where(currentRange == day)[0])):
                    indexes.append(idx)
            currIncomeX = np.array(self.incomeX)[indexes]
            currIncomeY = np.array(self.incomeY)[indexes]

        else:
            currentRange = self.pdRange
            currIncomeX = self.incomeX
            currIncomeY = self.incomeY

        baseArray = np.zeros(len(currentRange), dtype=np.float)

        self.ax2.bar(currIncomeX,
                     currIncomeY,
                     color=GREEN,
                     edgecolor=DARK_GREEN)

        for expenseX, expenseY in zip(self.expenseXs, self.expenseYs):
            ## calculate bottom for this iteration
            currBottomIdxs = []
            indexes = []

            for idx, day in enumerate(expenseX):
                if len(np.where(currentRange == day)[0]):
                    currBottomIdxs.append(np.where(currentRange == day)[0][0])
                    indexes.append(idx)

            expenseX = np.array(expenseX)[indexes]
            expenseY = np.array(expenseY)[indexes]
            bottom = baseArray[currBottomIdxs]
            self.ax2.bar(expenseX,
                         expenseY,
                         bottom=bottom,
                         color=RED,
                         edgecolor=DARK_RED)
            ### calculate baseArray for the next iteration

            baseArray[currBottomIdxs] += expenseY

        if self.start != None and self.end - self.start <= 4:
            print(333)
            self.ax2.xaxis.set_major_locator(DayLocator())

        self.ax2.xaxis.set_major_formatter(DATEFROMAT)
        self.ax2.set_yscale(self.scale2, nonposy='clip')
        self.ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
        plt.setp(self.ax2.xaxis.get_majorticklabels(), rotation=LABEL_ROTATION)

    def plotAx1(self):

        self.ax1.cla()
        self.ax1.set_title('Whole duration', bbox=self.box)

        if self.mode == 'transaction':
            self.plotAx1_transaction()
        elif self.mode == 'balance':
            self.plotAx1_balance()
        else:
            raise ValueError('selected mode not supported:  %s' % self.mode)

        self.span = SpanSelector(self.ax1,
                                 self.onselect,
                                 'horizontal',
                                 rectprops=dict(alpha=0.3, facecolor=RED))

        self.ax1.xaxis.set_major_formatter(DATEFROMAT)
        self.ax1.set_yscale(self.scale1, nonposy='clip')
        self.ax1.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
        plt.setp(self.ax1.xaxis.get_majorticklabels(), rotation=LABEL_ROTATION)

    def plotAx1_balance(self):

        self.ax1.step(self.pdDates, self.balance, marker="d", color=DARK_RED)

    def plotAx1_transaction(self):
        self.ax1.bar(self.incomeX,
                     self.incomeY,
                     color=GREEN,
                     edgecolor=DARK_GREEN)

        baseArray = np.zeros(len(self.pdRange), dtype=np.float)
        for expenseX, expenseY in zip(self.expenseXs, self.expenseYs):
            ## calculate bottom for this iteration
            currBottomIdxs = [
                np.where(self.pdRange == day)[0][0] for day in expenseX
            ]
            bottom = baseArray[currBottomIdxs]
            self.ax1.bar(expenseX,
                         expenseY,
                         bottom=bottom,
                         color=RED,
                         edgecolor=DARK_RED)
            ### calculate baseArray for the next iteration

            baseArray[currBottomIdxs] += expenseY

    def onselect(self, xmin, xmax):

        dayMin, dayMax = sorted((int(xmin - 0.5), int(xmax + 0.5)))
        ##xmin, xmax is days from zero, if Xaxis is pandas daterange
        yearZero = datetime.datetime.strptime('0001/01/01', "%Y/%m/%d")
        startDate = yearZero + timedelta(days=dayMin)
        endDate = yearZero + timedelta(days=dayMax)
        st = str(startDate)[:10]
        nd = str(endDate)[:10]

        stIdx, = np.where(self.pdRange.values == np.datetime64(st))
        endIdx, = np.where(self.pdRange.values == np.datetime64(nd))

        if stIdx and endIdx:
            stIdx, endIdx = stIdx[0], endIdx[0]
            # start and endpoints in of range
        elif stIdx:
            stIdx, endIdx = stIdx[0], len(self.pdRange) - 1
            # endpoint out of range
        elif endIdx:
            stIdx, endIdx = 0, endIdx[0]
            # startpoint out of range
        else:
            # start and endpoints are out of range ")
            return

        self.start, self.end = stIdx, endIdx

        ist = int(st.replace("-", ""))
        ind = int(nd.replace("-", ""))

        selectedBalance = self.balance[(self.dateAxis > ist)
                                       & (self.dateAxis < ind)]
        selectionString = """
        Selection: {} - {}
        Starting balance: {} HUF
        Final balance: {} HUF
        Difference: {} HUF
        """.format(st, nd, selectedBalance[0], selectedBalance[-1],
                   selectedBalance[-1] - selectedBalance[0])
        self.txt.set_text(selectionString)

        self.plotAx2()
        self.fig.canvas.draw()

    def makeButtons(self):

        pos = self.ax4.get_position(
        )  # get the  position of axis ,which contains the buttons
        self.ax4.set_title('plot properties', bbox=self.box)
        rowNr, colNr = 2, 2
        buttonwidth = 0.13
        buttonheight = 0.07
        Vspace = (pos.width - colNr * buttonwidth) / (colNr + 1)
        Hspace = (pos.height - rowNr * buttonheight) / (rowNr + 1)
        ## radio buttons
        scaleSelectorAx1 = self.fig.add_axes([
            pos.x0 + Vspace, pos.y0 + 2 * Hspace + buttonheight, buttonwidth,
            buttonheight
        ],
                                             facecolor=PURPLE)
        scaleSelectorAx2 = self.fig.add_axes(
            [pos.x0 + Vspace, pos.y0 + Hspace, buttonwidth, buttonheight],
            facecolor=PURPLE)
        modeSelectorAx1 = self.fig.add_axes([
            pos.x0 + 2 * Vspace + buttonwidth,
            pos.y0 + 2 * Hspace + buttonheight, buttonwidth, buttonheight
        ],
                                            facecolor=PURPLE)

        scaleSelectorAx1.set_title('top plot scale', fontsize=12)
        scaleSelectorAx2.set_title('bottom plot scale', fontsize=12)
        modeSelectorAx1.set_title('top plot mode', fontsize=12)

        axcolor = PURPLE
        self.scaleSelector1 = RadioButtons(scaleSelectorAx1,
                                           ('logaritmic', 'linear'))
        self.scaleSelector2 = RadioButtons(scaleSelectorAx2,
                                           ('logaritmic', 'linear'))
        self.modeSelector = RadioButtons(modeSelectorAx1,
                                         ('transaction view', 'balance view'))

        for button in [
                self.scaleSelector1, self.scaleSelector2, self.modeSelector
        ]:
            for circle in button.circles:  # adjust radius here. The default is 0.05
                circle.set_radius(0.09)
                circle.set_edgecolor('black')

        ## small buttons
        resetAx = self.fig.add_axes([
            pos.x0 + 2 * Vspace + buttonwidth, pos.y0 + Hspace,
            buttonwidth / 2, buttonheight
        ])
        helpAx = self.fig.add_axes([
            pos.x0 + 2 * Vspace + 1.5 * buttonwidth, pos.y0 + Hspace,
            buttonwidth / 2, buttonheight
        ])
        self.resetBtn = Button(resetAx,
                               'Reset',
                               color=PURPLE,
                               hovercolor=DARK_RED)
        self.helpBtn = Button(helpAx,
                              'About',
                              color=PURPLE,
                              hovercolor=DARK_RED)

    def resetClicked(self, event):

        self.scale1 = 'log'
        self.scale2 = 'log'
        self.mode = 'transaction'
        self.start = None
        self.end = None
        self.plotAx1()
        self.plotAx2()
        self.scaleSelector1.set_active(0)
        self.scaleSelector2.set_active(0)
        self.modeSelector.set_active(0)
        self.fig.canvas.draw()

    def helpClicked(self, event):
        pass
        print('help')
        helpText = """Go to
        github.com/Wheele9/transaction-viewer
        to get the latest version, 
        to create an issue or pull request.
        Feel free to contact me."""
        self.txt.set_text(helpText)

    def modeButtonClicked(self, label):
        print(label)
        if label == 'balance view':
            if self.mode == 'balance': return
            self.mode = 'balance'
            self.scale1 = 'linear'
            self.plotAx1()
        elif label == 'transaction view':
            if self.mode == 'transaction': return
            self.mode = 'transaction'
            self.plotAx1()
        else:
            raise ValueError('could not find %s' % label)
        print('clicked,', self.mode)
        self.fig.canvas.draw()

    def scaleButton1Clicked(self, label):

        if label == 'linear':
            if self.scale1 == 'linear': return
            self.scale1 = 'linear'
            self.plotAx1()
        elif label == 'logaritmic':
            if self.scale1 == 'logaritmic': return
            self.scale1 = 'log'
            self.plotAx1()
        else:
            raise ValueError('could not find %s' % label)
        self.fig.canvas.draw()

    def scaleButton2Clicked(self, label):

        if label == 'linear':
            if self.scale2 == 'linear': return
            self.scale2 = 'linear'
            self.plotAx2()
        elif label == 'logaritmic':
            if self.scale2 == 'logaritmic': return
            self.scale2 = 'log'
            self.plotAx2()
        else:
            raise ValueError('could not find %s' % label)
        self.fig.canvas.draw()

    def connectButtons(self):

        self.scaleSelector1.on_clicked(self.scaleButton1Clicked)
        self.scaleSelector2.on_clicked(self.scaleButton2Clicked)
        self.modeSelector.on_clicked(self.modeButtonClicked)

        self.resetBtn.on_clicked(self.resetClicked)
        self.helpBtn.on_clicked(self.helpClicked)

    def calculateAttributes(self):

        self.balance = self.cleanDf['balance'].values
        self.dateAxis = self.cleanDf['date'].values
        self.transactions = self.cleanDf['sum'].values
        self.pdDates = [
            pd.to_datetime(str(date), format='%Y%m%d')
            for date in self.dateAxis
        ]

        start = self.pdDates[0]
        end = self.pdDates[-1]
        self.pdRange = pd.date_range(
            start=start,
            end=end,
            periods=None,
            freq='D',
        )

    def separateTransactions(self):

        values, counts = np.unique(self.pdDates, return_counts=True)
        maxPerDay = max(counts)

        expenseXs, expenseYs = [], []
        incomeX, incomeY = [], []
        smallX, smallY = [], []

        for freq in range(1, max(counts) + 1):
            for val, cnt in zip(values, counts):
                if cnt >= freq:
                    index = np.where(np.array(self.pdDates) == val)[0][freq -
                                                                       1]
                    if self.transactions[index] > 0:
                        incomeX.append(val)
                        incomeY.append(self.transactions[index])
                    else:
                        smallX.append(val)
                        smallY.append(-self.transactions[index])

            expenseXs.append(smallX)
            expenseYs.append(smallY)
            smallX, smallY = [], []

        self.expenseXs = expenseXs
        self.expenseYs = expenseYs
        self.incomeX = incomeX
        self.incomeY = incomeY

    def showPlots(self, cleanDf):
        self.cleanDf = cleanDf
        self.calculateAttributes()
        self.separateTransactions()

        self.createFigure()
        self.drawAxes()

        self.fig.canvas.mpl_connect('button_press_event', self.on_plot_hover)
        self.fig.subplots_adjust(left=0.06, bottom=0.07, right=0.97, top=0.95)

        self.makeButtons()
        self.connectButtons()

        plt.show()
Пример #3
0
class MaskEditor(ImageWindow):
    '''
<Basic Actions>
a/d: switch to previous/next view mode
Ctrl + s: save and exit
Ctrl + z: undo the last action

<Brush Actions>
mouse right + dragging = paint with brush
mouse wheel up/down: increase/decrease brush radius
w/s: change brush type (current brush type is shown on the upper panel)

<Grabcut Actions>
Ctrl + g: run grabcut with current mask

<Threshold Actions>
Drag sliders at the bottom: adjust thresholds for H, S, V channel pixel values
Double click on HSV panel: reset threshold values
q/e: switch to previous/next HSV panel view
    '''

    def __init__(self, img: np.ndarray, mask: np.ndarray, win_title=None):
        super().__init__(win_title, (0.05, 0.18, 0.9, 0.7))
        axcolor = 'lightgoldenrodyellow'
        self.disable_callbacks()

        self.src = img.copy()
        self.src_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)

        self.mask_src = preprocess_mask(mask)
        self.gc_mask = np.copy(self.mask_src)
        self.viewmode = ViewMode.MASKED_IMAGE
        self.brush_iptr = BrushInterpreter()
        self.history_mgr = MaskEditHistoryManager()
        self.save_result = False

        self.pixel_panel = self.fig.add_axes((0.7, 0.9, 0.08, 0.05))
        self.pixel_panel.imshow(255 * np.ones((5, 8, 3), np.uint8))
        hide_axes_labels(self.pixel_panel)
        for pos in ['left', 'top', 'right', 'bottom']:
            self.pixel_panel.spines[pos].set_color('none')

        unit = 0.06
        # Create brush panel
        self.brush_panel = self.fig.add_axes((unit, 0.9, len(BrushType) * unit, unit))
        hide_axes_labels(self.brush_panel)
        self.brush_panel.imshow(np.array([[BrushType.val2color(i) for i in range(len(BrushType))]]))
        self.brush_indicator = []
        self.update_brush_panel()

        # Create largest component panel
        self.lc_panel = self.fig.add_axes(((len(BrushType) + 1) * unit, 0.9, unit, unit))
        hide_axes_labels(self.lc_panel)
        self.lc_panel.text(
            -0.45, -0.7,
            'Largest Component Only',
            bbox=dict(
                linewidth=1,
                edgecolor='goldenrod',
                facecolor='none',
                alpha=1.0
            )
        )
        self.lc_panel.imshow(np.array([[[255, 255, 224]]], dtype=np.uint8))
        self.lc_switch = RadioButtons(self.lc_panel, ('off', 'on'))
        self.lc_switch.on_clicked(lambda x: self.update_mask() or self.history_mgr.add_switch_history('lc') or self.display())

        # Create fill holes panel
        self.fh_panel = self.fig.add_axes(((len(BrushType) + 3) * unit, 0.9, unit, unit))
        hide_axes_labels(self.fh_panel)
        self.fh_panel.text(
            -0.45, -0.7,
            'Fill Holes',
            bbox=dict(
                linewidth=1,
                edgecolor='goldenrod',
                facecolor='none',
                alpha=1.0
            )
        )
        self.fh_panel.imshow(np.array([[[255, 255, 224]]], dtype=np.uint8))
        self.fh_switch = RadioButtons(self.fh_panel, ('off', 'on'))
        self.fh_switch.on_clicked(lambda x: self.update_mask() or self.history_mgr.add_switch_history('fh') or self.display())

        # Create component area threshold panel
        self.area_panel = self.fig.add_axes(((len(BrushType) + 5) * unit, 0.9 + 0.3 * unit, 3 * unit, 0.3 * unit), facecolor=axcolor)
        hide_axes_labels(self.area_panel)
        self.area_panel.text(
            -0.45, 1.7,
            'Filter Small Components by Area',
            bbox=dict(
                linewidth=1,
                edgecolor='goldenrod',
                facecolor='none',
                alpha=1.0
            )
        )
        self.area_slider = Slider(self.area_panel, '', 0, 100, valinit=0, color=(0, 1, 0, 0.3), valfmt='%0.2f%% of the largest component')
        self.on_area_adjust = False

        # Create threshold sliders
        self.upper_names = ['Upper H', 'Upper S', 'Upper V']
        self.lower_names = ['Lower H', 'Lower S', 'Lower V']
        self.min_values = np.array([0, 0, 0])
        self.max_values = np.array([359, 255, 255])
        self.thresh_sliders = {}
        self.thresh_slider_panels = []
        self.on_thresh_adjust = False

        for i in range(3):
            lower_slider_ax = self.fig.add_axes((0.25, 0.12 - (2 * i) * 0.018 - i * 0.008, 0.7, 0.01), facecolor=axcolor)
            self.thresh_sliders[self.lower_names[i]] = Slider(
                lower_slider_ax, self.lower_names[i],
                self.min_values[i], self.max_values[i],
                valinit=self.min_values[i], color=(0, 1, 0, 0.3),
                valfmt='%d'
            )
            self.thresh_sliders[self.lower_names[i]].valtext.set_text(str(int(self.min_values[i])))
            self.thresh_slider_panels.append(lower_slider_ax)

            upper_slider_ax = self.fig.add_axes((0.25, 0.12 - (2 * i + 1) * 0.018 - i * 0.008, 0.7, 0.01), facecolor=axcolor)
            self.thresh_sliders[self.upper_names[i]] = Slider(
                upper_slider_ax, self.upper_names[i],
                self.min_values[i], self.max_values[i],
                valinit=self.max_values[i], color=(0, 1, 0, 0.3),
                valfmt='%d'
            )
            self.thresh_sliders[self.upper_names[i]].valtext.set_text(str(int(self.max_values[i])))
            self.thresh_slider_panels.append(upper_slider_ax)
        self.enable_callbacks()

        # Create hue-saturation panel
        self.hs_panel = self.fig.add_axes((0.008, 0.01, 0.15, 0.15), projection='polar', facecolor=axcolor)
        self.hs_plot_mode = HSPlotMode.ALL
        self.hs_region_iptr = PolarDragInterpreter()
        self.arc_regions = []
        self.temp_arc_regions = []
        hide_axes_labels(self.hs_panel)

        # Create value panel
        self.v_panel = self.fig.add_axes((0.16, 0.01, 0.02, 0.15), facecolor=axcolor)
        hide_axes_labels(self.v_panel, 'x')

        self.plot_hs_range()
        self.plot_thresh_regions()
        self.display()
        self.root.focus_force()

    def mainloop(self):
        super().mainloop()
        return self.gc_mask if self.save_result else None

    def enable_callbacks(self):
        super().enable_callbacks()
        if hasattr(self, 'sliders'):
            for i in range(3):
                self.thresh_sliders[self.lower_names[i]].set_active(True)
                self.thresh_sliders[self.upper_names[i]].set_active(True)

    def disable_callbacks(self):
        super().disable_callbacks()
        if hasattr(self, 'sliders'):
            for i in range(3):
                self.thresh_sliders[self.lower_names[i]].set_active(False)
                self.thresh_sliders[self.upper_names[i]].set_active(False)

    def run_grabcut(self):
        try:
            gc_mask = grabcut(self.src, cv2.GC_INIT_WITH_MASK, mask=self.gc_mask)
            self.history_mgr.add_grabcut_history(self.mask_src)
            self.mask_src = gc_mask
            self.update_mask()
            self.display()
        except ValueError as e:
            self.show_message(str(e), 'Warning')

    @property
    def lower_thresh(self):
        return [int(self.thresh_sliders[self.lower_names[i]].val) for i in range(3)]

    @property
    def upper_thresh(self):
        return [int(self.thresh_sliders[self.upper_names[i]].val) for i in range(3)]

    @property
    def h_range(self):
        return HRange(self.lower_thresh[0], self.upper_thresh[0])

    @property
    def s_range(self):
        return SRange(self.lower_thresh[1], self.upper_thresh[1])

    @property
    def v_range(self):
        return VRange(self.lower_thresh[2], self.upper_thresh[2])

    @property
    def min_area(self):
        return self.area_slider.val

    def update_mask(self):
        thresh_mask = threshold_hsv(self.src_hsv, self.h_range, self.s_range, self.v_range) // 255

        self.gc_mask = merge_gc_mask(self.mask_src, thresh_mask)

        for brush_trace in self.history_mgr.brush_traces():
            for brush_touch in brush_trace:
                self.gc_mask = apply_brush_touch(self.gc_mask, brush_touch)

        if self.fill_holes:
            self.gc_mask = fill_holes_gc(self.gc_mask)

        if self.largest_component_only:
            component_mask = (largest_connected_component(
                np.where(self.gc_mask % 2 == 1, 255, 0).astype(np.uint8)
            ) // 255).astype(np.uint8)
        else:
            component_mask = (filter_by_area(
                np.where(self.gc_mask % 2 == 1, 255, 0).astype(np.uint8),
                self.min_area / 100
            ) // 255).astype(np.uint8)

        self.gc_mask = merge_gc_mask(self.gc_mask, component_mask)

    def update_brush_panel(self):
        for item in self.brush_indicator:
            item.remove()
        self.brush_indicator.clear()
        brush_id = self.brush_iptr.brush.value['val']
        self.brush_indicator = [
            self.brush_panel.text(
                i - 0.45, -0.7,
                BrushType.val2name(i),
                fontweight='bold' if brush_id == i else 'normal',
                bbox=dict(
                    linewidth=3 if brush_id == i else 1,
                    facecolor='w',
                    edgecolor=BrushType.val2color(i) / 255,
                    alpha=1.0
                )
            ) for i in range(len(BrushType))
        ]
        self.refresh()

    @property
    def largest_component_only(self):
        return self.lc_switch.value_selected == 'on'

    @property
    def fill_holes(self):
        return self.fh_switch.value_selected == 'on'

    def set_sliders(self, lower_thresh, upper_thresh, write_history=True):
        for i, vals in enumerate(zip(lower_thresh, upper_thresh)):
            self.thresh_sliders[self.lower_names[i]].set_val(vals[0])
            self.thresh_sliders[self.upper_names[i]].set_val(vals[1])
        if write_history:
            self.history_mgr.add_thresh_history(self.lower_thresh, self.upper_thresh)
        self.update_mask()
        self.refresh()

    def display(self):
        if self.viewmode == ViewMode.MASK:
            rgb_mask = np.zeros((*self.gc_mask.shape, 3), dtype=np.uint8)
            for val in range(4):
                rgb_mask[self.gc_mask == val] = BrushType.val2color(val)
            self.set_image(rgb_mask)
            self.ax.set_title('Object Mask')
        elif self.viewmode == ViewMode.MASKED_IMAGE:
            self.set_image(cv2.bitwise_and(self.src, self.src, mask=np.where(self.gc_mask % 2 == 1, 255, 0).astype(np.uint8)))
            self.ax.set_title('Foreground')
        elif self.viewmode == ViewMode.INVERSE_MASKED_IMAGE:
            self.set_image(cv2.bitwise_and(self.src, self.src, mask=np.where(self.gc_mask % 2 == 0, 255, 0).astype(np.uint8)))
            self.ax.set_title('Background')
        elif self.viewmode == ViewMode.MASK_OVERLAY:
            overlayed = np.copy(self.src)
            for i in range(len(BrushType)):
                overlay_mask(overlayed, np.where(self.gc_mask == i, 255, 0).astype(np.uint8), BrushType.val2color(i) / 255)
            self.set_image(overlayed)
            self.ax.set_title('Mask Overlayed Image')
        else:
            raise ValueError('Invalid viewmode: {}'.format(self.viewmode))

    def add_arc_region(self, rect, temporary=False):
        if not rect.is_empty():
            theta = np.arange(rect.left, rect.right, np.pi / 180)
            r_bottom = rect.bottom * np.ones_like(theta)
            r_top = rect.top * np.ones_like(theta)

            whole_theta_range = (rect.right - rect.left) % (2 * np.pi) == 0

            if temporary:
                if not whole_theta_range:
                    self.temp_arc_regions += self.hs_panel.plot(
                        [rect.left, rect.left], [rect.bottom, rect.top], color='k', linestyle='--'
                    )
                    self.temp_arc_regions += self.hs_panel.plot(
                        [rect.right, rect.right], [rect.bottom, rect.top], color='k', linestyle='--'
                    )

                self.temp_arc_regions += self.hs_panel.plot(theta, r_top, color='k', linestyle='--')
                self.temp_arc_regions += self.hs_panel.plot(theta, r_bottom, color='k', linestyle='--')
            else:
                if not whole_theta_range:
                    self.arc_regions += self.hs_panel.plot(
                        [rect.left, rect.left], [rect.bottom, rect.top], color='k', linestyle='-'
                    )
                    self.arc_regions += self.hs_panel.plot(
                        [rect.right, rect.right], [rect.bottom, rect.top], color='k', linestyle='-'
                    )
                self.arc_regions += self.hs_panel.plot(theta, r_top, color='k', linestyle='-')
                self.arc_regions += self.hs_panel.plot(theta, r_bottom, color='k', linestyle='-')

            self.hs_panel.set_rmax(1.2)
            self.refresh()

    def clear_arc_regions(self):
        for i in range(len(self.arc_regions)):
            self.arc_regions[i].remove()

        self.arc_regions.clear()
        self.refresh()

    def clear_temp_arc_regions(self):
        for i in range(len(self.temp_arc_regions)):
            self.temp_arc_regions[i].remove()

        self.temp_arc_regions.clear()
        self.refresh()

    def plot_hs_range(self):
        self.hs_panel.clear()
        if self.hs_plot_mode == HSPlotMode.ALL:
            self.hs_panel.set_title('Hue-Saturation\n        Disc', loc='left')
            # visualize hsv pallete
            h_range = HRange()
            s_range = SRange()

            r_val = s_range.get_ranges(1 / 32)[0]
            theta_val = h_range.get_ranges(np.pi / 60)[0]

            r = np.tile(r_val, len(theta_val))
            theta = np.repeat(theta_val, len(r_val))
            self.hs_panel.scatter(
                theta, r,
                c=colors.hsv_to_rgb(np.clip(np.transpose([theta / (2 * np.pi), r, np.ones_like(theta)]), 0, 1)),
                s=30 * r,
                alpha=0.8
            )
        else:
            if self.hs_plot_mode == HSPlotMode.IMAGE:
                img = np.copy(self.src_hsv)
                self.hs_panel.set_title('Image HSV\nDistribution', loc='left')
            elif self.hs_plot_mode == HSPlotMode.BACKGROUND:
                img = cv2.bitwise_and(self.src_hsv, self.src_hsv, mask=np.where(self.gc_mask % 2 == 0, 255, 0).astype(np.uint8))
                self.hs_panel.set_title('Background HSV\nDistribution', loc='left')
            else:
                img = cv2.bitwise_and(self.src_hsv, self.src_hsv, mask=np.where(self.gc_mask % 2 == 1, 255, 0).astype(np.uint8))
                self.hs_panel.set_title('Foreground HSV\nDistribution', loc='left')
            hsv_pixels = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
            hsv_pixels = np.unique(hsv_pixels, axis=0)
            hsv_pixels = (np.round(hsv_pixels / [3, 8, 1]) * np.array([3, 8, 1])).astype(np.int)
            hsv_pixels = np.unique(hsv_pixels, axis=0)

            theta = 2 * hsv_pixels[:, 0] / 179 * np.pi
            r = hsv_pixels[:, 1] / 255
            v = hsv_pixels[:, 2] / 255
            self.hs_panel.scatter(
                theta, r,
                c=colors.hsv_to_rgb(np.clip(np.transpose([theta / (2 * np.pi), r, v]), 0, 1)),
                s=30 * r,
                alpha=0.8
            )

        self.hs_panel.set_rmax(1.2)

    def plot_thresh_regions(self):
        self.clear_arc_regions()
        for region in get_arc_regions(self.h_range, self.s_range):
            self.add_arc_region(region)
        self.hs_panel.set_rmax(1.2)

        self.v_panel.clear()
        rgb = cv2.cvtColor(np.array([[[0, 0, v]] * 10 for v in range(VRange.MAX + 1)]).astype(np.uint8), cv2.COLOR_HSV2RGB)
        alpha = np.array([[[255 if v in self.v_range else 0]] * 10 for v in range(VRange.MAX + 1)]).astype(np.uint8)
        self.v_panel.imshow(np.dstack((rgb, alpha)) / 255)

    @on_caps_lock_off
    def on_key_press(self, event):
        super().on_key_press(event)
        if event.key == 'd':
            self.viewmode += 1
            self.display()
        elif event.key == 'a':
            self.viewmode -= 1
            self.display()
        elif event.key == 'w':
            self.brush_iptr.brush += 1
            self.update_brush_panel()
        elif event.key == 's':
            self.brush_iptr.brush -= 1
            self.update_brush_panel()
        elif event.key == 'ctrl+s':
            self.save_result = len(self.history_mgr) > 0
            self.close()
        elif event.key == ' ':
            self.run_grabcut()
        elif event.key == 'ctrl+z':
            if len(self.history_mgr) > 0:
                action_name, data = self.history_mgr.pop()
                if action_name == 'grabcut':
                    self.mask_src = data
                elif action_name == 'thresh':
                    self.set_sliders(*data, False)
                    self.plot_hs_range()
                    self.plot_thresh_regions()
                elif action_name == 'area':
                    self.area_slider.set_val(data)
                elif action_name == 'lc':
                    idx = 0 if self.lc_switch.value_selected == 'off' else 1
                    self.lc_switch.set_active(1 - idx)
                    self.history_mgr.pop()
                elif action_name == 'fh':
                    idx = 0 if self.fh_switch.value_selected == 'off' else 1
                    self.fh_switch.set_active(1 - idx)
                    self.history_mgr.pop()
                self.update_mask()
                self.display()
            else:
                self.show_message('No history to recover', 'Guide')
        elif event.key == 'q':
            self.hs_plot_mode -= 1
            self.plot_hs_range()
            self.plot_thresh_regions()
        elif event.key == 'e':
            self.hs_plot_mode += 1
            self.plot_hs_range()
            self.plot_thresh_regions()

    def on_mouse_press(self, event):
        super().on_mouse_press(event)

        p = self.get_axes_coordinates(event)
        if event.key is None and event.button == 3 and event.inaxes is self.ax:
            self.brush_iptr.start_dragging(p)
        elif event.button == 1 and event.inaxes in self.thresh_slider_panels:
            self.clear_arc_regions()
            self.on_thresh_adjust = True
        elif event.button == 1 and event.inaxes is self.area_panel:
            self.on_area_adjust = True
        elif event.inaxes is self.hs_panel:
            if event.dblclick and event.button == 1:
                if not np.array_equal(self.history_mgr.last_lower_thresh, self.min_values) or not np.array_equal(self.history_mgr.last_upper_thresh, self.max_values):
                    self.clear_arc_regions()
                    self.set_sliders(self.min_values, self.max_values)
                    self.plot_thresh_regions()
                    self.update_mask()
                    self.display()
            elif event.button == 3:
                self.clear_arc_regions()
                p = self.get_axes_coordinates(event, dtype=float)
                if event.key in ['control', 'ctrl+shift'] and p is not None:
                    self.hs_region_iptr.start_dragging(Point(p.x, 0, dtype=float), event.key == 'control')
                else:
                    self.hs_region_iptr.start_dragging(p, event.key != 'shift')

    def on_mouse_move(self, event):
        super().on_mouse_move(event)
        p = self.get_axes_coordinates(event)
        if event.inaxes is self.ax and p in self.img_rect:
            self.transient_patches.append(
                self.pixel_panel.text(
                    0, 5,
                    'x: {}, y: {}, H: {}, S: {}, V: {}'
                        .format(p.x, p.y, 2 * self.src_hsv[p.y][p.x][0], self.src_hsv[p.y][p.x][1], self.src_hsv[p.y][p.x][2]),
                    bbox=dict(
                        linewidth=1,
                        edgecolor='none',
                        facecolor='none',
                        alpha=1.0
                    )
                )
            )
        if event.inaxes is self.ax and event.key != 'control':
            self.add_transient_patch(BrushTouch(
                p, self.brush_iptr.radius, True, self.brush_iptr.brush
            ).patch(alpha=0.3))

            if self.brush_iptr.on_dragging:
                trace = self.brush_iptr.get_trace(p)
                self.add_patch(trace.patch())
        elif self.on_thresh_adjust:
            self.plot_thresh_regions()
        elif self.hs_region_iptr.on_dragging:
            self.clear_temp_arc_regions()
            p = self.get_axes_coordinates(event, dtype=float)
            self.hs_region_iptr.update(p)
            self.add_arc_region(self.hs_region_iptr.rect, temporary=True)

    def on_mouse_release(self, event):
        super().on_mouse_release(event)
        p = self.get_axes_coordinates(event)
        if self.brush_iptr.on_dragging and event.button == 3:
            if p is not None:
                self.brush_iptr.get_trace(p)
                self.history_mgr.add_brush_touch_history(self.brush_iptr.history())
                self.brush_iptr.clear()
            self.clear_patches()
            self.clear_transient_patch()
            self.brush_iptr.finish_dragging(p)
        elif self.on_thresh_adjust and event.button == 1:
            self.on_thresh_adjust = False
            self.history_mgr.add_thresh_history(self.lower_thresh, self.upper_thresh)
        elif self.on_area_adjust and event.button == 1:
            self.on_area_adjust = False
            self.history_mgr.add_area_history(self.min_area)
        elif self.hs_region_iptr.on_dragging:
            p = self.get_axes_coordinates(event, dtype=float)
            self.hs_region_iptr.finish_dragging(p)
            self.clear_temp_arc_regions()

            rect = self.hs_region_iptr.rect
            if not rect.is_empty():
                hmin = int(round(np.rad2deg(rect.left).item())) % 360
                hmax = int(round(np.rad2deg(rect.right).item()))

                def adjust(val):
                    return min(255, max(0, int(round(255 * val))))

                lower_thresh = [hmin, adjust(rect.top), self.v_range.min]
                upper_thresh = [hmax, adjust(rect.bottom), self.v_range.max]
                self.set_sliders(lower_thresh, upper_thresh)

        self.update_mask()
        self.display()
        self.plot_hs_range()
        self.plot_thresh_regions()

    def on_scroll(self, event):
        super().on_scroll(event)
        if event.inaxes is self.ax and event.key != 'control':
            if event.step == 1:
                self.brush_iptr.radius += 1
            elif self.brush_iptr.radius >= 1:
                self.brush_iptr.radius -= 1
            p = self.get_axes_coordinates(event)
            self.clear_transient_patch()
            self.add_transient_patch(BrushTouch(
                p, self.brush_iptr.radius, True, self.brush_iptr.brush
            ).patch(alpha=0.3))
            self.refresh()
Пример #4
0
class SubjectPerformancesVisualizer:
    """ Display all performances of a subject, with a slider for visualizing all results. """
    def __init__(self, expe, default_subject_index, show_radio_selector=True):

        self.expe = expe
        self.fig, self.axes = plt.subplots(nrows=4,
                                           ncols=1,
                                           sharex=True,
                                           sharey=False,
                                           figsize=(7, 9))
        self.fig.subplots_adjust(bottom=0.15)

        # radio buttons for going through all subjects
        self.show_radio_selector = show_radio_selector
        radio_selector_x = 0.75 if show_radio_selector else 1.0
        self.fig.subplots_adjust(right=radio_selector_x - 0.05)
        self.widget_ax = plt.axes([radio_selector_x, 0.1, 0.6, 0.8],
                                  frameon=True)  # ,aspect='equal')
        self.radio_buttons = RadioButtons(
            self.widget_ax,
            tuple([str(i) for i in range(len(self.expe.subjects))]))
        if show_radio_selector:
            self.fig.text(0.9, 0.92, 'Subject:', ha='center', fontsize='10')
        self.radio_buttons.on_clicked(self.on_radio_button_changed)
        self.radio_buttons.set_active(default_subject_index)

    def close(self):
        self.fig.clear()
        plt.close(self.fig)

    def on_radio_button_changed(self, label):
        subject_index = int(label)
        subject = self.expe.subjects[subject_index]
        self.update_plot(subject)

    def update_plot(self, subject):
        plt.suptitle(
            "Durations $D$, errors $E$ and performances $S$ for subject #" +
            str(subject.index))

        synths_ids = self.expe.global_params.get_synths_ids()

        self.axes[0].clear()
        self.axes[0].set(ylabel="Search duration $D$")
        self.axes[0].scatter(synths_ids, subject.d[:, 0], marker='s')
        self.axes[0].scatter(synths_ids, subject.d[:, 1], marker='D')
        self.axes[0].set_ylim([0, subject.global_params.allowed_time
                               ])  # hides the -1 unvalid values

        self.axes[1].clear()
        self.axes[1].set(ylabel="Normalized error $E$")
        self.axes[1].scatter(synths_ids, subject.e_norm1[:, 0], marker='s')
        self.axes[1].scatter(synths_ids, subject.e_norm1[:, 1], marker='D')
        self.axes[1].set_ylim([0, 1])  # hides the -1 unvalid values
        self.axes[1].legend(['Sliders', 'Interp.'], loc="best")

        self.axes[2].clear()
        self.axes[2].set(ylabel="Score {}".format(
            perfeval.get_perf_eval_name(perfeval.EvalType.INGAME)))
        self.axes[2].scatter(synths_ids, subject.s_ingame[:, 0], marker='s')
        self.axes[2].scatter(synths_ids, subject.s_ingame[:, 1], marker='D')
        self.axes[2].set_ylim([0, 1])  # hides the -1 unvalid values
        self.axes[2].set_xlim([min(synths_ids) - 0.5, max(synths_ids) + 0.5])
        self.axes[2].xaxis.set_ticks(synths_ids)

        self.axes[3].clear()
        s_adj = subject.get_s_adjusted(
            adjustment_type=perfeval.EvalType.ADJUSTED
        )  # default best adjustment type
        self.axes[3].set(ylabel="Score {}".format(
            perfeval.get_perf_eval_name(perfeval.EvalType.ADJUSTED)),
                         xlabel="Synth ID")
        self.axes[3].scatter(synths_ids, s_adj[:, 0], marker='s')
        self.axes[3].scatter(synths_ids, s_adj[:, 1], marker='D')
        self.axes[3].set_ylim([0, 1])  # hides the -1 unvalid values
        self.axes[3].set_xlim([min(synths_ids) - 0.5, max(synths_ids) + 0.5])
        self.axes[3].xaxis.set_ticks(synths_ids)

        plt.draw()  # or does not update graphically...
        if not self.show_radio_selector:
            figurefiles.save_in_subjects_folder(
                self.fig, "Perf_subject_{:02d}.pdf".format(subject.index))
Пример #5
0
class FreesurferReviewInterface(BaseReviewInterface):
    """Custom interface for rating the quality of Freesurfer parcellation."""
    def __init__(self,
                 fig,
                 axes_seg,
                 rating_list=cfg.default_rating_list,
                 next_button_callback=None,
                 quit_button_callback=None,
                 alpha_seg=cfg.default_alpha_seg):
        """Constructor"""

        super().__init__(fig, axes_seg, next_button_callback,
                         quit_button_callback)

        self.rating_list = rating_list

        self.overlaid_artists = axes_seg
        self.latest_alpha_seg = alpha_seg
        self.prev_axis = None
        self.prev_ax_pos = None
        self.zoomed_in = False
        self.add_radio_buttons()
        self.add_alpha_slider()

        self.next_button_callback = next_button_callback
        self.quit_button_callback = quit_button_callback

        # this list of artists to be populated later
        # makes to handy to clean them all
        self.data_handles = list()

        self.unzoomable_axes = [
            self.radio_bt_rating.ax, self.text_box.ax, self.bt_next.ax,
            self.bt_quit.ax
        ]

    def add_radio_buttons(self):

        ax_radio = plt.axes(cfg.position_radio_buttons,
                            facecolor=cfg.color_rating_axis,
                            aspect='equal')
        self.radio_bt_rating = RadioButtons(ax_radio,
                                            self.rating_list,
                                            active=None,
                                            activecolor='orange')
        self.radio_bt_rating.on_clicked(self.save_rating)
        for txt_lbl in self.radio_bt_rating.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_rating.circles:
            circ.set(radius=0.06)

    def add_alpha_slider(self):
        """Controls the transparency level of overlay"""

        # alpha slider
        ax_slider = plt.axes(cfg.position_slider_seg_alpha,
                             facecolor=cfg.color_slider_axis)
        self.slider = Slider(ax_slider,
                             label='transparency',
                             valmin=0.0,
                             valmax=1.0,
                             valinit=0.7,
                             valfmt='%1.2f')
        self.slider.label.set_position((0.99, 1.5))
        self.slider.on_changed(self.set_alpha_value)

    def save_rating(self, label):
        """Update the rating"""

        # print('  rating {}'.format(label))
        self.user_rating = label

    def get_ratings(self):
        """Returns the final set of checked ratings"""

        return self.user_rating

    def allowed_to_advance(self):
        """
        Method to ensure work is done for current iteration,
        before allowing the user to advance to next subject.

        Returns False if atleast one of the following conditions are not met:
            Atleast Checkbox is checked
        """

        return self.radio_bt_rating.value_selected is not None

    def reset_figure(self):
        "Resets the figure to prepare it for display of next subject."

        self.clear_data()
        self.clear_radio_buttons()
        self.clear_notes_annot()

    def clear_data(self):
        """clearing all data/image handles"""

        if self.data_handles:
            for artist in self.data_handles:
                artist.remove()
            # resetting it
            self.data_handles = list()

        # this is populated for each unit during display
        self.overlaid_artists.clear()

    def clear_radio_buttons(self):
        """Clears the radio button"""

        # enabling default rating encourages lazy advancing without review
        # self.radio_bt_rating.set_active(cfg.index_freesurfer_default_rating)
        for index, label in enumerate(self.radio_bt_rating.labels):
            if label.get_text() == self.radio_bt_rating.value_selected:
                self.radio_bt_rating.circles[index].set_facecolor(
                    cfg.color_rating_axis)
                break
        self.radio_bt_rating.value_selected = None

    def clear_notes_annot(self):
        """clearing notes and annotations"""

        self.text_box.set_val(cfg.textbox_initial_text)
        # text is matplotlib artist
        self.annot_text.remove()

    def on_mouse(self, event):
        """Callback for mouse events."""

        if self.prev_axis is not None:
            # include all the non-data axes here (so they wont be zoomed-in)
            if event.inaxes not in self.unzoomable_axes:
                self.prev_axis.set_position(self.prev_ax_pos)
                self.prev_axis.set_zorder(0)
                self.prev_axis.patch.set_alpha(0.5)
                self.zoomed_in = False

        # right click toggles overlay
        if event.button in [3]:
            self.toggle_overlay()
        # double click to zoom in to any axis
        elif event.dblclick and event.inaxes is not None and \
            event.inaxes not in self.unzoomable_axes:
            # zoom axes full-screen
            self.prev_ax_pos = event.inaxes.get_position()
            event.inaxes.set_position(cfg.zoomed_position)
            event.inaxes.set_zorder(1)  # bring forth
            event.inaxes.set_facecolor('black')  # black
            event.inaxes.patch.set_alpha(1.0)  # opaque
            self.zoomed_in = True
            self.prev_axis = event.inaxes

        else:
            pass

        plt.draw()

    def on_keyboard(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax or key_in.key is None:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in ['right', ' ', 'space']:
            self.next_button_callback()
        if key_pressed in ['ctrl+q', 'q+ctrl']:
            self.quit_button_callback()
        else:
            if key_pressed in cfg.default_rating_list_shortform:
                self.user_rating = cfg.map_short_rating[key_pressed]
                index_to_set = cfg.default_rating_list.index(self.user_rating)
                self.radio_bt_rating.set_active(index_to_set)
            elif key_pressed in ['t']:
                self.toggle_overlay()
            else:
                pass

    def toggle_overlay(self):
        """Toggles the overlay by setting alpha to 0 and back."""

        if self.latest_alpha_seg != 0.0:
            self.prev_alpha_seg = self.latest_alpha_seg
            self.latest_alpha_seg = 0.0
        else:
            self.latest_alpha_seg = self.prev_alpha_seg
        self.update()

    def set_alpha_value(self, latest_value):
        """" Use the slider to set alpha."""

        self.latest_alpha_seg = latest_value
        self.update()

    def update(self):
        """updating seg alpha for all axes"""

        for art in self.overlaid_artists:
            art.set_alpha(self.latest_alpha_seg)

        # update figure
        plt.draw()
class plot_2D:
    '''
    2D axis plot for both toric/planar lattices.

    Plots the qubits as cirlces, including the errors that occur on these qubits.
    Plots the stabilizers, their measurement state and the matching between the stabilizers.

    Many plot parameters, including colors of the plot, linewidths, scatter sizes are defined here.

    '''
    def __init__(self, graph, z=0, from3D=0, **kwargs):

        self.size = graph.size
        self.graph = graph
        self.from3D = from3D

        self.qsize = 0.15
        self.pick = 5

        self.alpha = 0.4
        self.alpha2 = 0.3

        for key, value in kwargs.items():
            setattr(self, key, value)

        # Define colors
        self.cw = [1, 1, 1]
        self.cl = [0.8, 0.8, 0.8]  # Line color
        self.cc = [0.7, 0.7, 0.7]  # Qubit color
        self.ec = [0.3, 0.3, 0.3]  # Erasure color
        self.cx = [0.9, 0.3, 0.3]  # X error color
        self.cz = [0.5, 0.5, 0.9]  # Z error color
        self.cy = [0.9, 0.9, 0.5]  # Y error color
        self.cX = [0.9, 0.7, 0.3]  # X quasiparticle color
        self.cZ = [0.3, 0.9, 0.3]  # Z quasiparticle color
        self.Cx = [0.5, 0.1, 0.1]
        self.Cz = [0.1, 0.1, 0.5]
        self.cE = [0.9, 0.3, 0.7]  # Erasure color
        self.C1 = [self.cx, self.cz]
        self.C2 = [self.cX, self.cZ]
        self.LS = ["-", "--"]
        self.LS2 = [":", "--"]

        self.history = dd(dict)
        self.iter = 0
        self.iter_names = ["Initial"]
        self.iter_plot = 0
        self.recent = 0

        self.f = plt.figure(figsize=(self.plot_size, self.plot_size))
        plt.ion()
        plt.cla()
        plt.show()
        plt.axis("off")
        self.ax = plt.axes([0.075, 0.1, 0.7, 0.85])
        self.ax.set_aspect("equal")

        self.canvas = self.f.canvas
        self.canvas.callbacks.connect('pick_event', self.on_pick)

        self.prev_button = Button(plt.axes([0.75, 0.025, 0.125, 0.05]),
                                  "Previous")
        self.next_button = Button(plt.axes([0.9, 0.025, 0.075, 0.05]), "Next")
        self.prev_button.on_clicked(self.draw_prev)
        self.next_button.on_clicked(self.draw_next)
        self.rax = plt.axes([0.9, 0.1, 0.075, 0.125])
        self.radio_button = RadioButtons(self.rax, ("info", "X", "Z", "E"))

        self.ax_text = plt.axes([0.025, 0.025, 0.7, 0.05])
        plt.axis("off")
        self.text = self.ax_text.text(0.5,
                                      0.5,
                                      "",
                                      fontsize=10,
                                      va="center",
                                      ha="center",
                                      transform=self.ax_text.transAxes)

        self.init_plot(z)

        self.radio_button.set_active(0)
        plt.setp(self.rax, visible=0)

    def on_pick(self, event):
        '''
        Pick event handler for the plots
        Normally prints some info about the nodes on the plot.
        In the initial round, the user can opt to manually add extra errors onto the lattice.
        '''
        artist = event.artist
        radiovalue = self.radio_button.value_selected

        if radiovalue == "info":
            print("picked", artist.object.picker())
        else:
            qubit = artist.object
            '''
            Need to calculate time between pick events due to 3d_scatter workaround. When switching between plot objects in the workaround, we swap the visibility and also the picker attribute from None to True. However, the pick event is somehow stored for some period, such that after the swap of the picker attribute, a second pick event is registered. We therefore wait 0.1 seconds between pick events.
            '''
            prev_time = getattr(qubit, "pick_time", None)
            qubit.pick_time = time()
            if prev_time and qubit.pick_time - prev_time < 0.1:
                return

            if radiovalue == "X":
                qubit.E[0].state = not qubit.E[0].state
            elif radiovalue == "Z":
                qubit.E[1].state = not qubit.E[1].state
            elif radiovalue == "E":
                qubit.erasure = not qubit.erasure

            attr_dict = self.get_error_attr(qubit)

            if not attr_dict:
                attr_dict = dict(fill=0, facecolor=self.cw)

            if qubit.erasure:
                attr_dict.update(dict(linestyle=":"))
            else:
                attr_dict.update(dict(linestyle="-"))

            if attr_dict:
                self.new_attributes(qubit.pg, attr_dict)

    '''
    #########################################################################
                            Waiting funtions
    '''

    def draw_plot(self):
        '''
        Blits all changed plotting object onto the figure.
        Optional text is printed, added to the log and shown on the figure
        '''
        txt = self.iter_names[self.iter]
        self.text.set_text(txt)
        pr.printlog(f"{txt} plotted.")
        self.canvas.blit(self.ax.bbox)
        self.waitforkeypress()

    def waitforkeypress(self):
        '''
        Pauses the script until user interaction on the plot.
        Waits for a maximum of 120 seconds.
        '''
        wait = True
        while wait:
            wait = not plt.waitforbuttonpress(-1) or self.recent

    '''
    #########################################################################
                            Playback funtions
    '''

    def new_iter(self, name):
        '''
        Initiates new plot iteration
        '''
        self.iter_names.append(name)
        self.iter += 1
        self.iter_plot += 1

    def draw_next(self, event=None):
        '''
        Redraws all changes from next plot iteration onto the plot
        '''
        if self.iter_plot < self.iter:
            self.iter_plot += 1
            text = self.iter_names[self.iter_plot]
            self.text.set_text(text)
            for object, changes in self.history[self.iter_plot].items():
                self.change_attributes(object, changes)
            self.canvas.blit(self.ax.bbox)
            print("Drawing next: {}".format(text))
            if self.iter_plot == self.iter:
                self.recent = 0
        elif self.iter_plot == self.iter:
            print("Can't go further!")

    def draw_prev(self, event=None):
        '''
        Redraws all changes from previous plot iteration onto the plot
        '''
        if self.iter_plot >= 1:
            self.recent = 1
            self.iter_plot -= 1
            text = self.iter_names[self.iter_plot]
            self.text.set_text(text)
            for object, changes in self.history[self.iter_plot].items():
                self.change_attributes(object, changes)
            self.canvas.blit(self.ax.bbox)

            print("Drawing previous: {}".format(text))
        else:
            print("Can't go back further!")

    '''
    #########################################################################
                            Change attribute functions
    '''

    def get_nested_np_color(self, array):
        '''
        Get nested color and makes np.array, which is sometimes but not at all times used for color values, to a list.
        '''
        def get_nested(value):
            if type(value) == list and type(value[0]) == list:
                return get_nested(value[0])
            else:
                return value

        if type(array).__name__ == "ndarray":
            return get_nested(array.tolist())
        elif type(array) == list:
            return get_nested(array)
        else:
            return array

    def new_attributes(self, obj, attr_dict, overwrite=False):
        '''
        Finds the differences of the plot attributes between this iteration and the previous iterations. All differences are stored as dictionaries in the history variable.
        Makes sure that all changes are stored correctly and plot attributes are not overwritten if not explicitly defined.
        '''
        prev_changes = self.history[self.iter - 1]
        next_changes = self.history[self.iter]

        prev, next = {}, {}

        if not overwrite or obj not in prev_changes:
            for key, value in attr_dict.items():

                value = self.get_nested_np_color(value)
                old_value = self.get_nested_np_color(plt.getp(obj, key))

                if old_value != value:
                    prev[key] = old_value
                    next[key] = value
        else:
            old_dict = prev_changes[obj]
            for key, value in attr_dict.items():
                value = self.get_nested_np_color(value)
                old_value = old_dict[
                    key] if key in old_dict else self.get_nested_np_color(
                        plt.getp(obj, key))
                if old_value != value:
                    prev[key] = old_value
                    next[key] = value
        if prev:
            if overwrite or obj not in prev_changes:
                prev_changes[obj] = prev
            else:
                prev_changes[obj].update(prev)

        if next:
            next_changes[obj] = next
            self.change_attributes(obj, next)

    def change_attributes(self, object, attr_dict):
        '''
        Redraws the attributes from the dictionary onto the plot object
        '''
        if attr_dict:
            plt.setp(object, **attr_dict)
        self.ax.draw_artist(object)

    '''
    #########################################################################
                            Initilize axes
    '''

    def legend_circle(self,
                      label,
                      mfc=None,
                      marker="o",
                      mec="k",
                      ms=10,
                      color="w",
                      lw=0,
                      mew=2,
                      ls="-"):
        '''
        Returns a Line2D, cirlle object that is used on the plot legend.
        '''
        return Line2D(
            [0],
            [0],
            lw=lw,
            ls=ls,
            marker=marker,
            color=color,
            mec=mec,
            mew=mew,
            mfc=mfc,
            ms=ms,
            label=label,
        )

    def init_legend(self, x, y, items=[], loc="upper right"):
        '''
        Initilizes the legend of the plot.
        The qubits, errors and stabilizers are added.
        Aditional legend items can be inputted through the items paramter
        '''

        self.ax.set_title("{} lattice".format(self.graph.__class__.__name__))

        le_qubit = self.legend_circle("Qubit", mfc=self.cc, mec=self.cc)
        le_xer = self.legend_circle("X-error", mfc=self.cx, mec=self.cx)
        le_zer = self.legend_circle("Y-error", mfc=self.cz, mec=self.cz)
        le_yer = self.legend_circle("Z-error", mfc=self.cy, mec=self.cy)
        le_ver = self.legend_circle("Vertex",
                                    ls="-",
                                    lw=self.linewidth,
                                    color=self.cX,
                                    mfc=self.cX,
                                    mec=self.cX,
                                    marker="|")
        le_pla = self.legend_circle("Plaquette",
                                    ls="--",
                                    lw=self.linewidth,
                                    color=self.cZ,
                                    mfc=self.cZ,
                                    mec=self.cZ,
                                    marker="|")

        self.lh = [le_qubit, le_xer, le_zer, le_yer, le_ver, le_pla] + items

        self.ax.legend(handles=self.lh, bbox_to_anchor=(x, y), loc=loc, ncol=1)

    def init_axis(self, min, max):
        '''
        Initilizes the 2D axis by settings axis limits, flipping y axis and removing the axis border
        '''
        # plt.grid(alpha = self.alpha2, ls=":", lw=self.linewidth)
        self.ax.set_xlim(min, max)
        self.ax.set_ylim(min, max)
        self.ax.invert_yaxis()
        self.ax.spines["top"].set_visible(False)
        self.ax.spines["right"].set_visible(False)
        self.ax.spines["bottom"].set_visible(False)
        self.ax.spines["left"].set_visible(False)
        plt.axis("off")

    '''
    #########################################################################
                            Initilize plot
    '''

    def init_plot(self, z=0):
        '''
        param: z        z layer to plot, defaults to 0
        Initializes 2D plot of toric/planar lattice
        Stabilizers are plotted with line objects
        Qubits are plotted with Circle objects
        '''
        plt.sca(self.ax)
        self.init_axis(-.25, self.size - .25)

        # Plot stabilizers
        for stab in self.graph.S[z].values():
            self.plot_stab(stab, alpha=self.alpha)

        # Plot open boundaries if exists
        if hasattr(self.graph, 'B'):
            for bound in self.graph.B[z].values():
                self.plot_stab(bound, alpha=self.alpha)

        # Plot qubits
        for qubit in self.graph.Q[z].values():
            self.plot_qubit(qubit)

        le_err = self.legend_circle("Erasure",
                                    mfc="w",
                                    marker="$\u25CC$",
                                    mec=self.cc,
                                    mew=1,
                                    ms=12)
        self.init_legend(1.3, 0.95, items=[le_err])

        self.canvas.draw()
        if not self.from3D:
            self.draw_plot()

    '''
    #########################################################################
                            Helper plot funtions
    '''

    def draw_line(self, X, Y, color="w", lw=2, ls=2, alpha=1, **kwargs):
        '''
        Plots a line onto the plot. Exist for default parameters.
        '''
        return self.ax.plot(X, Y, c=color, lw=lw, ls=ls, alpha=alpha)[0]

    def plot_stab(self, stab, alpha=1):
        '''
        param: stab         graph.stab object
        param: alpha        alpha for all line objects of this stab
        Plots stabilizers as line objects.
        Loop over layer neighbor keys to ensure compatibility with planar/toric lattices
        '''
        (type, y, x), zb = stab.sID, stab.z
        y += .5 * type
        x += .5 * type
        ls = "-" if type == 0 else "--"

        stab.pg = {}
        for dir in [dir for dir in self.graph.dirs if dir in stab.neighbors]:
            if dir == "w":
                X, Y = [x - .25, x + 0], [y + 0, y + 0]
            elif dir == "e":
                X, Y = [x + 0, x + .25], [y + 0, y + 0]
            elif dir == "n":
                X, Y = [x + 0, x + 0], [y - .25, y + 0]
            elif dir == "s":
                X, Y = [x + 0, x + 0], [y + 0, y + .25]

            line = self.draw_line(X,
                                  Y,
                                  Z=zb * self.z_distance,
                                  color=self.cl,
                                  lw=self.linewidth,
                                  ls=ls,
                                  alpha=alpha)
            stab.pg[dir] = line
            line.object = stab

    def plot_qubit(self, qubit):
        '''
        param: qubit        graph.qubit object
        Patch.Circle object for each qubit on the lattice
        '''
        (td, yb, xb) = qubit.qID
        X, Y = (xb + .5, yb) if td == 0 else (xb, yb + .5)
        qubit.pg = plt.Circle(
            (X, Y),
            self.qsize,
            edgecolor=self.cc,
            fill=False,
            lw=self.linewidth,
            picker=self.pick,
        )
        self.ax.add_artist(qubit.pg)
        qubit.pg.object = qubit

    '''
    #########################################################################
                            Plotting functions
    '''

    def get_error_attr(self, qubit):
        '''
        returns plot attributes of a qubit plot if there is an pauli error
        '''
        X_error = qubit.E[0].state
        Z_error = qubit.E[1].state

        attr_dict = {}
        if X_error or Z_error:
            if X_error and not Z_error:
                color = self.cx
            elif Z_error and not X_error:
                color = self.cz
            else:
                color = self.cy
            attr_dict.update(dict(fill=1, facecolor=color, edgecolor=self.cc))
        return attr_dict

    def plot_erasures(self, z=0, draw=True):
        """
        :param erasures         list of locations (TD, y, x) of the erased stab_qubits
        plots an additional blue cicle around the qubits which has been erased
        """
        if z == 0: self.new_iter("Erasure")

        for qubit in self.graph.Q[0].values():
            qplot = qubit.pg

            attr_dict = self.get_error_attr(qubit)

            if qubit.erasure:
                attr_dict.update(dict(linestyle=":"))

            if attr_dict:
                self.new_attributes(qplot, attr_dict)

        if draw: self.draw_plot()

    def plot_errors(self, z=0, plot_qubits=False, draw=True):
        """
        :param arrays       array of qubit states
        plots colored circles within the qubits if there is an error
        """

        if z == 0:
            round = "Result" if plot_qubits else "Errors"
            self.new_iter(round)

        for qubit in self.graph.Q[z].values():
            qplot = qubit.pg

            attr_dict = self.get_error_attr(qubit)

            if not attr_dict and plot_qubits:
                attr_dict = dict(fill=0, facecolor=self.cw)

            if attr_dict:
                self.new_attributes(qplot, attr_dict)

        if draw: self.draw_plot()

    def plot_syndrome(self, z=0, draw=True):
        """
        :param qua_loc      list of quasiparticle/anyon positions (y,x)
        plots the vertices of the anyons on the lattice
        """
        if z == 0: self.new_iter("Syndrome")

        for stab in self.graph.S[z].values():
            (ertype, yb, xb) = stab.sID
            if stab.parity:
                for dir in self.graph.dirs:
                    if dir in stab.neighbors:
                        gplot = stab.pg[dir]
                        self.new_attributes(gplot, dict(color=self.C2[ertype]))
        if draw: self.draw_plot()

    def plot_lines(self, matchings):
        """
        :param results      list of matchings of anyon
        plots strings between the two anyons of each match
        """
        P = [0, .5]
        self.new_iter("Matching")

        for _, _, v0, v1 in matchings:

            color = [random.random() * 0.8 + 0.2 for _ in range(3)]

            (type, topy, topx), topz = v0.sID, v0.z
            (type, boty, botx), botz = v1.sID, v1.z
            p, ls = P[type], self.LS2[type]

            X = [topx + p, botx + p]
            Y = [topy + p, boty + p]
            Z = [(topz - .5) * self.z_distance, (botz - .5) * self.z_distance]
            lplot = self.draw_line(X,
                                   Y,
                                   Z=Z,
                                   color=color,
                                   lw=self.linewidth,
                                   ls=ls,
                                   alpha=self.alpha2)

            self.history[self.iter - 1][lplot] = dict(visible=0)
            self.history[self.iter][lplot] = dict(visible=1)

        self.draw_plot()

    def plot_final(self):
        """
        param: flips        qubits that have flipped in value (y,x)
        param: arrays       data array of the (corrected) qubit states
        plots the applied stabilizer measurements over the lattices
        also, in the qubits that have flipped in value a smaller white circle is plotted

        optionally, the axis is clear and the final state of the lattice is plotted
        """

        plt.sca(self.ax)
        self.new_iter("Final")

        for qubit in self.graph.Q[0].values():
            qplot = qubit.pg
            X_error = qubit.E[0].matching
            Z_error = qubit.E[1].matching

            if X_error or Z_error:
                if X_error and not Z_error:
                    color = self.cx
                elif Z_error and not X_error:
                    color = self.cz
                else:
                    color = self.cy
                self.new_attributes(qplot, dict(edgecolor=color))

        self.draw_plot()
        self.plot_errors(plot_qubits=True)
Пример #7
0
class ObjFindGUI:
    """
    GUI to interactively identify object traces. The GUI can be run within
    PypeIt during data reduction, or as a standalone script outside of
    PypeIt. To initialise the GUI, call the initialise() function in this
    file.
    """
    def __init__(self, canvas, image, frame, det, sobjs, left, right, obj_trace, trace_models,
                 axes, profdict, slit_ids=None, printout=False, runtime=False):
        """Controls for the interactive Object ID tasks in PypeIt.

        The main goal of this routine is to interactively add/delete/modify
        object apertures.

        Args:
            canvas (Matploltib figure canvas):
                The canvas on which all axes are contained
            image (AxesImage):
                The image plotted to screen
            frame (ndarray):
                The image data
            det (int):
                Detector to add a slit on
            sobjs (SpecObjs, None):
                An instance of the SpecObjs class
            left (numpy.ndarray):
                Slit left edges
            right (numpy.ndarray):
                Slit right edges
            obj_trace (dict):
                Result of
                :func:`pypeit.scripts.object_finding.parse_traces`.
            trace_models (dict):
                Dictionary with the object, standard star, and slit
                trace models
            axes (dict):
                Dictionary of four Matplotlib axes instances (Main
                spectrum panel, two for residuals, one for information)
            profdict (dict):
                Dictionary containing profile information (profile data,
                and the left/right lines displayinf the FWHM)
            slit_ids (list, None):
                List of slit ID numbers
            printout (bool):
                Should the results be printed to screen
            runtime (bool):
                Is the GUI being launched during data reduction?
        """
        # Store the axes
        self._det = det
        self.image = image
        self.frame = frame
        self.nspec, self.nspat = frame.shape[0], frame.shape[1]
        self._spectrace = np.arange(self.nspec)
        self.profile = profdict
        self._printout = printout
        self._runtime = runtime
        self._slit_ids = slit_ids

        self.left = left
        self.right = right
        self.obj_trace = obj_trace
        self.trace_models = trace_models

        self.axes = axes
        self.specobjs = sobjs
        self.objtraces = []
        self._object_traces = ObjectTraces()
        self.anchors = []
        self._obj_idx = -1
        self._spatpos = np.arange(frame.shape[1])[np.newaxis, :].repeat(frame.shape[0], axis=0)  # Spatial coordinate (as the frame shape)
        self.empty_mantrace()
        if sobjs is None:
            self._object_traces.from_dict(self.obj_trace, det)
        else:
            self._object_traces.from_specobj(sobjs, det)

        # Unset some of the matplotlib keymaps
        matplotlib.pyplot.rcParams['keymap.fullscreen'] = ''        # toggling fullscreen (Default: f, ctrl+f)
        #matplotlib.pyplot.rcParams['keymap.home'] = ''              # home or reset mnemonic (Default: h, r, home)
        matplotlib.pyplot.rcParams['keymap.back'] = ''              # forward / backward keys to enable (Default: left, c, backspace)
        matplotlib.pyplot.rcParams['keymap.forward'] = ''           # left handed quick navigation (Default: right, v)
        #matplotlib.pyplot.rcParams['keymap.pan'] = ''              # pan mnemonic (Default: p)
        matplotlib.pyplot.rcParams['keymap.zoom'] = ''              # zoom mnemonic (Default: o)
        matplotlib.pyplot.rcParams['keymap.save'] = ''              # saving current figure (Default: s)
        matplotlib.pyplot.rcParams['keymap.quit'] = ''              # close the current figure (Default: ctrl+w, cmd+w)
        matplotlib.pyplot.rcParams['keymap.grid'] = ''              # switching on/off a grid in current axes (Default: g)
        matplotlib.pyplot.rcParams['keymap.yscale'] = ''            # toggle scaling of y-axes ('log'/'linear') (Default: l)
        matplotlib.pyplot.rcParams['keymap.xscale'] = ''            # toggle scaling of x-axes ('log'/'linear') (Default: L, k)
        matplotlib.pyplot.rcParams['keymap.all_axes'] = ''          # enable all axes (Default: a)

        # Initialise the main canvas tools
        canvas.mpl_connect('draw_event', self.draw_callback)
        canvas.mpl_connect('button_press_event', self.button_press_callback)
        canvas.mpl_connect('key_press_event', self.key_press_callback)
        canvas.mpl_connect('button_release_event', self.button_release_callback)
        canvas.mpl_connect('motion_notify_event',  self.mouse_move_callback)
        self.canvas = canvas

        # Interaction variables
        self._respreq = [False, None]  # Does the user need to provide a response before any other operation will be permitted? Once the user responds, the second element of this array provides the action to be performed.
        self._qconf = False  # Confirm quit message
        self._changes = False
        self._use_updates = True
        self._trcmthd = 'object'
        self.mmx, self.mmy = 0, 0
        self._inslit = 0  # Which slit is the mouse in

        # Draw the spectrum
        self.canvas.draw()

        # Initialise buttons and menu options
        self._ax_meth_default = 'Object'
        self._methdict = dict({'Object': [0, 'object'],
                               'Standard Star': [1, 'std'],
                               'Slit Edges': [2, 'slit']})
#                               'Manual': [3, 'manual']
        self.initialise_menu()

    def print_help(self):
        """Print the keys and descriptions that can be used for Identification
        """
        keys = operations.keys()
        print("===============================================================")
        print("Add/remove object traces until you are happy with the resulting")
        print("traces. When you've finished, click one of the exit buttons on")
        print("the right side of the page. If you click 'Continue (and save changes)'")
        print("the object traces will be printed to the terminal, where you can")
        print("copy them into your .pypeit file.")
        print("")
        print("thick coloured dashed lines = object traces")
        print("thick coloured solid line   = currently selected object trace")
        print("thin green/blue lines       = slit edges")
        print("")
        print("Meanings of the different coloured dashed lines:")
        print(" green = user-defined object trace")
        print(" blue  = trace automatically generated with PypeIt")
        print(" red   = trace automatically generated with PypeIt (deleted)")
        print("===============================================================")
        print("       OBJECT ID OPERATIONS")
        for key in keys:
            print("{0:6s} : {1:s}".format(key, operations[key]))
        print("---------------------------------------------------------------")

    def initialise_menu(self):
        """Initialise the menu buttons
        """
        axcolor = 'lightgoldenrodyellow'
        # Continue with reduction (using updated specobjs)
        ax_cont = plt.axes([0.82, 0.85, 0.15, 0.05])
        self._ax_cont = Button(ax_cont, "Continue (and save changes)", color=axcolor, hovercolor='y')
        self._ax_cont.on_clicked(self.button_cont)
        # Continue with reduction (using original specobjs)
        ax_exit = plt.axes([0.82, 0.79, 0.15, 0.05])
        self._ax_exit = Button(ax_exit, "Continue (don't save changes)", color=axcolor, hovercolor='y')
        self._ax_exit.on_clicked(self.button_exit)
        # Button to select trace method
        rax = plt.axes([0.82, 0.59, 0.15, 0.15], facecolor=axcolor)
        rax.set_title("Select trace method:")
        self._ax_meth = RadioButtons(rax, ('Object', 'Standard Star', 'Slit Edges'))#, 'Manual'))
        self._ax_meth.on_clicked(self.radio_meth)
        # Determine the best default to use:
        if self.trace_models['object']['trace_model'] is not None:
            self._ax_meth_default = 'Object'
        elif self.trace_models['std']['trace_model'] is not None:
            self._ax_meth_default = 'Standard Star'
        elif self.trace_models['slit']['trace_model'] is not None:
            self._ax_meth_default = 'Slit Edges'
#        elif self._trcdict["trace_model"]["manual"]["trace_model"] is not None:
#            self._ax_meth_default = 'Manual'
        # Set the active method
        self._ax_meth.set_active(self._methdict[self._ax_meth_default][0])

    def radio_meth(self, label, infobox=True):
        """Tell the code what to do when a different trace method is selected

        Args:
            label (str): The label of the radio button that was clicked
        """
        # Update the radio button
        if self._methdict[label][1]:
            self._trcmthd = self._methdict[label][1]
        # Check if the method is available, if not, change to the default (manual is always allowed)
        if self._trcmthd != "manual":
            if self.trace_models[self._trcmthd]['trace_model'] is None:
                self.update_infobox(message="That option is not available - changing to default",
                                    yesno=False)
                self._ax_meth.set_active(self._methdict[self._ax_meth_default][0])
                self._trcmthd = self._methdict[self._ax_meth_default][1]
            else:
                if infobox:
                    self.update_infobox(message="Trace method set to: {0:s}".format(label), yesno=False)
        else:
            if infobox:
                self.update_infobox(message="Trace method set to: {0:s}".format(label), yesno=False)

    def button_cont(self, event):
        """What to do when the 'exit and save' button is clicked
        """
        self._respreq = [True, "exit_update"]
        self.update_infobox(message="Are you sure you want to exit and use the updated object traces?", yesno=True)

    def button_exit(self, event):
        """What to do when the 'exit and do not save changes' button is clicked
        """
        self._respreq = [True, "exit_restore"]
        self.update_infobox(message="Are you sure you want to exit and use the original object traces?", yesno=True)

    def replot(self):
        """Redraw the entire canvas
        """
        self.canvas.restore_region(self.background)
        self.draw_objtraces()
        self.draw_anchors()
        self.canvas.draw()

    def draw_objtraces(self):
        """Draw the object traces
        """
        for i in self.objtraces: i.pop(0).remove()
        self.objtraces = []
        # Plot the object traces
        allcols = ['DodgerBlue', 'LimeGreen', 'r']  # colors mean: [original, added, deleted]
        for iobj in range(self._object_traces.nobj):
            color = allcols[self._object_traces._add_rm[iobj]]
            if iobj == self._obj_idx:
                self.objtraces.append(self.axes['main'].plot(self._object_traces._trace_spat[iobj],
                                                             self._object_traces._trace_spec[iobj],
                                                             color=color,
                                                             linestyle='-', linewidth=4, alpha=0.5))
            else:
                self.objtraces.append(self.axes['main'].plot(self._object_traces._trace_spat[iobj],
                                                             self._object_traces._trace_spec[iobj],
                                                             color=color,
                                                             linestyle='--', linewidth=3, alpha=0.5))

    def draw_anchors(self):
        """Draw the anchors for manual tracing
        """
        for i in self.anchors: i.pop(0).remove()
        self.anchors = []
        # Plot the best fitting trace, if it exists
        if self._mantrace["spat_trc"] is not None:
            self.anchors.append(self.axes['main'].plot(self._mantrace["spat_trc"], self._mantrace["spec_trc"],
                                                       'g-', linewidth=3, alpha=0.5))
        # Plot the anchor points on top
        self.anchors.append(self.axes['main'].plot(self._mantrace["spat_a"], self._mantrace["spec_a"], 'ro', alpha=0.5))

    def draw_profile(self):
        """Draw the object profile
        """
        if self._obj_idx == -1:
            sz = self.profile['profile'].get_xdata.size
            self.profile['profile'].set_ydata(np.zeros(sz))
        else:
            # Plot the extent of the FWHM
            self.profile['fwhm'][0].set_xdata(-self._object_traces._fwhm[self._obj_idx]/2.0)
            self.profile['fwhm'][1].set_xdata(+self._object_traces._fwhm[self._obj_idx]/2.0)
            # Update the data shown
            objprof = self.make_objprofile()
            self.profile['profile'].set_ydata(objprof)
            self.axes['profile'].set_xlim([-self._object_traces._fwhm[self._obj_idx],
                                           +self._object_traces._fwhm[self._obj_idx]])
            omin, omax = objprof.min(), objprof.max()
            self.axes['profile'].set_ylim([omin-0.1*(omax-omin), omax+0.1*(omax-omin)])

    def draw_callback(self, event):
        """Draw callback (i.e. everytime the canvas is being drawn/updated)

        Args:
            event (Event): A matplotlib event instance
        """
        # Get the background
        self.background = self.canvas.copy_from_bbox(self.axes['main'].bbox)
        self.draw_objtraces()

    def get_ind_under_point(self, event):
        """Get the index of the object trace closest to the cursor

        Args:
            event (Event): Matplotlib event instance containing information about the event
        """
        mindist = self._spatpos.shape[0]**2
        self._obj_idx = -1
        for iobj in range(self._object_traces.nobj):
            dist = (event.xdata-self._object_traces._trace_spat[iobj])**2 +\
                   (event.ydata-self._object_traces._trace_spec[iobj])**2
            if np.min(dist) < mindist:
                mindist = np.min(dist)
                self._obj_idx = iobj
        if self._obj_idx != -1:
            self.draw_profile()
        return

    def get_axisID(self, event):
        """Get the ID of the axis where an event has occurred

        Args:
            event (Event): Matplotlib event instance containing information about the event

        Returns:
            int, None: Axis where the event has occurred
        """
        if event.inaxes == self.axes['main']:
            return 0
        elif event.inaxes == self.axes['info']:
            return 1
        elif event.inaxes == self.axes['profile']:
            return 2
        return None

    def mouse_move_callback(self, event):
        """Store the locations of mouse as it moves across the canvas
        """
        if event.inaxes is None:
            return
        axisID = self.get_axisID(event)
        if event.inaxes == self.axes['main']:
            self.mmx, self.mmy = event.xdata, event.ydata

    def button_press_callback(self, event):
        """What to do when the mouse button is pressed

        Args:
            event (Event): Matplotlib event instance containing information about the event
        """
        if event.inaxes is None:
            return
        if self.canvas.toolbar.mode != "":
            return
        if event.button == 1:
            self._addsub = 1
        elif event.button == 3:
            self._addsub = 0
        if event.inaxes == self.axes['main']:
            self._start = [event.x, event.y]
        elif event.inaxes == self.axes['profile']:
            self._start = [event.x, event.y]

    def button_release_callback(self, event):
        """What to do when the mouse button is released

        Args:
            event (Event): Matplotlib event instance containing information about the event
        """
        if event.inaxes is None:
            return
        if event.inaxes == self.axes['info']:
            if (event.xdata > 0.8) and (event.xdata < 0.9):
                answer = "y"
            elif event.xdata >= 0.9:
                answer = "n"
            else:
                return
            self.operations(answer, -1)
            self.update_infobox(default=True)
            return
        elif event.inaxes == self.axes['profile']:
            if (event.x == self._start[0]) and (event.y == self._start[1]):
                self.set_fwhm(event.xdata)
                self.update_infobox(message="FWHM updated for the selected object", yesno=False)
            return
        elif self._respreq[0]:
            # The user is trying to do something before they have responded to a question
            return
        if self.canvas.toolbar.mode != "":
            return
        # Draw an actor
        axisID = self.get_axisID(event)
        if axisID is not None:
            if axisID <= 2:
                self._end = [event.x, event.y]
                if (self._end[0] == self._start[0]) and (self._end[1] == self._start[1]):
                    # The mouse button was pressed (not dragged)
                    self.get_ind_under_point(event)
                    self.update_infobox(message="Object selected", yesno=False)
                    pass
                elif self._end != self._start:
                    # The mouse button was dragged
                    if axisID == 0:
                        if self._start > self._end:
                            tmp = self._start
                            self._start = self._end
                            self._end = tmp
                        # Now do something
                        pass
        # Now plot
        self.canvas.restore_region(self.background)
        self.draw_objtraces()
        self.canvas.draw()

    def key_press_callback(self, event):
        """What to do when a key is pressed

        Args:
            event (Event): Matplotlib event instance containing information about the event
        """
        # Check that the event is in an axis...
        if not event.inaxes:
            return
        # ... but not the information box!
        if event.inaxes == self.axes['info']:
            return
        axisID = self.get_axisID(event)
        self.operations(event.key, axisID)

    def operations(self, key, axisID):
        """Canvas operations

        Args:
            key (str): Which key has been pressed
            axisID (int): The index of the axis where the key has been pressed (see get_axisID)
        """
        # Check if the user really wants to quit
        if key == 'q' and self._qconf:
            if self._changes:
                self.update_infobox(message="WARNING: There are unsaved changes!!\nPress q again to exit", yesno=False)
                self._qconf = True
            else:
                msgs.bug("Need to change this to kill and return the results to PypeIt")
                plt.close()
        elif self._qconf:
            self.update_infobox(default=True)
            self._qconf = False

        # Manage responses from questions posed to the user.
        if self._respreq[0]:
            if key != "y" and key != "n":
                return
            else:
                # Switch off the required response
                self._respreq[0] = False
                # Deal with the response
                if self._respreq[1] == "delete_object" and key == "y":
                    self.delete_object()
                elif self._respreq[1] == "clear_anchors" and key == "y":
                    self.empty_mantrace()
                    self.replot()
                elif self._respreq[1] == "exit_update" and key == "y":
                    self._use_updates = True
                    self.print_pypeit_info()
                    self.operations("qu", None)
                elif self._respreq[1] == "exit_restore" and key == "y":
                    self._use_updates = False
                    self.operations("qr", None)
                else:
                    return
            # Reset the info box
            self.update_infobox(default=True)
            return

        if key == '?':
            self.print_help()
        elif key == 'a':
            self.add_object()
        elif key == 'c':
            if axisID == 0:
                # If this is pressed on the main window
                self.recenter()
#        elif key == 'c':
#            self._respreq = [True, "clear_anchors"]
#            self.update_infobox(message="Are you sure you want to clear the anchors", yesno=True)
        elif key == 'd':
            if self._obj_idx != -1:
                self._respreq = [True, "delete_object"]
                self.update_infobox(message="Are you sure you want to delete this object trace", yesno=True)
#        elif key == 'm':
#            if self._trcmthd != 'manual':
#                self.update_infobox(message="To add an anchor point, set the 'manual' trace method", yesno=False)
#            else:
#                self.add_anchor()
#        elif key == 'n':
#            self.remove_anchor()
        elif key == 'qu' or key == 'qr':
            if self._changes:
                self.update_infobox(message="WARNING: There are unsaved changes!!\nPress q again to exit", yesno=False)
                self._qconf = True
            else:
                plt.close()
#        elif key == '+':
#            if self._mantrace["polyorder"] < 10:
#                self._mantrace["polyorder"] += 1
#                self.update_infobox(message="Polynomial order = {0:d}".format(self._mantrace["polyorder"]), yesno=False)
#                self.fit_anchors()
#            else:
#                self.update_infobox(message="Polynomial order must be <= 10", yesno=False)
#        elif key == '-':
#            if self._mantrace["polyorder"] > 1:
#                self._mantrace["polyorder"] -= 1
#                self.update_infobox(message="Polynomial order = {0:d}".format(self._mantrace["polyorder"]), yesno=False)
#                self.fit_anchors()
#            else:
#                self.update_infobox(message="Polynomial order must be >= 1", yesno=False)
        self.replot()

    def add_anchor(self):
        """Add a manual anchor point
        """
        self._mantrace['spat_a'].append(self.mmx)
        self._mantrace['spec_a'].append(self.mmy)
        self.fit_anchors()

    def remove_anchor(self):
        """Remove a manual anchor point
        """
        # Find the anchor closest to the mouse position
        if len(self._mantrace['spat_a']) != 0:
            mindist = (self._mantrace['spat_a'][0]-self.mmx)**2 + (self._mantrace['spec_a'][0]-self.mmy)**2
            minidx = 0
            for ianc in range(1, len(self._mantrace['spat_a'])):
                dist = (self._mantrace['spat_a'][ianc] - self.mmx) ** 2 + (self._mantrace['spec_a'][ianc] - self.mmy) ** 2
                if dist < mindist:
                    mindist = dist
                    minidx = ianc
            del self._mantrace['spat_a'][minidx]
            del self._mantrace['spec_a'][minidx]
        self.fit_anchors()

    def fit_anchors(self):
        """Fit the manual anchor points
        """
        if len(self._mantrace['spat_a']) <= self._mantrace['polyorder']:
            self.update_infobox(message="You need to select more trace points before manually adding\n" +
                                        "a manual object trace. To do this, use the 'm' key", yesno=False)
        else:
            # Fit a polynomial to the anchor points
            coeff = np.polyfit(self._mantrace['spec_a'], self._mantrace['spat_a'], self._mantrace['polyorder'])
            self._mantrace['spat_trc'] = np.polyval(coeff, self._mantrace['spec_trc'])
        # Replot, regardless of whether a fit is done (a point might have been added/removed)
        self.replot()

    def get_slit(self):
        """Find the slit that the mouse is currently in
        """
        ypos = int(self.mmy)
        for sl in range(self.left.shape[1]):
            if (self.mmx > self.left[ypos, sl]) and (self.mmx < self.right[ypos, sl]):
                self._inslit = sl
                return

    def add_object(self):
        if self._trcmthd == 'manual' and len(self._mantrace['spat_a']) <= self._mantrace['polyorder']:
            self.update_infobox(message="You need to select more trace points before manually adding\n" +
                                        "a manual object trace. To do this, use the 'm' key", yesno=False)
            return
        # Add an object trace
        spec_vec = self._mantrace['spec_trc']
        if self._trcmthd == 'manual':
            trace_model = self._mantrace['spat_trc'].copy()
            # Now empty the manual tracing
            self.empty_mantrace()
        else:
            if self._trcmthd == 'slit':
                self.get_slit()
                trace_model = self.trace_models[self._trcmthd]['trace_model'][:,self._inslit].copy()
            else:
                trace_model = self.trace_models[self._trcmthd]['trace_model'].copy()
            spat_0 = np.interp(self.mmy, spec_vec, trace_model)
            shift = self.mmx - spat_0
            trace_model += shift
        # Determine the FWHM
        if self._object_traces.nobj != 0:
            fwhm = self._object_traces._fwhm[0]
        else:  # Otherwise just use the fwhm parameter input to the code (or the default value)
            fwhm = 2

        # Finally, add this object to the list
        self._object_traces.add_object(self._det, self.mmx, self.mmy, trace_model, self._spectrace.copy(), fwhm)

# TODO: This method must have been defunct because it uses elements of
# _trcdict that don't exist (as far as I can tell) and instantiates a
# SpecObj from the specobjs module, which hasn't existed for while
#
#    def add_object_sobj(self):
#        """Add an object to specobjs
#        """
#        if self._trcmthd == 'manual' and len(self._mantrace['spat_a']) <= self._mantrace['polyorder']:
#            self.update_infobox(message="You need to select more trace points before manually adding\n" +
#                                        "a manual object trace. To do this, use the 'm' key", yesno=False)
#            return
#        # Add an object trace
#        spec_vec = self._mantrace['spec_trc']
#        if self._trcmthd == 'manual':
#            trace_model = self._mantrace['spat_trc'].copy()
#            # Now empty the manual tracing
#            self.empty_mantrace()
#        else:
#            trace_model = self._trcdict["trace_model"][self._trcmthd]["trace_model"].copy()
#            spat_0 = np.interp(self.mmy, spec_vec, trace_model)
#            shift = self.mmx - spat_0
#            trace_model += shift
#        xsize = self._trcdict["slit_righ"] - self._trcdict["slit_left"]
#        nsamp = np.ceil(xsize.max())
#        # Extract the SpecObj parameters
#        par = self._trcdict['sobj_par']
#        # Create a SpecObj
#        thisobj = specobjs.SpecObj(par['frameshape'], par['slit_spat_pos'], par['slit_spec_pos'],
#                                   det=par['det'], setup=par['setup'], slitid=par['slitid'],
#                                   orderindx=par['orderindx'], objtype=par['objtype'])
#        thisobj.hand_extract_spat = self.mmx
#        thisobj.hand_extract_spec = self.mmy
#        thisobj.hand_extract_det = par['det']
#        thisobj.hand_extract_fwhm = None
#        thisobj.hand_extract_flag = True
#        f_ximg = RectBivariateSpline(spec_vec, np.arange(self.nspat), par["ximg"])
#        thisobj.spat_fracpos = f_ximg(thisobj.hand_extract_spec, thisobj.hand_extract_spat,
#                                      grid=False)  # interpolate from ximg
#        thisobj.smash_peakflux = np.interp(thisobj.spat_fracpos * nsamp, np.arange(nsamp),
#                                           self._trcdict['profile'])  # interpolate from fluxconv
#        # assign the trace
#        thisobj.trace_spat = trace_model
#        thisobj.trace_spec = spec_vec
#        thisobj.spat_pixpos = thisobj.trace_spat[self.nspec//2]
#        thisobj.set_idx()
#        if self._object_traces.nobj != 0:
#            thisobj.fwhm = self._object_traces._fwhm[0]
#        else:  # Otherwise just use the fwhm parameter input to the code (or the default value)
#            thisobj.fwhm = 2
#        # Finally, add new object
#        self.specobjs.add_sobj(thisobj)

    def delete_object(self):
        """Delete an object trace
        """
        self._object_traces.delete_object(self._obj_idx)
        self._obj_idx = -1
        self.replot()

    def delete_object_sobj(self):
        """Delete a specobj
        """
        self.specobjs.remove_sobj(self._obj_idx)
        self._obj_idx = -1

    def print_pypeit_info(self):
        """print text that the user should insert into their .pypeit file
        """
        if 1 in self._object_traces._add_rm:
            msgs.info("Include the following info in the manual_extract column in your .pypeit file:\n")
            print(self._object_traces.get_pypeit_string())

    def recenter(self):
        xlim = self.axes['main'].get_xlim()
        ylim = self.axes['main'].get_ylim()
        xmin = self.mmx - 0.5*(xlim[1]-xlim[0])
        xmax = self.mmx + 0.5*(xlim[1]-xlim[0])
        ymin = self.mmy - 0.5*(ylim[1]-ylim[0])
        ymax = self.mmy + 0.5*(ylim[1]-ylim[0])
        self.axes['main'].set_xlim([xmin, xmax])
        self.axes['main'].set_ylim([ymin, ymax])

    def make_objprofile(self):
        """Generate an object profile from the traces
        """
        coords = self._spatpos - self._object_traces._trace_spat[self._obj_idx][:, np.newaxis]
        ww = np.where(np.abs(coords) < 4*self._object_traces._fwhm[self._obj_idx])
        bincent = self.profile['profile'].get_xdata()
        offs = 0.5*(bincent[1]-bincent[0])
        edges = np.append(bincent[0]-offs, bincent+offs)
        prof, _ = np.histogram(coords[ww], bins=edges, weights=self.frame[ww])
        return prof/ww[0].size

    def set_fwhm(self, xdata):
        """Set the FWHM using the available panel

        Args:
            xdata (float): The x coordinate selected by the user
        """
        self._object_traces._fwhm[self._obj_idx] = 2.0*np.abs(xdata)
        self.draw_profile()
        self.replot()
        return

    def get_specobjs(self):
        """Get the updated version of SpecObjs

        Returns:
            SpecObjs: SpecObjs Class
        """
        if self._use_updates:
            msgs.work("Have not updated SpecObjs yet")
            return self.specobjs
        else:
            return None

    def update_infobox(self, message="Press '?' to list the available options",
                       yesno=True, default=False):
        """Send a new message to the information window at the top of the canvas

        Args:
            message (str): Message to be displayed
        """
        self.axes['info'].clear()
        if default:
            self.axes['info'].text(0.5, 0.5, "Press '?' to list the available options", transform=self.axes['info'].transAxes,
                          horizontalalignment='center', verticalalignment='center')
            self.canvas.draw()
            return
        # Display the message
        self.axes['info'].text(0.5, 0.5, message, transform=self.axes['info'].transAxes,
                      horizontalalignment='center', verticalalignment='center')
        if yesno:
            self.axes['info'].fill_between([0.8, 0.9], 0, 1, facecolor='green', alpha=0.5, transform=self.axes['info'].transAxes)
            self.axes['info'].fill_between([0.9, 1.0], 0, 1, facecolor='red', alpha=0.5, transform=self.axes['info'].transAxes)
            self.axes['info'].text(0.85, 0.5, "YES", transform=self.axes['info'].transAxes,
                          horizontalalignment='center', verticalalignment='center')
            self.axes['info'].text(0.95, 0.5, "NO", transform=self.axes['info'].transAxes,
                          horizontalalignment='center', verticalalignment='center')
        self.axes['info'].set_xlim((0, 1))
        self.axes['info'].set_ylim((0, 1))
        self.canvas.draw()

    def empty_mantrace(self):
        """Generate an empty dictionary for the manual tracing
        """
        self._mantrace = dict(spat_a=[], spec_a=[], spat_trc=None, spec_trc=np.arange(self.nspec), polyorder=0)
        return
Пример #8
0
class Curator:
    """
    matplotlib display of scrolling image data 
    
    Parameters
    ---------
    extractor : extractor
        extractor object containing a full set of infilled threads and time series

    Attributes
    ----------
    ind : int
        thread indexing 

    min : int
        min of image data (for setting ranges)

    max : int
        max of image data (for setting ranges)

    """
    def __init__(self, e, window=100):
        # get info from extractors
        self.s = e.spool
        self.timeseries = e.timeseries
        self.tf = e.im
        self.tf.t = 0
        self.window = window
        ## num neurons
        self.numneurons = len(self.s.threads)

        self.path = e.root + 'extractor-objects/curate.json'
        self.ind = 0
        try:
            with open(self.path) as f:
                self.curate = json.load(f)

            self.ind = int(self.curate['last'])
        except:
            self.curate = {}
            self.ind = 0
            self.curate['0'] = 'seen'

        # array to contain internal state: whether to display single ROI, ROI in Z, or all ROIs
        self.pointstate = 0
        self.show_settings = 0
        self.showmip = 0
        ## index for which thread
        #self.ind = 0

        ## index for which time point to display
        self.t = 0

        ### First frame of the first thread
        self.update_im()

        ## Display range
        self.min = np.min(self.im)
        self.max = np.max(self.im)  # just some arbitrary value

        ## maximum t
        self.tmax = e.t

        self.restart()
        atexit.register(self.log_curate)

    def restart(self):
        ## Figure to display
        self.fig = plt.figure()

        ## Size of window around ROI in sub image
        #self.window = window

        ## grid object for complicated subplot handing
        self.grid = plt.GridSpec(4, 2, wspace=0.1, hspace=0.2)

        ### First subplot: whole image with red dot over ROI
        self.ax1 = plt.subplot(self.grid[:3, 0])
        plt.subplots_adjust(bottom=0.4)
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)

        # plotting for multiple points

        if self.pointstate == 0:
            pass
            #self.point1 = plt.scatter()
            #self.point1 = plt.scatter(self.s.get_positions_t_z(self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:,2], self.s.get_positions_t_z(self.t,self.s.threads[self.ind].get_position_t(self.t)[0])[:,1],c='b', s=10)
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')

        # plotting for single point
        #
        #plt.axis("off")
        #

        ### Second subplot: some window around the ROI
        plt.subplot(self.grid[:3, 1])
        plt.subplots_adjust(bottom=0.4)

        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)

        self.img2 = plt.imshow(self.get_subim_display(),
                               cmap='gray',
                               vmin=0,
                               vmax=1)
        self.point2 = plt.scatter(self.window / 2 + self.offset[0],
                                  self.window / 2 + self.offset[1],
                                  c='r',
                                  s=40)

        self.title = self.fig.suptitle(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.axis("off")

        ### Third subplot: plotting the timeseries
        self.timeax = plt.subplot(self.grid[3, :])
        plt.subplots_adjust(bottom=0.4)
        self.timeplot, = self.timeax.plot(
            (self.timeseries[:, self.ind] -
             np.min(self.timeseries[:, self.ind])) /
            (np.max(self.timeseries[:, self.ind]) -
             np.min(self.timeseries[:, self.ind])))
        plt.axis("off")

        ### Axis for scrolling through t
        self.tr = plt.axes([0.2, 0.15, 0.3, 0.03],
                           facecolor='lightgoldenrodyellow')
        self.s_tr = Slider(self.tr,
                           'Timepoint',
                           0,
                           self.tmax - 1,
                           valinit=0,
                           valstep=1)
        self.s_tr.on_changed(self.update_t)

        ### Axis for setting min/max range
        self.minr = plt.axes([0.2, 0.2, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.sminr = Slider(self.minr,
                            'R Min',
                            0,
                            np.max(self.im),
                            valinit=self.min,
                            valstep=1)
        self.maxr = plt.axes([0.2, 0.25, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.smaxr = Slider(self.maxr,
                            'R Max',
                            0,
                            np.max(self.im) * 4,
                            valinit=self.max,
                            valstep=1)
        self.sminr.on_changed(self.update_mm)
        self.smaxr.on_changed(self.update_mm)

        ### Axis for buttons for next/previous time series
        #where the buttons are, and their locations
        self.axprev = plt.axes([0.62, 0.20, 0.1, 0.075])
        self.axnext = plt.axes([0.75, 0.20, 0.1, 0.075])
        self.bnext = Button(self.axnext, 'Next')
        self.bnext.on_clicked(self.next)
        self.bprev = Button(self.axprev, 'Previous')
        self.bprev.on_clicked(self.prev)

        #### Axis for button for display
        self.pointsax = plt.axes([0.75, 0.10, 0.1, 0.075])
        self.pointsbutton = RadioButtons(self.pointsax,
                                         ('Single', 'Same Z', 'All'))
        self.pointsbutton.set_active(self.pointstate)
        self.pointsbutton.on_clicked(self.update_pointstate)

        #### Axis for whether to display MIP on left
        self.mipax = plt.axes([0.62, 0.10, 0.1, 0.075])
        self.mipbutton = RadioButtons(self.mipax, ('Single Z', 'MIP'))
        self.mipbutton.set_active(self.showmip)
        self.mipbutton.on_clicked(self.update_mipstate)

        ### Axis for button to keep
        self.keepax = plt.axes([0.87, 0.20, 0.075, 0.075])
        self.keep_button = CheckButtons(self.keepax, ['Keep', 'Trash'],
                                        [False, False])
        self.keep_button.on_clicked(self.keep)

        ### Axis to determine which ones to show
        self.showax = plt.axes([0.87, 0.10, 0.075, 0.075])
        self.showbutton = RadioButtons(
            self.showax, ('All', 'Unlabelled', 'Kept', 'Trashed'))
        self.showbutton.set_active(self.show_settings)
        self.showbutton.on_clicked(self.show)

        plt.show()

    ## Attempting to get autosave when instance gets deleted, not working right now TODO
    def __del__(self):
        self.log_curate()

    def update_im(self):
        #print(self.t)
        #print(self.ind)
        #print(self.t,int(self.s.threads[self.ind].get_position_t(self.t)[0]))
        if self.showmip:
            self.im = np.max(self.tf.get_t(self.t), axis=0)
        else:
            self.im = self.tf.get_tbyf(
                self.t,
                int(self.s.threads[self.ind].get_position_t(self.t)[0]))

    def get_im_display(self):

        return (self.im - self.min) / (self.max - self.min)

    def get_subim_display(self):
        return (self.subim - self.min) / (self.max - self.min)

    def update_figures(self):
        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)
        self.img1.set_data(self.get_im_display())

        if self.pointstate == 0:
            pass
        elif self.pointstate == 1:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            2],
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            1]
                ]).T)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t(self.t)[:, 2],
                    self.s.get_positions_t(self.t)[:, 1]
                ]).T)
        self.thispoint.set_offsets([
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1]
        ])
        plt.axis('off')
        #plotting for single point
        #

        self.img2.set_data(self.get_subim_display())
        self.point2.set_offsets([
            self.window / 2 + self.offset[0], self.window / 2 + self.offset[1]
        ])
        self.title.set_text(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.draw()

    def update_timeseries(self):
        self.timeplot.set_ydata((self.timeseries[:, self.ind] -
                                 np.min(self.timeseries[:, self.ind])) /
                                (np.max(self.timeseries[:, self.ind]) -
                                 np.min(self.timeseries[:, self.ind])))
        plt.draw()

    def update_t(self, val):
        # Update index for t
        self.t = val
        # update image for t
        self.update_im()
        self.update_figures()

    def update_mm(self, val):
        self.min = self.sminr.val
        self.max = self.smaxr.val
        #self.update_im()
        self.update_figures()

    def next(self, event):
        self.set_index_next()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def prev(self, event):
        self.set_index_prev()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def log_curate(self):
        self.curate['last'] = self.ind
        with open(self.path, 'w') as fp:
            json.dump(self.curate, fp)

    def keep(self, event):
        status = self.keep_button.get_status()
        if np.sum(status) != 1:
            for i in range(len(status)):
                if status[i] != False:
                    self.keep_button.set_active(i)

        else:
            if status[0]:
                self.curate[str(self.ind)] = 'keep'
            elif status[1]:
                self.curate[str(self.ind)] = 'trash'
            else:
                pass

    def update_buttons(self):

        curr = self.keep_button.get_status()
        #print(curr)
        future = [False for i in range(len(curr))]
        if self.curate.get(str(self.ind)) == 'seen':
            pass
        elif self.curate.get(str(self.ind)) == 'keep':
            future[0] = True
        elif self.curate.get(str(self.ind)) == 'trash':
            future[1] = True
        else:
            pass

        for i in range(len(curr)):
            if curr[i] != future[i]:
                self.keep_button.set_active(i)

    def show(self, label):
        d = {'All': 0, 'Unlabelled': 1, 'Kept': 2, 'Trashed': 3}
        #print(label)
        self.show_settings = d[label]

    def set_index_prev(self):
        if self.show_settings == 0:
            self.ind -= 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind -= 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def set_index_next(self):
        if self.show_settings == 0:
            self.ind += 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind += 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind += 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def update_curate(self):
        if self.curate.get(str(self.ind)) in ['keep', 'seen', 'trash']:
            pass
        else:
            self.curate[str(self.ind)] = 'seen'

    def update_pointstate(self, label):
        d = {
            'Single': 0,
            'Same Z': 1,
            'All': 2,
        }
        #print(label)
        self.pointstate = d[label]
        self.update_point1()
        self.update_figures()

    def update_point1(self):
        self.ax1.clear()
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)
        plt.axis('off')
        if self.pointstate == 0:
            self.point1 = None
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')
        #plt.show()

    def update_mipstate(self, label):
        d = {
            'Single Z': 0,
            'MIP': 1,
        }
        #print(label)
        self.showmip = d[label]

        self.update_im()
        self.update_figures()
Пример #9
0
class ViewImageStack(object):  # interactive image viewer for a stack of images
    def __init__(self, chunk_size_current, img_data, trainid_data,
                 pulseid_data, cellid_data, is_good_data, view_cmap, view_clip,
                 view_figsz, dir_save):

        self.fg = plt.figure(figsize=view_figsz)
        self.cmap = view_cmap
        self.clmin = view_clip[0]
        self.clmax = view_clip[1]
        self.dirsave = dir_save

        self.imtot = chunk_size_current  # total number of images in the file
        self.X = img_data
        self.trainid_data = trainid_data
        self.pulseid_data = pulseid_data
        self.cellid_data = cellid_data
        self.imtype = is_good_data
        self.ind = 0  #initial image to show, e.g, 0 or self.imtot//2

        self.mindef = self.clmin  # default values used for error handling
        self.maxdef = self.clmax
        clrinact = 'lightgray'  # 'papayawhip' 'lightgray'
        clrhover = 'darkgray'  # 'salmon' 'darkgray'

        # axes for all visual components
        self.ax = self.fg.add_subplot(111)
        self.cbax = self.fg.add_axes(
            [0.85, 0.1, 0.03,
             0.77])  # for colorbar [left, bottom, width, height]
        self.resax = self.fg.add_axes([0.05, 0.83, 0.1,
                                       0.05])  # for 'Reset' button
        self.rax = self.fg.add_axes([0.05, 0.72, 0.1, 0.1],
                                    facecolor=clrinact)  # for radiobuttons
        self.minax = self.fg.add_axes([0.04, 0.65, 0.12, 0.05
                                       ])  # min and max values for datarange
        self.maxax = self.fg.add_axes([0.04, 0.59, 0.12, 0.05])
        self.imnax = self.fg.add_axes([0.07, 0.23, 0.1, 0.04],
                                      facecolor=clrinact)  # chose image number
        self.savebut = self.fg.add_axes([0.05, 0.30, 0.1,
                                         0.05])  # for 'Save image' button
        self.exitax = self.fg.add_axes([0.05, 0.93, 0.1,
                                        0.05])  # for 'Exit' button

        # visual components
        self.im = self.ax.imshow(self.X[self.ind, :, :],
                                 cmap=self.cmap,
                                 clim=(self.clmin, self.clmax))
        self.cls = plt.colorbar(self.im, cax=self.cbax)
        self.averbtn = Button(self.savebut, 'Save png', color=clrinact)
        self.resbt = Button(self.resax, 'Reset view', color=clrinact)
        self.exitbt = Button(self.exitax, 'Exit', color=clrinact)
        self.tbmin = TextBox(self.minax,
                             'min:',
                             initial=str(self.clmin),
                             color=clrinact)
        self.tbmax = TextBox(self.maxax,
                             'max:',
                             initial=str(self.clmax),
                             color=clrinact)
        self.rb = RadioButtons(self.rax, ('lin', 'log'),
                               active=0,
                               activecolor='red')
        self.imnum = TextBox(self.imnax,
                             'jump to \n image #',
                             initial=str(self.ind + 1),
                             color=clrinact)

        strID = "Image information:\nTrainID: " + str(
            self.trainid_data[self.ind]) + "\nPulseID: " + str(
                self.pulseid_data[self.ind]) + "\nCellID: " + str(
                    self.cellid_data[self.ind]) + "\nImage type: " + str(
                        self.imtype[self.ind])
        self.textID = self.fg.text(0.03, 0.085, strID)

        self.fg.canvas.mpl_connect('scroll_event', self.on_scroll)
        self.fg.canvas.mpl_connect('button_press_event', self.on_press)
        self.fg.canvas.mpl_connect('key_press_event', self.on_keypress)

        self.update()
        plt.show(block=True)

    def on_press(self, event):

        if event.inaxes == self.rax.axes:
            self.update()

        if event.inaxes == self.resax.axes:  # 'reset view' button  : default datarange, linear scale, first image
            self.clmin = self.mindef
            self.clmax = self.maxdef
            self.tbmin.set_val(self.clmin)
            self.tbmax.set_val(self.clmax)
            self.rb.set_active(0)
            self.ind = 0
            print('updated data range is [%s , %s ]' %
                  (self.clmin, self.clmax))
            self.update()

        if event.inaxes == self.exitax:
            sys.exit(0)  # exit program

        if event.inaxes == self.savebut.axes:  # 'save image' button

            items = [
                self.ax, self.cbax,
                self.cbax.get_yaxis().get_label(),
                self.ax.get_xaxis().get_label(),
                self.ax.get_yaxis().get_label()
            ]
            bbox = Bbox.union([item.get_window_extent() for item in items])
            extent = bbox.transformed(self.fg.dpi_scale_trans.inverted())
            strim = 'img_trID_{:d}_plID_{:d}_ceID_{:d}_type_{:d}.png'.format(
                self.trainid_data[self.ind], self.pulseid_data[self.ind],
                self.cellid_data[self.ind], self.imtype[self.ind])
            self.fg.savefig(self.dirsave + strim, bbox_inches=extent)

    def on_keypress(self, event):

        #if (event.inaxes == self.minax.axes) | (event.inaxes == self.maxax.axes):
        if event.key == 'enter':  #one may face bad script behaviour here if some specific keyboard button like 'tab' has been pressed
            try:
                self.clmin = float(self.tbmin.text)
                self.clmax = float(self.tbmax.text)
                #print('updated data range is [%s , %s ]' %(self.clmin, self.clmax))
            except:
                print('inappropriate values for data range specified')
                self.tbmin.set_val(self.clmin)
                self.tbmax.set_val(self.clmax)
            self.update()

        if (event.inaxes == self.imnax):
            if event.key == 'enter':  #one may face bad script behaviour here if some specific keyboard button like 'tab' has been pressed
                try:
                    val = int(self.imnum.text)
                    if (val >= 1) & (val <= self.imtot):
                        self.ind = val - 1
                    else:
                        print('inappropriate image number specified')
                        self.imnum.set_val(str(self.ind + 1))
                except:
                    print('inappropriate image number specified')
                    self.imnum.set_val(str(self.ind + 1))
                self.update()

    def on_scroll(self, event):

        if event.button == 'up':
            self.ind = (self.ind + 1) % self.imtot
        else:
            self.ind = (self.ind - 1) % self.imtot
        self.update()

    def update(self):

        self.ax.set_title(
            'image # %s out of %s \n (use scroll wheel to navigate through images)'
            % (self.ind + 1, self.imtot))

        # choose linear or log scale
        if self.rb.value_selected == 'lin':
            self.im.set_data(self.X[self.ind, :, :])
        else:
            with np.errstate(divide='ignore', invalid='ignore'):
                self.im.set_data(np.log10(self.X[self.ind, :, :]))
                #self.im.set_clim(np.log10(self.clmin), np.log10(self.clmax))  #adjust scaling
        # update display representation
        self.im.set_clim(self.clmin, self.clmax)

        strID = "Image information:\nTrainID: " + str(
            self.trainid_data[self.ind]) + "\nPulseID: " + str(
                self.pulseid_data[self.ind]) + "\nCellID: " + str(
                    self.cellid_data[self.ind]) + "\nImage type: " + str(
                        self.imtype[self.ind])
        self.textID.set_text(strID)

        self.cls.draw_all()
        self.im.axes.figure.canvas.draw()

    def __del__(self):

        self.fg.canvas.mpl_disconnect(self.on_scroll)
        self.fg.canvas.mpl_disconnect(self.on_press)
        self.fg.canvas.mpl_disconnect(self.on_keypress)
Пример #10
0
def interact2D(
    data: wt_data.Data,
    xaxis=0,
    yaxis=1,
    channel=0,
    cmap=None,
    local=False,
    use_imshow=False,
    verbose=True,
):
    """Interactive 2D plot of the dataset.
    Side plots show x and y projections of the slice (shaded gray).
    Left clicks on the main axes draw 1D slices on side plots at the coordinates selected.
    Right clicks remove the 1D slices.
    For 3+ dimensional data, sliders below the main axes are used to change which slice is viewed.

    Parameters
    ----------
    data : WrightTools.Data object
        Data to plot.
    xaxis : string, integer, or data.Axis object (optional)
        Expression or index of x axis. Default is 0.
    yaxis : string, integer, or data.Axis object (optional)
        Expression or index of y axis. Default is 1.
    channel : string, integer, or data.Channel object (optional)
        Name or index of channel to plot. Default is 0.
    cmap : string or cm object (optional)
        Name of colormap, or explicit colormap object.  Defaults to channel default.
    local : boolean (optional)
        Toggle plotting locally. Default is False.
    use_imshow : boolean (optional)
        If True, matplotlib imshow is used to render the 2D slice.
        Can give better performance, but is only accurate for
        uniform grids.  Default is False.
    verbose : boolean (optional)
        Toggle talkback. Default is True.
    """
    # avoid changing passed data object
    data = data.copy()
    # unpack
    data.prune(keep_channels=channel, verbose=False)
    channel = get_channel(data, channel)
    xaxis, yaxis = get_axes(data, [xaxis, yaxis])
    cmap = cmap if cmap is not None else get_colormap(channel.signed)
    current_state = SimpleNamespace()
    # create figure
    nsliders = data.ndim - 2
    if nsliders < 0:
        raise DimensionalityError(">= 2", data.ndim)
    # TODO: implement aspect; doesn't work currently because of our incorporation of colorbar
    fig, gs = create_figure(width="single", nrows=7 + nsliders, cols=[1, 1, 1, 1, 1, "cbar"])
    # create axes
    ax0 = plt.subplot(gs[1:6, 0:5])
    ax0.patch.set_facecolor("w")
    cax = plt.subplot(gs[1:6, -1])
    sp_x = add_sideplot(ax0, "x", pad=0.1)
    sp_y = add_sideplot(ax0, "y", pad=0.1)
    ax_local = plt.subplot(gs[0, 0], aspect="equal", frameon=False)
    ax_title = plt.subplot(gs[0, 3], frameon=False)
    ax_title.text(
        0.5,
        0.5,
        data.natural_name,
        fontsize=18,
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax_title.transAxes,
    )
    ax_title.set_axis_off()
    # NOTE: there are more axes here for more buttons / widgets in future plans
    # create lines
    x_color = "#00BFBF"  # cyan with increased saturation
    y_color = "coral"
    line_sp_x = sp_x.plot([None], [None], visible=False, color=x_color, linewidth=2)[0]
    line_sp_y = sp_y.plot([None], [None], visible=False, color=y_color, linewidth=2)[0]
    crosshair_hline = ax0.plot([None], [None], visible=False, color=x_color, linewidth=2)[0]
    crosshair_vline = ax0.plot([None], [None], visible=False, color=y_color, linewidth=2)[0]
    current_state.xarg = xaxis.points.flatten().size // 2
    current_state.yarg = yaxis.points.flatten().size // 2
    xdir = 1 if xaxis.points.flatten()[-1] - xaxis.points.flatten()[0] > 0 else -1
    ydir = 1 if yaxis.points.flatten()[-1] - yaxis.points.flatten()[0] > 0 else -1
    current_state.bin_vs_x = True
    current_state.bin_vs_y = True
    # create buttons
    current_state.local = local
    radio = RadioButtons(ax_local, (" global", " local"))
    if local:
        radio.set_active(1)
    else:
        radio.set_active(0)
    for circle in radio.circles:
        circle.set_radius(0.14)
    # create sliders
    sliders = {}
    for axis in data.axes:
        if axis not in [xaxis, yaxis]:
            if axis.size > np.prod(axis.shape):
                raise NotImplementedError("Cannot use multivariable axis as a slider")
            slider_axes = plt.subplot(gs[~len(sliders), :]).axes
            slider = Slider(slider_axes, axis.label, 0, axis.points.size - 1, valinit=0, valstep=1)
            sliders[axis.natural_name] = slider
            slider.ax.vlines(
                range(axis.points.size - 1),
                *slider.ax.get_ylim(),
                colors="k",
                linestyle=":",
                alpha=0.5,
            )
            slider.valtext.set_text(gen_ticklabels(axis.points)[0])
    current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()])
    # initial xyz start are from zero indices of additional axes
    current_state.dat = data.chop(
        xaxis.natural_name,
        yaxis.natural_name,
        at=_at_dict(data, sliders, xaxis, yaxis),
        verbose=False,
    )[0]
    norm = get_norm(channel, current_state)

    gen_mesh = ax0.pcolormesh if not use_imshow else ax0.imshow
    obj2D = gen_mesh(
        current_state.dat,
        cmap=cmap,
        norm=norm,
        ylabel=yaxis.label,
        xlabel=xaxis.label,
    )
    ax0.grid(b=True)
    # colorbar
    ticks = norm_to_ticks(norm)
    ticklabels = gen_ticklabels(ticks, channel.signed)
    colorbar = plot_colorbar(cax, cmap=cmap, label=channel.natural_name, ticks=ticks)
    colorbar.set_ticklabels(ticklabels)
    fig.canvas.draw_idle()

    def draw_sideplot_projections():
        arr = current_state.dat[channel.natural_name][:]
        xind = list(
            np.array(
                current_state.dat.axes[
                    current_state.dat.axis_expressions.index(xaxis.expression)
                ].shape
            )
            > 1
        ).index(True)
        yind = list(
            np.array(
                current_state.dat.axes[
                    current_state.dat.axis_expressions.index(yaxis.expression)
                ].shape
            )
            > 1
        ).index(True)
        if channel.signed:
            temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
            temp_arr[temp_arr < 0] = 0
            x_proj_pos = np.nanmean(temp_arr, axis=yind)
            y_proj_pos = np.nanmean(temp_arr, axis=xind)

            temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
            temp_arr[temp_arr > 0] = 0
            x_proj_neg = np.nanmean(temp_arr, axis=yind)
            y_proj_neg = np.nanmean(temp_arr, axis=xind)

            x_proj = np.nanmean(arr, axis=yind)
            y_proj = np.nanmean(arr, axis=xind)

            alpha = 0.4
            blue = "#517799"  # start with #87C7FF and change saturation
            red = "#994C4C"  # start with #FF7F7F and change saturation
            if current_state.bin_vs_x:
                x_proj_norm = max(np.nanmax(x_proj_pos), np.nanmax(-x_proj_neg))
                if x_proj_norm != 0:
                    x_proj_pos /= x_proj_norm
                    x_proj_neg /= x_proj_norm
                    x_proj /= x_proj_norm
                try:
                    sp_x.fill_between(xaxis.points, x_proj_pos, 0, color=red, alpha=alpha)
                    sp_x.fill_between(xaxis.points, 0, x_proj_neg, color=blue, alpha=alpha)
                    sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
                except ValueError:  # Input passed into argument is not 1-dimensional
                    current_state.bin_vs_x = False
                    sp_x.set_visible(False)
            if current_state.bin_vs_y:
                y_proj_norm = max(np.nanmax(y_proj_pos), np.nanmax(-y_proj_neg))
                if y_proj_norm != 0:
                    y_proj_pos /= y_proj_norm
                    y_proj_neg /= y_proj_norm
                    y_proj /= y_proj_norm
                try:
                    sp_y.fill_betweenx(yaxis.points, y_proj_pos, 0, color=red, alpha=alpha)
                    sp_y.fill_betweenx(yaxis.points, 0, y_proj_neg, color=blue, alpha=alpha)
                    sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_y = False
                    sp_y.set_visible(False)
        else:
            if current_state.bin_vs_x:
                x_proj = np.nanmean(arr, axis=yind)
                x_proj = norm(x_proj, channel.signed)
                try:
                    sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_x = False
                    sp_x.set_visible(False)
            if current_state.bin_vs_y:
                y_proj = np.nanmean(arr, axis=xind)
                y_proj = norm(y_proj, channel.signed)
                try:
                    sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_y = False
                    sp_y.set_visible(False)

    draw_sideplot_projections()

    ax0.set_xlim(xaxis.points.min(), xaxis.points.max())
    ax0.set_ylim(yaxis.points.min(), yaxis.points.max())

    if channel.signed:
        sp_x.set_ylim(-1.1, 1.1)
        sp_y.set_xlim(-1.1, 1.1)

    def update_sideplot_slices():
        # TODO:  if bins is only available along one axis, slicing should be valid along the other
        #   e.g., if bin_vs_y =  True, then assemble slices vs x
        #   for now, just uniformly turn off slicing
        if (not current_state.bin_vs_x) or (not current_state.bin_vs_y):
            return
        xlim = ax0.get_xlim()
        ylim = ax0.get_ylim()
        x0 = xaxis.points[current_state.xarg]
        y0 = yaxis.points[current_state.yarg]

        crosshair_hline.set_data(np.array([xlim, [y0, y0]]))
        crosshair_vline.set_data(np.array([[x0, x0], ylim]))

        at_dict = _at_dict(data, sliders, xaxis, yaxis)
        at_dict[xaxis.natural_name] = (x0, xaxis.units)
        side_plot_data = data.chop(yaxis.natural_name, at=at_dict, verbose=False)
        side_plot = side_plot_data[0][channel.natural_name].points
        side_plot = norm(side_plot, channel.signed)
        line_sp_y.set_data(side_plot, yaxis.points)
        side_plot_data.close()

        at_dict = _at_dict(data, sliders, xaxis, yaxis)
        at_dict[yaxis.natural_name] = (y0, yaxis.units)
        side_plot_data = data.chop(xaxis.natural_name, at=at_dict, verbose=False)
        side_plot = side_plot_data[0][channel.natural_name].points
        side_plot = norm(side_plot, channel.signed)
        line_sp_x.set_data(xaxis.points, side_plot)
        side_plot_data.close()

    def update_local(index):
        if verbose:
            print("normalization:", index)
        current_state.local = radio.value_selected[1:] == "local"
        norm = get_norm(channel, current_state)
        obj2D.set_norm(norm)
        ticklabels = gen_ticklabels(np.linspace(norm.vmin, norm.vmax, 11), channel.signed)
        colorbar.set_ticklabels(ticklabels)
        fig.canvas.draw_idle()

    def update_slider(info, use_imshow=use_imshow):
        current_state.dat.close()
        current_state.dat = data.chop(
            xaxis.natural_name,
            yaxis.natural_name,
            at={
                a.natural_name: (a[:].flat[int(sliders[a.natural_name].val)], a.units)
                for a in data.axes
                if a not in [xaxis, yaxis]
            },
            verbose=False,
        )[0]
        for k, s in sliders.items():
            s.valtext.set_text(
                gen_ticklabels(data.axes[data.axis_names.index(k)].points)[int(s.val)]
            )
        if use_imshow:
            transpose = _order_for_imshow(
                current_state[xaxis.natural_name][:],
                current_state[yaxis.natural_name][:],
            )
            obj2D.set_data(current_state.dat[channel.natural_name][:].transpose(transpose))
        else:
            obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
        norm = get_norm(channel, current_state)
        obj2D.set_norm(norm)

        ticks = norm_to_ticks(norm)
        ticklabels = gen_ticklabels(ticks, channel.signed)
        colorbar.set_ticklabels(ticklabels)
        sp_x.collections.clear()
        sp_y.collections.clear()
        draw_sideplot_projections()
        if line_sp_x.get_visible() and line_sp_y.get_visible():
            update_sideplot_slices()
        fig.canvas.draw_idle()

    def update_crosshairs(xarg, yarg, hide=False):
        # if x0 is None or y0 is None:
        #    raise TypeError((x0, y0))
        # find closest x and y pts in dataset
        current_state.xarg = xarg
        current_state.yarg = yarg
        xedge = xarg in [0, xaxis.points.flatten().size - 1]
        yedge = yarg in [0, yaxis.points.flatten().size - 1]
        current_state.xpos = xaxis.points[xarg]
        current_state.ypos = yaxis.points[yarg]
        if not hide:  # update crosshairs and show
            if verbose:
                print(current_state.xpos, current_state.ypos)
            update_sideplot_slices()
            line_sp_x.set_visible(True)
            line_sp_y.set_visible(True)
            crosshair_hline.set_visible(True)
            crosshair_vline.set_visible(True)
            # thicker lines if on the axis edges
            crosshair_vline.set_linewidth(6 if xedge else 2)
            crosshair_hline.set_linewidth(6 if yedge else 2)
        else:  # do not update and hide crosshairs
            line_sp_x.set_visible(False)
            line_sp_y.set_visible(False)
            crosshair_hline.set_visible(False)
            crosshair_vline.set_visible(False)

    def update_button_release(info):
        # mouse button release
        current_state.focus(info.inaxes)
        if info.inaxes == ax0:
            xlim = ax0.get_xlim()
            ylim = ax0.get_ylim()
            x0, y0 = info.xdata, info.ydata
            if x0 > xlim[0] and x0 < xlim[1] and y0 > ylim[0] and y0 < ylim[1]:
                xarg = np.abs(xaxis.points - x0).argmin()
                yarg = np.abs(yaxis.points - y0).argmin()
                if info.button == 1 or info.button is None:  # left click
                    update_crosshairs(xarg, yarg)
                elif info.button == 3:  # right click
                    update_crosshairs(xarg, yarg, hide=True)
        fig.canvas.draw_idle()

    def update_key_press(info):
        if info.key in ["left", "right", "up", "down"]:
            if current_state.focus.focus_axis != ax0:  # sliders
                if info.key in ["up", "down"]:
                    return
                slider = [
                    slider
                    for slider in sliders.values()
                    if slider.ax == current_state.focus.focus_axis
                ][0]
                new_val = slider.val + 1 if info.key == "right" else slider.val - 1
                new_val %= slider.valmax + 1
                slider.set_val(new_val)
            else:  # crosshairs
                dx = dy = 0
                if info.key == "left":
                    dx -= 1
                elif info.key == "right":
                    dx += 1
                elif info.key == "up":
                    dy += 1
                elif info.key == "down":
                    dy -= 1
                update_crosshairs(
                    (current_state.xarg + dx * xdir) % xaxis.points.flatten().size,
                    (current_state.yarg + dy * ydir) % yaxis.points.flatten().size,
                )
        elif info.key == "tab":
            current_state.focus("next")
        elif info.key == "ctrl+tab":
            current_state.focus("previous")
        else:
            mpl.backend_bases.key_press_handler(info, fig.canvas, fig.canvas.toolbar)
        fig.canvas.draw_idle()

    fig.canvas.mpl_disconnect(fig.canvas.manager.key_press_handler_id)
    fig.canvas.mpl_connect("button_release_event", update_button_release)
    fig.canvas.mpl_connect("key_press_event", update_key_press)
    radio.on_clicked(update_local)

    for slider in sliders.values():
        slider.on_changed(update_slider)

    return obj2D, sliders, crosshair_hline, crosshair_vline, radio, colorbar
Пример #11
0
class RunSim:
    ''' This class is the core of the program. It uses matplotlib to gather
    user input and dispaly results, and also preforms requisite calculations.'''
    def __init__(self):
        ''' The constructor creates instances of the walker, biker, and driver
        objects from the 'modes' module, sets a default distance and trip
        number, and then calculates the time, cost, calories,
        and CO2 emissions for all modes of tranit. Then it sets up graphs
        for displaying the information, as well as sliders, buttons, and
        RadioButtons for gathering user input. The functions that those
        buttons execute are defined internally.
        '''
        # Create instances of the driver, biker, and walker objects
        # Whose instance variables will be used heavily in calculations.
        self.d = Driver()
        self.b = Biker()
        self.w = Walker()

        # Default vaulues
        self.dist = 1
        self.trips = 1

        # Do initial calculations (calcualte returns calcDict)
        self.calcDict = self.calculate()

        # Create figure object which we will place everything on
        # of dimensions 14" by 10"
        self.fig = plt.figure(figsize=(14, 10))

        # Create 4 axes objects evenly spaced in a column. These will
        # hold the four graphs/figures (time, cost, calories, and CO2.)
        self.ax1 = self.fig.add_subplot(411)
        self.ax2 = self.fig.add_subplot(412)
        self.ax3 = self.fig.add_subplot(413)
        self.ax4 = self.fig.add_subplot(414)

        # Adjust the location of subplots (the axes holding graphs)
        self.fig.subplots_adjust(left=.74,
                                 right=.94,
                                 bottom=.18,
                                 top=.98,
                                 hspace=.25)

        ### Set up Buttons, RadioButtons, Sliders and their fcns! ###

        # The structure of setting up the temporary rax axes, then making
        # the RadioButton, then defining it's function, then adding the
        # on_clicked statement is taken from this example:
        # http://matplotlib.org/examples/widgets/radio_buttons.html

        axcolor = 'lightgoldenrodyellow'

        def getRadioPosList(ax):
            ''' This function is simply for positioning the 4 radio buttons
            used to change units approporatiely, vertically centered in
            relation to the adjacent graph. '''
            # Width and height are the same for each radio button
            widthRax = 0.12
            heightRax = 0.10

            # Find the lower (left) x value of adjacent graph
            # And construct appropriate x value for radio axes
            x0Ax = ax.get_position().get_points()[0, 0]
            xRax = x0Ax - (widthRax * 1.8)

            # Find lower and upper y values of the adjacent graph,
            # average them, and add half the height of the radiobutton axes
            # to determine a y value that will vertically center the radiobutton
            y0Ax = ax.get_position().get_points()[0, 1]
            y1Ax = ax.get_position().get_points()[1, 1]
            yRax = ((y0Ax + y1Ax) / 2) - (heightRax / 2)

            return [xRax, yRax, widthRax, heightRax]

        # Unit Change RadioButton 1: Change Time Units
        rax = plt.axes(getRadioPosList(self.ax1), axisbg=axcolor)
        self.radioTime = RadioButtons(rax, ('Hours', 'Minutes', 'Audiobooks'))

        def timeChange(label):
            self.updateGraph()
            plt.draw()

        self.radioTime.on_clicked(timeChange)

        # Unit Change RadioButton 2: Change Money Units
        rax = plt.axes(getRadioPosList(self.ax2), axisbg=axcolor)
        self.radioCost = RadioButtons(rax, ('Dollars', 'Coffees'))

        def costChange(label):
            self.updateGraph()
            plt.draw()

        self.radioCost.on_clicked(costChange)

        # Unit Change RadioButton 3: Change calorie burn units
        rax = plt.axes(getRadioPosList(self.ax3), axisbg=axcolor)
        self.radioCal = RadioButtons(rax, ('Cal (total)', 'Cal (/hour)'))

        def calChange(label):
            self.updateGraph()
            plt.draw()

        self.radioCal.on_clicked(calChange)

        # Unit Change RadioButton 4: Change CO2 Emissions Units
        rax = plt.axes(getRadioPosList(self.ax4), axisbg=axcolor)
        self.radioCO2 = RadioButtons(rax, ('CO2 (lbs)', 'CO2 (trees)'))

        def CO2Change(label):
            self.updateGraph()
            plt.draw()

        self.radioCO2.on_clicked(CO2Change)

        # Sliders 1 and 2: Distnace and Number of Trips
        # Axes and instance of slider for distance control
        axslideDist = plt.axes([0.17, 0.10, 0.77, 0.03], axisbg=axcolor)
        self.slideDist = Slider(axslideDist,
                                'Distance (miles)',
                                0.0,
                                100.0,
                                valinit=self.dist,
                                valfmt='%4.2f')
        # Axes and instance of slider for number of trips control
        axslideTrip = plt.axes([0.17, 0.05, 0.77, 0.03], axisbg=axcolor)
        self.slideTrip = Slider(axslideTrip,
                                'Trips',
                                0.0,
                                100.0,
                                valinit=self.trips,
                                valfmt='%4.2f')

        # Function for updating values after either slider is moved.
        def sliderUpdate(val):
            self.trips = self.slideTrip.val
            self.dist = self.slideDist.val
            self.calcDict = self.calculate()
            self.updateGraph()

        self.slideTrip.on_changed(sliderUpdate)
        self.slideDist.on_changed(sliderUpdate)

        axcolor = 'gold'

        # Customization RadioButton 1: Car - Car Type
        rax = plt.axes([.06, .72, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Driving Info", fontsize=11)
        rax.text(0, 1.05, "Car Type", fontsize=11)
        self.radioCarType = RadioButtons(
            rax, ('Average', 'Small Sedan', 'Medium Sedan', 'Large Sedan',
                  '4WD/Sport', 'Minivan'))

        def carTypeChange(label):
            ''' Adjusts instance of driver object based on the category selected
            using the driver object's setCat function.'''
            if label == 'Average':
                self.d.setCat('average')
            elif label == 'Small Sedan':
                self.d.setCat('smallSedan')
            elif label == ('Medium Sedan'):
                self.d.setCat('mediumSedan')
            elif label == 'Large Sedan':
                self.d.setCat('largeSedan')
            elif label == '4WD/Sport':
                self.d.setCat('4wdSport')
            elif label == 'Minivan':
                self.d.setCat('minivan')
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()

        self.radioCarType.on_clicked(carTypeChange)

        # Customization RadioButton 2: Bike - Spend on Bike
        rax = plt.axes([.26, .72, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Biking Info:", fontsize=11)
        rax.text(0, 1.05, "Spend on parts/maintenance ($/year)", fontsize=11)
        self.radioBikeSpend = RadioButtons(
            rax,
            ('$0-25', '$25-50', '$50-100', '$100-150', '$150-200', '>$200'),
            active=2)

        def bikeSpendChange(label):
            ''' Adjusts instance of biker object based on selected spending,
            using the biker object's setSpend fcn. Then updates graph.'''
            if label == '$0-25':
                self.b.setSpend(12.5)
            elif label == '$25-50':
                self.b.setSpend(37.5)
            elif label == '$50-100':
                self.b.setSpend(75)
            elif label == ('$100-150'):
                self.b.setSpend(125)
            elif label == '$150-200':
                self.b.setSpend(175)
            elif label == '>$200':
                self.b.setSpend(250)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()

        self.radioBikeSpend.on_clicked(bikeSpendChange)

        # Customization RadioButton 3: Walk - Spend on Shoes
        rax = plt.axes([.06, .424, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Walking Info:", fontsize=11)
        rax.text(0, 1.05, "Spend on a new pair of shoes", fontsize=11)
        self.radioWalkSpend = RadioButtons(
            rax,
            ('$0-25', '$25-50', '$50-100', '$100-150', '$150-200', '>$200'),
            active=1)

        def walkSpendChange(label):
            ''' Changes instance of walker object based on spending.'''
            if label == '$0-25':
                self.w.setSpend(12.5)
            elif label == '$25-50':
                self.w.setSpend(37.5)
            elif label == '$50-100':
                self.w.setSpend(75)
            elif label == ('$100-150'):
                self.w.setSpend(125)
            elif label == '$150-200':
                self.w.setSpend(175)
            elif label == '>$200':
                self.w.setSpend(250)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()

        self.radioWalkSpend.on_clicked(walkSpendChange)

        # Customization RadioButton 4: Person - Sex
        rax = plt.axes([.26, .424, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Calorie Burn Info:", fontsize=11)
        rax.text(0, 1.05, "Sex", fontsize=11)
        self.radioPersonSex = RadioButtons(rax, ('Male', 'Female'), active=0)

        def personSexChange(label):
            ''' Changes the sex of the person instance of the current instnace
            of the driver, biker, and walker objects. So much OOP!!!!'''
            if label == 'Male':
                self.d.person.setSex('M', True)
                self.b.person.setSex('M', True)
                self.w.person.setSex('M', True)
            elif label == 'Female':
                self.d.person.setSex('F', True)
                self.b.person.setSex('F', True)
                self.w.person.setSex('F', True)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()

        self.radioPersonSex.on_clicked(personSexChange)

        # Reset Button
        axReset = plt.axes([0.17, 0.25, 0.15, 0.10])
        bReset = Button(axReset, 'Reset Defaults')

        def resetDefaults(event):
            ''' Resets all buttons/sliders to their default position,
            which triggers recalculations and redrawing of the
            graphs. This function is a little slow.'''
            self.slideDist.set_val(1)
            self.slideTrip.set_val(1)
            self.radioTime.set_active(0)
            self.radioCost.set_active(0)
            self.radioCal.set_active(0)
            self.radioCO2.set_active(0)
            self.radioCarType.set_active(0)
            self.radioBikeSpend.set_active(2)
            self.radioWalkSpend.set_active(1)
            self.radioPersonSex.set_active(0)
            plt.draw()

        bReset.on_clicked(resetDefaults)

        # These keep the current drawing current.
        self.updateGraph()
        plt.show()

    def calculate(self):
        ''' This function does all the calculating behind the program. It uses
        attributes of the driver, walker, and biker object's to calculate
        values for time, cost, calorie burn, and CO2 emitted in various units.
        This information is stored in the handy-dandy dictionary: calcDict.'''

        # Dictionary that holds calculations for different categories in the form
        # of lists, where [0]=driver, [1]=biker, [2]=walker
        calcDict = {
            'time': [],
            'cost': [],
            'cal': [],
            'time-mins': [],
            'time-audio': [],
            'cost-coffee': [],
            'cal-hour': [],
            'cal-sansBMR': [],
            'CO2': 0.0,
            'CO2-tree': 0.0
        }

        # Time in hours
        calcDict['time'].append(self.dist * self.trips / self.d.getMPH())
        calcDict['time'].append(self.dist * self.trips / self.b.getMPH())
        calcDict['time'].append(self.dist * self.trips / self.w.getMPH())

        # Cost in US dollars
        calcDict['cost'].append(self.d.getCost() * self.dist * self.trips)
        calcDict['cost'].append(self.b.getCost() * self.dist * self.trips)
        calcDict['cost'].append(self.w.getCost() * self.dist * self.trips)

        # Total calories burned
        calcDict['cal'].append(self.d.person.getCal() * calcDict['time'][0])
        calcDict['cal'].append(self.b.person.getCal() * calcDict['time'][1])
        calcDict['cal'].append(self.w.person.getCal() * calcDict['time'][2])

        ## Alternative units for above categories

        # Time in audiobooks (based on avg len of 12.59 hours)
        # Note: avg length determined from sample of 25 bestsellers on Audible.com
        calcDict['time-mins'].append(calcDict['time'][0] * 60)
        calcDict['time-mins'].append(calcDict['time'][1] * 60)
        calcDict['time-mins'].append(calcDict['time'][2] * 60)

        # Time in audiobooks (based on avg len of 12.59 hours)
        # Note: avg length determined from sample of 25 bestsellers on Audible.com
        calcDict['time-audio'].append(calcDict['time'][0] / 12.59)
        calcDict['time-audio'].append(calcDict['time'][1] / 12.59)
        calcDict['time-audio'].append(calcDict['time'][2] / 12.59)

        #Cost in terms of coffee at blue Mondays burned per hour
        calcDict['cost-coffee'].append(calcDict['cost'][0] / 2.60)
        calcDict['cost-coffee'].append(calcDict['cost'][1] / 2.60)
        calcDict['cost-coffee'].append(calcDict['cost'][2] / 2.60)

        #Cal burned per hour
        calcDict['cal-hour'].append(self.d.person.getCal())
        calcDict['cal-hour'].append(self.b.person.getCal())
        calcDict['cal-hour'].append(self.w.person.getCal())

        # CO2 emissions in lbs
        calcDict['CO2'] = self.d.getCO2() * (self.dist * self.trips)

        # CO2 emissions in terms of trees planted
        # A single tree planted thru americanforests.org sequesters 911 pounds of CO2
        # This value reflects the number of trees one should plant to sequester the carbon
        # emitted by driving
        calcDict['CO2-tree'] = (calcDict['CO2'] / 911)

        return calcDict

    def makeGraph(self, ax, data, ylab):
        ''' makeGraph is called by updateGraph and redraws the 3 graphs
        every time it is called. The x labels are always the same but the
        y values are passed in as 'data'. '''

        ax.clear()

        N = 3  # 3 divisions of x axis
        maxNum = max(data)  # determine max y value

        ind = np.arange(N)  # the x locations for the groups
        width = 0.5  # the width of the bars

        ## the bars
        rects1 = ax.bar(ind, data, width, color=['cyan', 'yellow', 'magenta'])

        # axes and labels
        ax.set_xlim(-.1, len(ind) - .4)
        ax.set_ylim(0, maxNum + (maxNum / 10) * 2)

        xTickMarks = ['Drive', 'Bike', 'Walk']
        ax.set_xticks(ind + (width / 2))
        xtickNames = ax.set_xticklabels(xTickMarks)
        ax.set_ylabel(ylab)

        def autolabel(rects):
            ''' Adds labels above bars. Code adapted from matplotlib demo code:
            http://matplotlib.org/examples/api/barchart_demo.html'''
            # attach some text labels
            for rect in rects:
                height = rect.get_height()
                ax.text(rect.get_x() + rect.get_width() / 2.,
                        1.05 * height,
                        '%4.2f' % (height),
                        fontsize=11,
                        ha='center',
                        va='bottom')

        autolabel(rects1)

    def showInfo(self, ax, data, msg):
        ''' The forth subplot (axes) holds text instead of a bar plot
        and it gets updated using this function.'''

        # Erase any existing information to avoid redrawing
        ax.clear()

        # Remove labels
        ax.set_xticklabels('')
        ax.set_yticklabels('')

        ax.text(.08, .70, "By not driving...", fontsize=11)

        ax.text(.4,
                .45,
                '%4.2f' % (data),
                style='italic',
                bbox={
                    'facecolor': 'lightgreen',
                    'alpha': 0.65,
                    'pad': 10
                })
        ax.text(.08, .20, msg, fontsize=11)

    def updateGraph(self):
        ''' This is called whenever the graph needs to be updated. It calls
        self.calculate to make sure self.calcDate is up to date and it uses
        the values of the radio buttons and sliders as well as the values stored
        in calcDict to determine which y values to pass into makeGraph to
        make the 3 graphs and which values to pass to showInfo.'''

        self.calcDict = self.calculate()

        if self.radioTime.value_selected == 'Hours':
            self.makeGraph(self.ax1, self.calcDict['time'], 'Time (Hours)')
        elif self.radioTime.value_selected == 'Minutes':
            self.makeGraph(self.ax1, self.calcDict['time-mins'],
                           'Time (Minutes)')
        elif self.radioTime.value_selected == 'Audiobooks':
            self.makeGraph(self.ax1, self.calcDict['time-audio'],
                           'Time (Audiobooks)')

        if self.radioCost.value_selected == 'Dollars':
            self.makeGraph(self.ax2, self.calcDict['cost'], 'Cost ($)')
        elif self.radioCost.value_selected == 'Coffees':
            self.makeGraph(self.ax2, self.calcDict['cost-coffee'],
                           'Cost (Coffees)')
        else:
            print('Error!')

        if self.radioCal.value_selected == 'Cal (total)':
            self.makeGraph(self.ax3, self.calcDict['cal'], 'Calories (total)')
        elif self.radioCal.value_selected == 'Cal (/hour)':
            self.makeGraph(self.ax3, self.calcDict['cal-hour'],
                           'Calories (/hour)')
        else:
            print('Error!')

        if self.radioCO2.value_selected == 'CO2 (lbs)':
            self.showInfo(self.ax4, self.calcDict['CO2'],
                          'Pounds of CO2 not emitted')
        elif self.radioCO2.value_selected == 'CO2 (trees)':
            self.showInfo(self.ax4, self.calcDict['CO2-tree'],
                          'Trees planted!')
        else:
            print('Error!')
Пример #12
0
def interact2D(data, xaxis=0, yaxis=1, channel=0, local=False, verbose=True):
    """ Interactive 2D plot of the dataset.
    Side plots show x and y projections of the slice (shaded gray).
    Left clicks on the main axes draw 1D slices on side plots at the coordinates selected.
    Right clicks remove the 1D slices.
    For 3+ dimensional data, sliders below the main axes are used to change which slice is viewed.

    Parameters
    ----------
    data : WrightTools.Data object
        Data to plot.
    xaxis : string, integer, or data.Axis object (optional)
        Expression or index of x axis. Default is 0.
    yaxis : string, integer, or data.Axis object (optional)
        Expression or index of y axis. Default is 1.
    channel : string, integer, or data.Channel object (optional)
        Name or index of channel to plot. Default is 0.
    local : boolean (optional)
        Toggle plotting locally. Default is False.
    verbose : boolean (optional)
        Toggle talkback. Default is True.
    """
    # avoid changing passed data object
    data = data.copy()
    # unpack
    channel = get_channel(data, channel)
    for ch in data.channel_names:
        if ch != channel.natural_name:
            data.remove_channel(ch, verbose=False)
    varis = list(data.variable_names)
    for ax in data.axes:
        for v in ax.variables:
            try:
                varis.remove(v.natural_name)
            except ValueError:
                pass  # Already removed, can't double count
    for v in varis:
        data.remove_variable(v, implied=False, verbose=False)
    xaxis, yaxis = get_axes(data, [xaxis, yaxis])
    cmap = get_colormap(channel)
    current_state = Bunch()
    # create figure
    nsliders = data.ndim - 2
    if nsliders < 0:
        raise DimensionalityError(">= 2", data.ndim)
    # TODO: implement aspect; doesn't work currently because of our incorporation of colorbar
    fig, gs = create_figure(width="single", nrows=7 + nsliders, cols=[1, 1, 1, 1, 1, "cbar"])
    # create axes
    ax0 = plt.subplot(gs[1:6, 0:5])
    ax0.patch.set_facecolor("w")
    cax = plt.subplot(gs[1:6, -1])
    sp_x = add_sideplot(ax0, "x", pad=0.1)
    sp_y = add_sideplot(ax0, "y", pad=0.1)
    ax_local = plt.subplot(gs[0, 0], aspect="equal", frameon=False)
    ax_title = plt.subplot(gs[0, 3], frameon=False)
    ax_title.text(
        0.5,
        0.5,
        data.natural_name,
        fontsize=18,
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax_title.transAxes,
    )
    ax_title.set_axis_off()
    # NOTE: there are more axes here for more buttons / widgets in future plans
    # create lines
    x_color = "#00BFBF"  # cyan with saturation increased
    y_color = "coral"
    line_sp_x = sp_x.plot([None], [None], visible=False, color=x_color)[0]
    line_sp_y = sp_y.plot([None], [None], visible=False, color=y_color)[0]
    crosshair_hline = ax0.plot([None], [None], visible=False, color=x_color)[0]
    crosshair_vline = ax0.plot([None], [None], visible=False, color=y_color)[0]
    current_state.xpos = crosshair_hline.get_ydata()[0]
    current_state.ypos = crosshair_vline.get_xdata()[0]
    current_state.bin_vs_x = True
    current_state.bin_vs_y = True
    # create buttons
    current_state.local = local
    radio = RadioButtons(ax_local, (" global", " local"))
    if local:
        radio.set_active(1)
    else:
        radio.set_active(0)
    for circle in radio.circles:
        circle.set_radius(0.14)
    # create sliders
    sliders = {}
    for axis in data.axes:
        if axis not in [xaxis, yaxis]:
            if axis.size > np.prod(axis.shape):
                raise NotImplementedError("Cannot use multivariable axis as a slider")
            slider_axes = plt.subplot(gs[~len(sliders), :]).axes
            slider = Slider(slider_axes, axis.label, 0, axis.points.size - 1, valinit=0, valstep=1)
            sliders[axis.natural_name] = slider
            slider.ax.vlines(
                range(axis.points.size - 1),
                *slider.ax.get_ylim(),
                colors="k",
                linestyle=":",
                alpha=0.5
            )
            slider.valtext.set_text(gen_ticklabels(axis.points)[0])
    # initial xyz start are from zero indices of additional axes
    current_state.dat = data.chop(
        xaxis.natural_name,
        yaxis.natural_name,
        at=_at_dict(data, sliders, xaxis, yaxis),
        verbose=False,
    )[0]
    clim = get_clim(channel, current_state)
    ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
    if clim[0] == clim[1]:
        clim = [-1 if channel.signed else 0, 1]
    obj2D = ax0.pcolormesh(
        current_state.dat,
        cmap=cmap,
        vmin=clim[0],
        vmax=clim[1],
        ylabel=yaxis.label,
        xlabel=xaxis.label,
    )
    ax0.grid(b=True)
    # colorbar
    colorbar = plot_colorbar(
        cax, cmap=cmap, label=channel.natural_name, ticks=np.linspace(clim[0], clim[1], 11)
    )
    colorbar.set_ticklabels(ticklabels)
    fig.canvas.draw_idle()

    def draw_sideplot_projections():
        arr = current_state.dat[channel.natural_name][:]
        xind = list(
            np.array(
                current_state.dat.axes[
                    current_state.dat.axis_expressions.index(xaxis.expression)
                ].shape
            )
            > 1
        ).index(True)
        yind = list(
            np.array(
                current_state.dat.axes[
                    current_state.dat.axis_expressions.index(yaxis.expression)
                ].shape
            )
            > 1
        ).index(True)
        if channel.signed:
            temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
            temp_arr[temp_arr < 0] = 0
            x_proj_pos = np.nanmean(temp_arr, axis=yind)
            y_proj_pos = np.nanmean(temp_arr, axis=xind)

            temp_arr = np.ma.masked_array(arr, np.isnan(arr), copy=True)
            temp_arr[temp_arr > 0] = 0
            x_proj_neg = np.nanmean(temp_arr, axis=yind)
            y_proj_neg = np.nanmean(temp_arr, axis=xind)

            x_proj = np.nanmean(arr, axis=yind)
            y_proj = np.nanmean(arr, axis=xind)

            alpha = 0.4
            blue = "#517799"  # start with #87C7FF and change saturation
            red = "#994C4C"  # start with #FF7F7F and change saturation
            if current_state.bin_vs_x:
                x_proj_norm = max(np.nanmax(x_proj_pos), np.nanmax(-x_proj_neg))
                if x_proj_norm != 0:
                    x_proj_pos /= x_proj_norm
                    x_proj_neg /= x_proj_norm
                    x_proj /= x_proj_norm
                try:
                    sp_x.fill_between(xaxis.points, x_proj_pos, 0, color=red, alpha=alpha)
                    sp_x.fill_between(xaxis.points, 0, x_proj_neg, color=blue, alpha=alpha)
                    sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
                except ValueError:  # Input passed into argument is not 1-dimensional
                    current_state.bin_vs_x = False
                    sp_x.set_visible(False)
            if current_state.bin_vs_y:
                y_proj_norm = max(np.nanmax(y_proj_pos), np.nanmax(-y_proj_neg))
                if y_proj_norm != 0:
                    y_proj_pos /= y_proj_norm
                    y_proj_neg /= y_proj_norm
                    y_proj /= y_proj_norm
                try:
                    sp_y.fill_betweenx(yaxis.points, y_proj_pos, 0, color=red, alpha=alpha)
                    sp_y.fill_betweenx(yaxis.points, 0, y_proj_neg, color=blue, alpha=alpha)
                    sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_y = False
                    sp_y.set_visible(False)
        else:
            if current_state.bin_vs_x:
                x_proj = np.nanmean(arr, axis=yind)
                x_proj = norm(x_proj, channel.signed)
                try:
                    sp_x.fill_between(xaxis.points, x_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_x = False
                    sp_x.set_visible(False)
            if current_state.bin_vs_y:
                y_proj = np.nanmean(arr, axis=xind)
                y_proj = norm(y_proj, channel.signed)
                try:
                    sp_y.fill_betweenx(yaxis.points, y_proj, 0, color="k", alpha=0.3)
                except ValueError:
                    current_state.bin_vs_y = False
                    sp_y.set_visible(False)

    draw_sideplot_projections()

    ax0.set_xlim(xaxis.points.min(), xaxis.points.max())
    ax0.set_ylim(yaxis.points.min(), yaxis.points.max())

    if channel.signed:
        sp_x.set_ylim(-1.1, 1.1)
        sp_y.set_xlim(-1.1, 1.1)

    def update_sideplot_slices():
        # TODO:  if bins is only available along one axis, slicing should be valid along the other
        #   e.g., if bin_vs_y =  True, then assemble slices vs x
        #   for now, just uniformly turn off slicing
        if (not current_state.bin_vs_x) or (not current_state.bin_vs_y):
            return
        xlim = ax0.get_xlim()
        ylim = ax0.get_ylim()
        x0 = current_state.xpos
        y0 = current_state.ypos

        crosshair_hline.set_data(np.array([xlim, [y0, y0]]))
        crosshair_vline.set_data(np.array([[x0, x0], ylim]))

        at_dict = _at_dict(data, sliders, xaxis, yaxis)
        at_dict[xaxis.natural_name] = (x0, xaxis.units)
        side_plot_data = data.chop(yaxis.natural_name, at=at_dict, verbose=False)
        side_plot = side_plot_data[0][channel.natural_name].points
        side_plot = norm(side_plot, channel.signed)
        line_sp_y.set_data(side_plot, yaxis.points)
        side_plot_data.close()

        at_dict = _at_dict(data, sliders, xaxis, yaxis)
        at_dict[yaxis.natural_name] = (y0, yaxis.units)
        side_plot_data = data.chop(xaxis.natural_name, at=at_dict, verbose=False)
        side_plot = side_plot_data[0][channel.natural_name].points
        side_plot = norm(side_plot, channel.signed)
        line_sp_x.set_data(xaxis.points, side_plot)
        side_plot_data.close()

    def update_local(index):
        if verbose:
            print("normalization:", index)
        current_state.local = radio.value_selected[1:] == "local"
        clim = get_clim(channel, current_state)
        ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
        colorbar.set_ticklabels(ticklabels)
        if clim[0] == clim[1]:
            clim = [-1 if channel.signed else 0, 1]
        obj2D.set_clim(*clim)
        fig.canvas.draw_idle()

    def update(info):
        if isinstance(info, (float, int)):
            current_state.dat.close()
            current_state.dat = data.chop(
                xaxis.natural_name,
                yaxis.natural_name,
                at={
                    a.natural_name: (a[:].flat[int(sliders[a.natural_name].val)], a.units)
                    for a in data.axes
                    if a not in [xaxis, yaxis]
                },
                verbose=False,
            )[0]
            for k, s in sliders.items():
                s.valtext.set_text(
                    gen_ticklabels(data.axes[data.axis_names.index(k)].points)[int(s.val)]
                )
            obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
            clim = get_clim(channel, current_state)
            ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
            if clim[0] == clim[1]:
                clim = [-1 if channel.signed else 0, 1]
            obj2D.set_clim(*clim)
            colorbar.set_ticklabels(ticklabels)
            sp_x.collections.clear()
            sp_y.collections.clear()
            draw_sideplot_projections()
            if line_sp_x.get_visible() and line_sp_y.get_visible():
                update_sideplot_slices()
        if isinstance(info, mpl.backend_bases.MouseEvent) and info.inaxes == ax0:  # crosshairs
            x0 = info.xdata
            y0 = info.ydata
            if x0 is None or y0 is None:
                raise TypeError(info)
            xlim = ax0.get_xlim()
            ylim = ax0.get_ylim()
            if x0 > xlim[0] and x0 < xlim[1] and y0 > ylim[0] and y0 < ylim[1]:
                current_state.xpos = info.xdata
                current_state.ypos = info.ydata
                if info.button == 1 or info.button is None:  # left click
                    if verbose:
                        print(current_state.xpos, current_state.ypos)
                    update_sideplot_slices()
                    line_sp_x.set_visible(True)
                    line_sp_y.set_visible(True)
                    crosshair_hline.set_visible(True)
                    crosshair_vline.set_visible(True)
                elif info.button == 3:  # right click
                    line_sp_x.set_visible(False)
                    line_sp_y.set_visible(False)
                    crosshair_hline.set_visible(False)
                    crosshair_vline.set_visible(False)
        fig.canvas.draw_idle()

    side_plotter = plt.matplotlib.widgets.AxesWidget(ax0)
    side_plotter.connect_event("button_release_event", update)

    radio.on_clicked(update_local)

    for slider in sliders.values():
        slider.on_changed(update)

    return obj2D, sliders, side_plotter, crosshair_hline, crosshair_vline, radio, colorbar
Пример #13
0
class Simulator:
    """
    Class Variables:

    left_pressed - Left mouse button is pressed
    running - Determines whether or not to start a new Simulator
    simulator_plot_width - The width of the Simulator pyplot axis
    arm_limbs - Limbs listed in order
    arm_joints - Joints listed in order
    total_rotations - Array listing the total amount each joint should be rotated
    firsts - Array containing values regarding the first movement of each joint

    Instance Variables:

    current_fig_axes - The axes of the current figure
    cw - 0 means clockwise, 1 means counterclockwise
    previous_arrow - 0 previous movement was not made by an arrow, 1 right, 2 left, 3 up, 4 down
    degree_quadrants - Quadrants to determine arrow movement of simulator

    Joints:
    wrist - Wrist of the arm
    elbow - Elbow of the arm
    shoulder - Rotating base portion of the arm, the shoulder

    Limbs:
    hand - Grabber portion of the arm
    forearm - Middle portion of the arm. Attaches wrist to elbow
    arm - Last segment of the arm. Attaches elbow to shoulder

    Limb Sizes:
    pvc_height - Height of all of the limbs
    hand_width - Width of the hand
    forearm_width - Width of the forearm and arm

    Radio Buttons:
    joint_selector - Select which joint to use
    clockwise_selector - Select clockwise or counterclockwise
    """

    left_pressed = False
    running = False
    simulator_plot_width = 1000

    arm_limbs = []
    arm_joints = []
    motors = []

    total_rotations = [0, 0, 0]
    unlocked = True

    firsts = [True, True, True]
    simulator_lock = True

    def __init__(self, motors):

        if not Simulator.running:
            Simulator.running = True
            Simulator.motors = motors

            quadrant_width = Simulator.simulator_plot_width / 2

            figure = plt.figure(num='Robotic Arm Simulator', figsize=(7, 6))
            figure.canvas.mpl_connect('button_press_event', self.mouse_clicked)
            figure.canvas.mpl_connect('button_release_event',
                                      self.mouse_released)
            figure.canvas.mpl_connect('close_event', Simulator.handle_close)
            figure.canvas.mpl_connect('motion_notify_event',
                                      self.mouse_dragged)
            #figure.canvas.mpl_connect('key_press_event', self.radio_control)
            figure.canvas.mpl_connect('key_press_event',
                                      self.arrow_key_listener)

            plt.axis([
                -quadrant_width, quadrant_width, -quadrant_width,
                quadrant_width
            ])
            plt.title('Simulator')
            plt.subplots_adjust(left=0.4, bottom=0.3)
            self.current_fig_axes = plt.gca()

            # Shapes
            wrist_radius = (1 / 40) * Simulator.simulator_plot_width
            elbow_radius = (
                4.5 / 120
            ) * Simulator.simulator_plot_width  # Same as shoulder radius

            self.pvc_height = (2.5 / 60) * Simulator.simulator_plot_width
            self.hand_width = (1 / 12) * Simulator.simulator_plot_width
            self.forearm_width = (
                1 / 6) * Simulator.simulator_plot_width  # Same as arm width

            shoulder_center = (0, 0)
            elbow_center = (shoulder_center[0] - self.forearm_width, 0)
            wrist_center = (elbow_center[0] - self.forearm_width, 0)

            # Key control variables
            self.arrow_key_pressed = False

            # Limbs should be behind the joints
            self.hand = self._default_limb(self.hand_width)
            plt.gca().add_patch(self.hand)

            self.forearm = self._default_limb(self.forearm_width)
            plt.gca().add_patch(self.forearm)

            self.arm = self._default_limb(self.forearm_width)
            plt.gca().add_patch(self.arm)

            # Joints
            self.wrist = Simulator._add_joint(wrist_center, wrist_radius, 'r')
            plt.gca().add_patch(self.wrist)

            self.elbow = Simulator._add_joint(elbow_center, elbow_radius, 'r')
            plt.gca().add_patch(self.elbow)

            self.shoulder = Simulator._add_joint(shoulder_center, elbow_radius,
                                                 'r')
            plt.gca().add_patch(self.shoulder)

            self._update_limb_positions()
            self.curr_joint_rotation = 0
            self.cw = 0
            self.previous_arrow = 0

            self.degree_quadrants = [
                Simulator.to_radians(90),
                Simulator.to_radians(180),
                Simulator.to_radians(270),
                Simulator.to_radians(360)
            ]

            # List the joints and limbs in order
            Simulator.arm_joints = [self.wrist, self.elbow, self.shoulder]
            Simulator.arm_limbs = [self.hand, self.forearm, self.arm]

            axcolor = 'lightgoldenrodyellow'
            rax = plt.axes([0.05, 0.6, 0.17, 0.20], facecolor=axcolor)

            # Select joint
            self.joint_selector = RadioButtons(rax,
                                               ('wrist', 'elbow', 'shoulder'))
            self.joint_selector.on_clicked(self.set_rotation_point)

            rax = plt.axes([0.05, 0.4, 0.15, 0.15], facecolor=axcolor)

            self.clockwise_selector = RadioButtons(rax, ('CW', 'CCW'))
            self.clockwise_selector.on_clicked(self.set_clockwise)

            rax = plt.axes([0.40, 0.17, 0.50, 0.03])

            init_val = 16
            self.speed = Simulator.to_radians(init_val)

            self.speed_slider = Slider(rax,
                                       'Speed' + '\n' + '(in degrees)',
                                       0.1,
                                       20,
                                       valinit=init_val)
            self.speed_slider.on_changed(self.update_speed)

            # Text
            axbox_wrist_text = plt.axes([0.31, 0.07, 0.1, 0.05])
            axbox_wrist_text.text(0, 0, 'Wrist: ', fontsize=12)
            axbox_wrist_text.axis('off')

            axbox_elbow_text = plt.axes([0.51, 0.07, 0.1, 0.05])
            axbox_elbow_text.text(0, 0, 'Elbow: ', fontsize=12)
            axbox_elbow_text.axis('off')

            axbox_shoulder_text = plt.axes([0.71, 0.07, 0.1, 0.05])
            axbox_shoulder_text.text(0, 0, 'Shoulder: ', fontsize=12)
            axbox_shoulder_text.axis('off')

            initial_text = r'${:.0f}\degree$'.format(0)
            self.axbox_wrist = plt.axes([0.40, 0.07, 0.1, 0.05])
            self.axbox_wrist.text(0, 0, initial_text, fontsize=15)
            self.axbox_wrist.axis('off')

            self.axbox_elbow = plt.axes([0.61, 0.07, 0.1, 0.05])
            self.axbox_elbow.text(0, 0, initial_text, fontsize=15)
            self.axbox_elbow.axis('off')

            self.axbox_shoulder = plt.axes([0.84, 0.07, 0.1, 0.05])
            self.axbox_shoulder.text(0, 0, initial_text, fontsize=15)
            self.axbox_shoulder.axis('off')

            self.axboxes = [
                self.axbox_wrist, self.axbox_elbow, self.axbox_shoulder
            ]

            plt.axes(self.current_fig_axes)

            plt.show()

    def set_rotation_point(self, label):
        if label == 'wrist':
            self.curr_joint_rotation = 0
        elif label == 'elbow':
            self.curr_joint_rotation = 1
        else:
            self.curr_joint_rotation = 2

    def set_clockwise(self, label):
        if label == 'CW':
            self.cw = 0
        else:
            self.cw = 1

    def update_speed(self, speed):
        self.speed = Simulator.to_radians(speed)

    def _get_limb_origin(self, joint, limb_width):
        return joint.center[0] - limb_width, joint.center[
            1] - self.pvc_height / 2

    def _default_limb(self, width):
        return plt.Rectangle((0, 0),
                             height=self.pvc_height,
                             width=width,
                             fc='gray')

    @staticmethod
    def _add_joint(center, radius, color_str):
        return plt.Circle(center, radius=radius, fc=color_str)

    def _update_angles1(self, degrees):
        for val in range(len(degrees)):
            self.axboxes[val].clear()
            self.axboxes[val].text(0,
                                   0,
                                   r'${:.0f}\degree$'.format(
                                       Simulator.to_degrees(degrees[val])),
                                   fontsize=15)
            self.axboxes[val].axis('off')

    def _update_angles(self, degrees):
        if self.curr_joint_rotation == 0:
            self.axbox_wrist.clear()
            self.axbox_wrist.text(0,
                                  0,
                                  r'${:.0f}\degree$'.format(degrees),
                                  fontsize=15,
                                  bbox=dict(facecolor='none',
                                            edgecolor='black'))
            self.axbox_wrist.axis('off')
        elif self.curr_joint_rotation == 1:
            self.axbox_elbow.clear()
            self.axbox_elbow.text(0,
                                  0,
                                  r'${:.0f}\degree$'.format(degrees),
                                  fontsize=15,
                                  bbox=dict(facecolor='none',
                                            edgecolor='black'))
            self.axbox_elbow.axis('off')
        elif self.curr_joint_rotation == 2:
            self.axbox_shoulder.clear()
            self.axbox_shoulder.text(0,
                                     0,
                                     r'${:.0f}\degree$'.format(degrees),
                                     fontsize=15,
                                     bbox=dict(facecolor='none',
                                               edgecolor='black'))
            self.axbox_shoulder.axis('off')
        else:
            print("Something went wrong in _update_angles")

    def _update_limb_positions(self):
        self.hand.set_xy(self._get_limb_origin(self.wrist, self.hand_width))
        self.forearm.set_xy(
            self._get_limb_origin(self.elbow, self.forearm_width))
        self.arm.set_xy(
            self._get_limb_origin(self.shoulder, self.forearm_width))
        #plt.gca().figure.canvas.draw()

    def rotate_joints(self, radians):

        base = Simulator.arm_joints[self.curr_joint_rotation]

        if self.cw == 0:
            radians = -radians

        # Handle joints first
        for i in range(self.curr_joint_rotation + 1):
            if i != 0:
                joint = Simulator.arm_joints[i - 1]
                joint.center = Simulator.rotate(base.center, joint.center,
                                                radians)

        for i in range(self.curr_joint_rotation + 1):

            Simulator.total_rotations[i] += radians

            t = matplotlib.transforms.Affine2D().rotate_around(
                Simulator.arm_joints[i].center[0],
                Simulator.arm_joints[i].center[1],
                Simulator.total_rotations[i])
            t += plt.gca().transData
            Simulator.arm_limbs[i].set_transform(t)

        self._update_limb_positions()

        for i in range(len(Simulator.total_rotations)):
            Simulator.total_rotations[i] = self.simplify_radians(
                Simulator.total_rotations[i])

        #print("About to update angles")
        self._update_angles1(Simulator.total_rotations)
        Simulator.motors[self.curr_joint_rotation].step_motor_degrees(
            Simulator.to_degrees(radians))
        #self._update_angles(Simulator.to_degrees(Simulator.total_rotations[self.curr_joint_rotation]))
        plt.gca().figure.canvas.draw()

    @staticmethod
    def get_empty_transform():
        return matplotlib.transforms.Affine2D().rotate_around(0, 0, 0)

    def mouse_clicked(self, event):
        if event.inaxes == self.current_fig_axes:
            Simulator.left_pressed = True
            print("You pressed: ", event.x, event.y)
            self.rotate_joints(self.speed)
            self.previous_arrow = 0

    def mouse_released(self, event):
        if event.inaxes == self.current_fig_axes:
            Simulator.left_pressed = False
            print("Released at: ", event.x, event.y)

    def mouse_dragged(self, event):
        if Simulator.left_pressed:
            if Simulator.unlocked:
                Simulator.unlocked = False
                self.previous_arrow = 0
                self.rotate_joints(self.speed)
                Simulator.unlocked = True

    @staticmethod
    def to_radians(degrees):
        return degrees * (math.pi / 180)

    @staticmethod
    def to_degrees(radians):
        return radians * (180 / math.pi)

    @staticmethod
    def simplify_radians(radians):
        while radians < 0:
            radians += 6.28

        while radians > 6.28:
            radians -= 6.28

        return radians

    def reverse_clockwise(self):
        if self.cw == 0:
            self.cw = 1
            self.clockwise_selector.set_active(1)
        else:
            self.cw = 0
            self.clockwise_selector.set_active(0)

    def opposite_arrow_pressed(self, current_arrow):
        self.reverse_clockwise()
        self.rotate_joints(self.speed)
        self.previous_arrow = current_arrow

    # Adds key controls to the radio buttons for optimal control of the simulator
    #def radio_control(self, event):

    # Move the robot based on its current joint position and arrow keys
    def arrow_key_listener(self, event):
        if Simulator.simulator_lock:
            Simulator.simulator_lock = False
            if event.key == 'c':
                self.previous_arrow = 0
                self.reverse_clockwise()
            elif event.key == '1':
                self.curr_joint_rotation = 0
                self.joint_selector.set_active(0)
                self.previous_arrow = 0
            elif event.key == '2':
                self.curr_joint_rotation = 1
                self.joint_selector.set_active(1)
                self.previous_arrow = 0
            elif event.key == '3':
                self.curr_joint_rotation = 2
                self.joint_selector.set_active(2)
                self.previous_arrow = 0
            elif event.key == 'right':
                if self.previous_arrow == 2:
                    self.opposite_arrow_pressed(1)
                else:
                    if self.previous_arrow != 1:
                        self.move_robot_right()
                        self.previous_arrow = 1
                    else:
                        self.rotate_joints(self.speed)
            elif event.key == 'left':
                if self.previous_arrow == 1:
                    self.opposite_arrow_pressed(2)
                else:
                    if self.previous_arrow != 2:
                        self.move_robot_left()
                        self.previous_arrow = 2
                    else:
                        self.rotate_joints(self.speed)
            elif event.key == 'up':
                if self.previous_arrow == 4:
                    self.opposite_arrow_pressed(3)
                else:
                    if self.previous_arrow != 3:
                        self.move_robot_up()
                        self.previous_arrow = 3
                    else:
                        self.rotate_joints(self.speed)
            elif event.key == 'down':
                if self.previous_arrow == 3:
                    self.opposite_arrow_pressed(4)
                else:
                    if self.previous_arrow != 4:
                        self.move_robot_down()
                        self.previous_arrow = 4
                    else:
                        self.rotate_joints(self.speed)
            #print("Set to true")
            Simulator.simulator_lock = True

    def move_robot_right(self):
        # TEST
        curr_joint_rotation = Simulator.total_rotations[
            self.curr_joint_rotation]
        if curr_joint_rotation == 0 and Simulator.firsts[
                self.curr_joint_rotation]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
            Simulator.firsts[self.curr_joint_rotation] = False
        elif curr_joint_rotation < self.degree_quadrants[0]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        elif curr_joint_rotation < self.degree_quadrants[1]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        elif curr_joint_rotation < self.degree_quadrants[2]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        else:
            self.clockwise_selector.set_active(0)
            self.cw = 0

        self.rotate_joints(self.speed)

    def move_robot_left(self):
        # TEST
        curr_joint_rotation = Simulator.total_rotations[
            self.curr_joint_rotation]
        if curr_joint_rotation == 0 and Simulator.firsts[
                self.curr_joint_rotation]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
            Simulator.firsts[self.curr_joint_rotation] = False
        elif curr_joint_rotation < self.degree_quadrants[0]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        elif curr_joint_rotation < self.degree_quadrants[1]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        elif curr_joint_rotation < self.degree_quadrants[2]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        else:
            self.clockwise_selector.set_active(1)
            self.cw = 1

        self.rotate_joints(self.speed)

    def move_robot_up(self):
        # TEST
        curr_joint_rotation = Simulator.total_rotations[
            self.curr_joint_rotation]
        if curr_joint_rotation < self.degree_quadrants[0]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        elif curr_joint_rotation < self.degree_quadrants[1]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        elif curr_joint_rotation < self.degree_quadrants[2]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        else:
            self.clockwise_selector.set_active(0)
            self.cw = 0

        self.rotate_joints(self.speed)

    def move_robot_down(self):
        # TEST
        curr_joint_rotation = Simulator.total_rotations[
            self.curr_joint_rotation]
        if curr_joint_rotation < self.degree_quadrants[0]:
            self.clockwise_selector.set_active(1)
            self.cw = 1
        elif curr_joint_rotation < self.degree_quadrants[1]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        elif curr_joint_rotation < self.degree_quadrants[2]:
            self.clockwise_selector.set_active(0)
            self.cw = 0
        else:
            self.clockwise_selector.set_active(1)
            self.cw = 1

        self.rotate_joints(self.speed)

    @staticmethod
    def handle_close(event):
        # Restart static variables
        Simulator.left_pressed = False
        Simulator.running = False

        Simulator.arm_limbs = []
        Simulator.arm_joints = []

        Simulator.total_rotations = [0, 0, 0]
        Simulator.unlocked = True

        Simulator.firsts = [True, True, True]

        print("Closing")

    # https://stackoverflow.com/questions/34372480/rotate-point-about-another-point-in-degrees-python
    # Credit for this function to Mark Dickinson
    @staticmethod
    def rotate(origin, point, angle):
        ox, oy = origin
        px, py = point

        qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
        qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
        return qx, qy
Пример #14
0
def main():
    '''
    Main function of the program
    '''

    global reg
    global strategy
    global in_cluster
    global in_scanning

    # process command-line options
    file_name, interactive = lib_args.get_args()
    header, pixels = lib_fits.read_first_image(file_name)

    logging.debug('cd1_1: %s, cd1_2: %s, cd2_1: %s, cd2_2: %s',
                  header['CD1_1'], header['CD1_2'], header['CD2_1'],
                  header['CD2_2'])
    logging.debug('height: %s, width: %s', pixels.shape[0], pixels.shape[1])

    # compute background
    background, dispersion, _ = lib_background.compute_background(pixels)
    logging.debug('background: %s, dispersion: %s', int(background),
                  int(dispersion))

    print('---------------------')
    # search for clusters in a sub-region of the image
    threshold = 6.0

    # graphic output
    if interactive:

        # cluster central
        image = pixels[45:70, 40:65]

        reg = lib_cluster.Region(image, background + threshold * dispersion)

        strategy = '4centers'
        in_cluster = 1.0
        in_scanning = 0.1

        def select_strategy(value):
            global strategy
            print(('select strategy=', value))
            strategy = value

        def select_in_cluster(value):
            global in_cluster
            in_cluster = value

        def select_in_scanning(value):
            global in_scanning
            in_scanning = value

        def start_animate(value):
            print(('animate with strategy=', strategy))
            reg.animate(in_cluster=in_cluster,
                        in_scanning=in_scanning,
                        strategy=strategy)

        strategies = ['random', 'all', 'center', '4centers']
        axis_strategy = plt.axes([0.2, 0.8, 0.15, 0.15])
        axis_strategy.set_title('Scanning strategy')
        strategy_widget = RadioButtons(axis_strategy, strategies)
        strategy_widget.set_active(strategies.index(strategy))
        strategy_widget.on_clicked(select_strategy)

        axis_in_cluster = plt.axes([0.2, 0.7, 0.65, 0.03])
        in_cluster_widget = Slider(axis_in_cluster,
                                   'wait in clusters',
                                   0.0,
                                   5.0,
                                   valinit=in_cluster)
        in_cluster_widget.on_changed(select_in_cluster)

        axis_in_scanning = plt.axes([0.2, 0.6, 0.65, 0.03])
        in_scanning_widget = Slider(axis_in_scanning,
                                    'wait in scanning',
                                    0.0,
                                    1.0,
                                    valinit=in_scanning)
        in_scanning_widget.on_changed(select_in_scanning)

        axis_animate = plt.axes([0.2, 0.5, 0.1, 0.03])
        in_animate = Button(axis_animate, 'Animate')
        in_animate.on_clicked(start_animate)

        plt.show()

    return 0
Пример #15
0
class COCO_dataset_generator(object):
    def __init__(self, fig, ax, args):

        self.ax = ax
        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        self.img_dir = args['image_dir']
        self.index = 0
        self.fig = fig
        self.polys = []
        self.zoom_scale, self.points, self.prev, self.submit_p, self.lines, self.circles = 1.2, [], None, None, [], []

        self.zoom_id = fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.click_id = fig.canvas.mpl_connect('button_press_event',
                                               self.onclick)
        self.clickrel_id = fig.canvas.mpl_connect('button_release_event',
                                                  self.onclick_release)
        self.keyboard_id = fig.canvas.mpl_connect('key_press_event',
                                                  self.onkeyboard)

        self.axradio = plt.axes([0.0, 0.0, 0.2, 1])
        self.axbringprev = plt.axes([0.3, 0.05, 0.17, 0.05])
        self.axreset = plt.axes([0.48, 0.05, 0.1, 0.05])
        self.axsubmit = plt.axes([0.59, 0.05, 0.1, 0.05])
        self.axprev = plt.axes([0.7, 0.05, 0.1, 0.05])
        self.axnext = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_bringprev = Button(self.axbringprev,
                                  'Bring Previous Annotations')
        self.b_bringprev.on_clicked(self.bring_prev)
        self.b_reset = Button(self.axreset, 'Reset')
        self.b_reset.on_clicked(self.reset)
        self.b_submit = Button(self.axsubmit, 'Submit')
        self.b_submit.on_clicked(self.submit)
        self.b_next = Button(self.axnext, 'Next')
        self.b_next.on_clicked(self.next)
        self.b_prev = Button(self.axprev, 'Prev')
        self.b_prev.on_clicked(self.previous)

        self.button_axes = [
            self.axbringprev, self.axreset, self.axsubmit, self.axprev,
            self.axnext, self.axradio
        ]

        self.existing_polys = []
        self.existing_patches = []
        self.selected_poly = False
        self.objects = []
        self.feedback = True  #args['feedback']

        self.right_click = False

        self.text = ''

        with open(args['class_file'], 'r') as f:
            self.class_names = [x.strip() for x in f.readlines()]

        self.radio = RadioButtons(self.axradio, self.class_names)
        self.class_names = ('BG', ) + tuple(self.class_names)

        self.img_paths = sorted(glob.glob(os.path.join(self.img_dir, '*.jpg')))

        if len(self.img_paths) == 0:
            self.img_paths = sorted(
                glob.glob(os.path.join(self.img_dir, '*.png')))
        if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
            self.index = len(glob.glob(os.path.join(self.img_dir,
                                                    '*.txt'))) - 1
        self.checkpoint = self.index

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')
        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def bring_prev(self, event):

        if not self.feedback:

            poly_verts, self.objects = return_info(self.img_paths[self.index -
                                                                  1][:-3] +
                                                   'txt')

            for num in poly_verts:
                self.existing_polys.append(
                    Polygon(num, closed=True, alpha=0.5, facecolor='red'))

                pat = PatchCollection([Polygon(num, closed=True)],
                                      facecolor='green',
                                      linewidths=0,
                                      alpha=0.6)
                self.ax.add_collection(pat)
                self.existing_patches.append(pat)

    def points_to_polygon(self):
        return np.reshape(np.array(self.points),
                          (int(len(self.points) / 2), 2))

    def deactivate_all(self):
        self.fig.canvas.mpl_disconnect(self.zoom_id)
        self.fig.canvas.mpl_disconnect(self.click_id)
        self.fig.canvas.mpl_disconnect(self.clickrel_id)
        self.fig.canvas.mpl_disconnect(self.keyboard_id)

    def onkeyboard(self, event):

        if not event.inaxes:
            return
        elif event.key == 'a':

            if self.selected_poly:
                self.points = self.interactor.get_polygon().xy.flatten()
                self.interactor.deactivate()
                self.right_click = True
                self.selected_poly = False
                self.fig.canvas.mpl_connect(self.click_id, self.onclick)
                self.polygon.color = (0, 255, 0)
                self.fig.canvas.draw()
            else:
                for i, poly in enumerate(self.existing_polys):

                    if poly.get_path().contains_point(
                        (event.xdata, event.ydata)):

                        self.radio.set_active(
                            self.class_names.index(self.objects[i]) - 1)
                        self.polygon = self.existing_polys[i]
                        self.existing_patches[i].set_visible(False)
                        self.fig.canvas.mpl_disconnect(self.click_id)
                        self.ax.add_patch(self.polygon)
                        self.fig.canvas.draw()
                        self.interactor = PolygonInteractor(
                            self.ax, self.polygon)
                        self.selected_poly = True
                        self.existing_polys.pop(i)
                        break

        elif event.key == 'r':

            for i, poly in enumerate(self.existing_polys):
                if poly.get_path().contains_point((event.xdata, event.ydata)):
                    self.existing_patches[i].set_visible(False)
                    self.existing_patches[i].remove()
                    self.existing_patches.pop(i)
                    self.existing_polys.pop(i)
                    break
        self.fig.canvas.draw()

    def next(self, event):

        if len(self.text.split('\n')) > 5:

            print(self.img_paths[self.index][:-3] + 'txt')

            with open(self.img_paths[self.index][:-3] + 'txt',
                      "w") as text_file:
                text_file.write(self.text)

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        if (self.index < len(self.img_paths) - 1):
            self.index += 1
        else:
            print('all image labeled!, please close current window')
            self.deactivate_all()
            #exit()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def reset_all(self):

        self.polys = []
        self.text = ''
        self.points, self.prev, self.submit_p, self.lines, self.circles = [], None, None, [], []

    def previous(self, event):

        if (self.index > self.checkpoint):
            self.index -= 1
        #print (self.img_paths[self.index][:-3]+'txt')
        os.remove(self.img_paths[self.index][:-3] + 'txt')

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def onclick(self, event):

        if not event.inaxes:
            return
        if not any([x.in_axes(event) for x in self.button_axes]):
            if event.button == 1:
                self.points.extend([event.xdata, event.ydata])
                #print (event.xdata, event.ydata)

                circle = plt.Circle((event.xdata, event.ydata),
                                    2.5,
                                    color='black')
                self.ax.add_artist(circle)
                self.circles.append(circle)

                if (len(self.points) < 4):
                    self.r_x = event.xdata
                    self.r_y = event.ydata
            else:
                if len(self.points) > 5:
                    self.right_click = True
                    self.fig.canvas.mpl_disconnect(self.click_id)
                    self.click_id = None
                    self.points.extend([self.points[0], self.points[1]])
                    #self.prev.remove()

            if (len(self.points) > 2):
                line = self.ax.plot([self.points[-4], self.points[-2]],
                                    [self.points[-3], self.points[-1]], 'b--')
                self.lines.append(line)

            self.fig.canvas.draw()

            if len(self.points) > 4:
                if self.prev:
                    self.prev.remove()
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.prev = self.p

                self.fig.canvas.draw()

            #if len(self.points)>4:
            #    print 'AREA OF POLYGON: ', self.find_poly_area(self.points)
            #print event.x, event.y

    def find_poly_area(self):
        coords = self.points_to_polygon()
        x, y = coords[:, 0], coords[:, 1]
        return (0.5 *
                np.abs(np.dot(x, np.roll(y, 1)) -
                       np.dot(y, np.roll(x, 1)))) / 2  #shoelace algorithm

    def onclick_release(self, event):

        if any([x.in_axes(event)
                for x in self.button_axes]) or self.selected_poly:
            return

        elif self.r_x and np.abs(event.xdata - self.r_x) > 10 and np.abs(
                event.ydata -
                self.r_y) > 10:  # 10 pixels limit for rectangle creation
            if len(self.points) < 4:

                self.right_click = True
                self.fig.canvas.mpl_disconnect(self.click_id)
                self.click_id = None
                bbox = [
                    np.min([event.xdata, self.r_x]),
                    np.min([event.ydata, self.r_y]),
                    np.max([event.xdata, self.r_x]),
                    np.max([event.ydata, self.r_y])
                ]
                self.r_x = self.r_y = None

                self.points = [
                    bbox[0], bbox[1], bbox[0], bbox[3], bbox[2], bbox[3],
                    bbox[2], bbox[1], bbox[0], bbox[1]
                ]
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.fig.canvas.draw()

    def zoom(self, event):

        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim(
            [xdata - new_width * (1 - relx), xdata + new_width * (relx)])
        self.ax.set_ylim(
            [ydata - new_height * (1 - rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def reset(self, event):

        if not self.click_id:
            self.click_id = self.fig.canvas.mpl_connect(
                'button_press_event', self.onclick)
        #print (len(self.lines))
        #print (len(self.circles))
        if len(self.points) > 5:
            for line in self.lines:
                line.pop(0).remove()
            for circle in self.circles:
                circle.remove()
            self.lines, self.circles = [], []
            self.p.remove()
            self.prev = self.p = None
            self.points = []
        #print (len(self.lines))
        #print (len(self.circles))

    def print_points(self):

        ret = ''
        for x in self.points:
            ret += '%.2f' % x + ' '
        return ret

    def submit(self, event):
        if not self.right_click:
            print('Right click before submit is a must!!')
        else:

            self.text += self.radio.value_selected + '\n' + '%.2f' % self.find_poly_area(
            ) + '\n' + self.print_points() + '\n\n'
            self.right_click = False
            #print (self.points)

            self.lines, self.circles = [], []
            self.click_id = self.fig.canvas.mpl_connect(
                'button_press_event', self.onclick)

            self.polys.append(
                Polygon(self.points_to_polygon(),
                        closed=True,
                        color=np.random.rand(3),
                        alpha=0.4,
                        fill=True))
            if self.submit_p:
                self.submit_p.remove()
            self.submit_p = PatchCollection(self.polys,
                                            cmap=matplotlib.cm.jet,
                                            alpha=0.4)
            self.ax.add_collection(self.submit_p)
            self.points = []
class COCO_dataset_generator(object):
    def __init__(self, fig, ax, args):

        self.ax = ax
        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        self.img_dir = args['image_dir']
        self.index = 0
        self.fig = fig
        self.polys = []
        self.zoom_scale, self.points, self.prev, self.submit_p, self.lines, self.circles = 1.2, [], None, None, [], []

        self.zoom_id = fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.click_id = fig.canvas.mpl_connect('button_press_event',
                                               self.onclick)
        self.clickrel_id = fig.canvas.mpl_connect('button_release_event',
                                                  self.onclick_release)
        self.keyboard_id = fig.canvas.mpl_connect('key_press_event',
                                                  self.onkeyboard)

        self.axradio = plt.axes([0.0, 0.0, 0.2, 1])
        self.axbringprev = plt.axes([0.3, 0.05, 0.17, 0.05])
        self.axreset = plt.axes([0.48, 0.05, 0.1, 0.05])
        self.axsubmit = plt.axes([0.59, 0.05, 0.1, 0.05])
        self.axprev = plt.axes([0.7, 0.05, 0.1, 0.05])
        self.axnext = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_bringprev = Button(self.axbringprev,
                                  'Bring Previous Annotations')
        self.b_bringprev.on_clicked(self.bring_prev)
        self.b_reset = Button(self.axreset, 'Reset')
        self.b_reset.on_clicked(self.reset)
        self.b_submit = Button(self.axsubmit, 'Submit')
        self.b_submit.on_clicked(self.submit)
        self.b_next = Button(self.axnext, 'Next')
        self.b_next.on_clicked(self.next)
        self.b_prev = Button(self.axprev, 'Prev')
        self.b_prev.on_clicked(self.previous)

        self.button_axes = [
            self.axbringprev, self.axreset, self.axsubmit, self.axprev,
            self.axnext, self.axradio
        ]

        self.existing_polys = []
        self.existing_patches = []
        self.selected_poly = False
        self.objects = []
        self.feedback = args['feedback']

        self.right_click = False

        self.text = ''

        with open(args['class_file'], 'r') as f:
            self.class_names = [x.strip() for x in f.readlines()]

        self.radio = RadioButtons(self.axradio, self.class_names)
        self.class_names = ('BG', ) + tuple(self.class_names)

        self.img_paths = sorted(glob.glob(os.path.join(self.img_dir, '*.jpg')))
        self.img_paths = self.img_paths + sorted(
            glob.glob(os.path.join(self.img_dir, '*.png')))

        while (1):
            if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
                self.index = self.index + 1
                continue
            else:
                break

        self.checkpoint = self.index
        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        image = plt.imread(self.img_paths[self.index])

        if args['feedback']:

            sys.path.append(args['maskrcnn_dir'])
            from config import Config
            import model as modellib
            from demo import BagsConfig
            from skimage.measure import find_contours
            from visualize_cv2 import random_colors

            config = BagsConfig()

            # Create model object in inference mode.
            model = modellib.MaskRCNN(
                mode="inference",
                model_dir='/'.join(args['weights_path'].split('/')[:-2]),
                config=config)

            # Load weights trained on MS-COCO
            model.load_weights(args['weights_path'], by_name=True)

            r = model.detect([image], verbose=0)[0]

            # Number of instances
            N = r['rois'].shape[0]

            masks = r['masks']

            # Generate random colors
            colors = random_colors(N)

            # Show area outside image boundaries.
            height, width = image.shape[:2]

            class_ids, scores = r['class_ids'], r['scores']

            for i in range(N):
                color = colors[i]

                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.class_names[class_id]

                # Mask
                mask = masks[:, :, i]

                # Mask Polygon
                # Pad to ensure proper polygons for masks that touch image edges.
                padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                                       dtype=np.uint8)
                padded_mask[1:-1, 1:-1] = mask
                contours = find_contours(padded_mask, 0.5)
                for verts in contours:
                    # Subtract the padding and flip (y, x) to (x, y)

                    verts = np.fliplr(verts) - 1
                    pat = PatchCollection([Polygon(verts, closed=True)],
                                          facecolor='green',
                                          linewidths=0,
                                          alpha=0.6)
                    self.ax.add_collection(pat)
                    self.objects.append(label)
                    self.existing_patches.append(pat)
                    self.existing_polys.append(
                        Polygon(verts,
                                closed=True,
                                alpha=0.25,
                                facecolor='red'))

        self.ax.imshow(image, aspect='auto')
        print("file name : {}".format(self.img_paths[self.index]))
        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def bring_prev(self, event):

        if not self.feedback:

            poly_verts, self.objects = return_info(self.img_paths[self.index -
                                                                  1][:-3] +
                                                   'txt')

            for num in poly_verts:
                self.existing_polys.append(
                    Polygon(num, closed=True, alpha=0.5, facecolor='red'))

                pat = PatchCollection([Polygon(num, closed=True)],
                                      facecolor='green',
                                      linewidths=0,
                                      alpha=0.6)
                self.ax.add_collection(pat)
                self.existing_patches.append(pat)

    def points_to_polygon(self):
        return np.reshape(np.array(self.points),
                          (int(len(self.points) / 2), 2))

    def deactivate_all(self):
        self.fig.canvas.mpl_disconnect(self.zoom_id)
        self.fig.canvas.mpl_disconnect(self.click_id)
        self.fig.canvas.mpl_disconnect(self.clickrel_id)
        self.fig.canvas.mpl_disconnect(self.keyboard_id)

    def onkeyboard(self, event):

        if not event.inaxes:
            return
        elif event.key == 'a':

            if self.selected_poly:
                self.points = self.interactor.get_polygon().xy.flatten()
                self.interactor.deactivate()
                self.right_click = True
                self.selected_poly = False
                self.fig.canvas.mpl_connect(self.click_id, self.onclick)
                self.polygon.color = (0, 255, 0)
                self.fig.canvas.draw()
            else:
                for i, poly in enumerate(self.existing_polys):

                    if poly.get_path().contains_point(
                        (event.xdata, event.ydata)):

                        self.radio.set_active(
                            self.class_names.index(self.objects[i]) - 1)
                        self.polygon = self.existing_polys[i]
                        self.existing_patches[i].set_visible(False)
                        self.fig.canvas.mpl_disconnect(self.click_id)
                        self.ax.add_patch(self.polygon)
                        self.fig.canvas.draw()
                        self.interactor = PolygonInteractor(
                            self.ax, self.polygon)
                        self.selected_poly = True
                        self.existing_polys.pop(i)
                        break

        elif event.key == 'r':

            for i, poly in enumerate(self.existing_polys):
                if poly.get_path().contains_point((event.xdata, event.ydata)):
                    self.existing_patches[i].set_visible(False)
                    self.existing_patches[i].remove()
                    self.existing_patches.pop(i)
                    self.existing_polys.pop(i)
                    break
        self.fig.canvas.draw()

    def next(self, event):

        if len(self.text.split('\n')) > 5:

            print(self.img_paths[self.index][:-3] + 'txt')

            with open(self.img_paths[self.index][:-3] + 'txt',
                      "w") as text_file:
                text_file.write(self.text)

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        if (self.index < len(self.img_paths) - 1):
            self.index += 1
        else:
            exit()

        while (1):
            if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
                self.index += 1
                continue

            if (self.index < len(self.img_paths) - 1):
                # self.index += 1
                break
            else:
                print("file is end.")
                exit()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')
        print("file name : {}".format(self.img_paths[self.index]))
        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def reset_all(self):

        self.polys = []
        self.text = ''
        self.points, self.prev, self.submit_p, self.lines, self.circles = [], None, None, [], []

    def previous(self, event):

        if (self.index > self.checkpoint):
            self.index -= 1
        #print (self.img_paths[self.index][:-3]+'txt')
        os.remove(self.img_paths[self.index][:-3] + 'txt')

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def onclick(self, event):

        if not event.inaxes:
            return
        if not any([x.in_axes(event) for x in self.button_axes]):
            if event.button == 1:
                self.points.extend([event.xdata, event.ydata])
                #print (event.xdata, event.ydata)

                circle = plt.Circle((event.xdata, event.ydata),
                                    2.5,
                                    color='black')
                self.ax.add_artist(circle)
                self.circles.append(circle)

                if (len(self.points) < 4):
                    self.r_x = event.xdata
                    self.r_y = event.ydata
            else:
                if len(self.points) > 5:
                    self.right_click = True
                    self.fig.canvas.mpl_disconnect(self.click_id)
                    self.click_id = None
                    self.points.extend([self.points[0], self.points[1]])
                    #self.prev.remove()

            if (len(self.points) > 2):
                line = self.ax.plot([self.points[-4], self.points[-2]],
                                    [self.points[-3], self.points[-1]], 'b--')
                self.lines.append(line)

            self.fig.canvas.draw()

            if len(self.points) > 4:
                if self.prev:
                    self.prev.remove()
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.prev = self.p

                self.fig.canvas.draw()

            #if len(self.points)>4:
            #    print 'AREA OF POLYGON: ', self.find_poly_area(self.points)
            #print event.x, event.y

    def find_poly_area(self):
        coords = self.points_to_polygon()
        x, y = coords[:, 0], coords[:, 1]
        return (0.5 *
                np.abs(np.dot(x, np.roll(y, 1)) -
                       np.dot(y, np.roll(x, 1)))) / 2  #shoelace algorithm

    def onclick_release(self, event):

        if any([x.in_axes(event)
                for x in self.button_axes]) or self.selected_poly:
            return

        elif self.r_x and np.abs(event.xdata - self.r_x) > 10 and np.abs(
                event.ydata -
                self.r_y) > 10:  # 10 pixels limit for rectangle creation
            if len(self.points) < 4:

                self.right_click = True
                self.fig.canvas.mpl_disconnect(self.click_id)
                self.click_id = None
                bbox = [
                    np.min([event.xdata, self.r_x]),
                    np.min([event.ydata, self.r_y]),
                    np.max([event.xdata, self.r_x]),
                    np.max([event.ydata, self.r_y])
                ]
                self.r_x = self.r_y = None

                self.points = [
                    bbox[0], bbox[1], bbox[0], bbox[3], bbox[2], bbox[3],
                    bbox[2], bbox[1], bbox[0], bbox[1]
                ]
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.fig.canvas.draw()

    def zoom(self, event):

        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim(
            [xdata - new_width * (1 - relx), xdata + new_width * (relx)])
        self.ax.set_ylim(
            [ydata - new_height * (1 - rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def reset(self, event):

        if not self.click_id:
            self.click_id = fig.canvas.mpl_connect('button_press_event',
                                                   self.onclick)
        #print (len(self.lines))
        #print (len(self.circles))
        if len(self.points) > 5:
            for line in self.lines:
                line.pop(0).remove()
            for circle in self.circles:
                circle.remove()
            self.lines, self.circles = [], []
            self.p.remove()
            self.prev = self.p = None
            self.points = []
        #print (len(self.lines))
        #print (len(self.circles))

    def print_points(self):

        ret = ''
        for x in self.points:
            ret += '%.2f' % x + ' '
        return ret

    def submit(self, event):

        if not self.right_click:
            print('Right click before submit is a must!!')
        else:

            self.text += self.radio.value_selected + '\n' + '%.2f' % self.find_poly_area(
            ) + '\n' + self.print_points() + '\n\n'
            self.right_click = False
            #print (self.points)

            self.lines, self.circles = [], []
            self.click_id = fig.canvas.mpl_connect('button_press_event',
                                                   self.onclick)

            self.polys.append(
                Polygon(self.points_to_polygon(),
                        closed=True,
                        color=np.random.rand(3),
                        alpha=0.4,
                        fill=True))
            if self.submit_p:
                self.submit_p.remove()
            self.submit_p = PatchCollection(self.polys,
                                            cmap=matplotlib.cm.jet,
                                            alpha=0.4)
            self.ax.add_collection(self.submit_p)
            self.points = []
Пример #17
0
class RunSim:
    ''' This class is the core of the program. It uses matplotlib to gather
    user input and dispaly results, and also preforms requisite calculations.'''

    def __init__(self):
        ''' The constructor creates instances of the walker, biker, and driver
        objects from the 'modes' module, sets a default distance and trip
        number, and then calculates the time, cost, calories,
        and CO2 emissions for all modes of tranit. Then it sets up graphs
        for displaying the information, as well as sliders, buttons, and
        RadioButtons for gathering user input. The functions that those
        buttons execute are defined internally.
        '''
        # Create instances of the driver, biker, and walker objects
        # Whose instance variables will be used heavily in calculations.
        self.d = Driver()
        self.b = Biker()
        self.w = Walker()

        # Default vaulues
        self.dist = 1
        self.trips = 1

        # Do initial calculations (calcualte returns calcDict)
        self.calcDict = self.calculate()

        # Create figure object which we will place everything on
        # of dimensions 14" by 10"
        self.fig = plt.figure(figsize=(14,10))

        # Create 4 axes objects evenly spaced in a column. These will
        # hold the four graphs/figures (time, cost, calories, and CO2.)
        self.ax1 = self.fig.add_subplot(411)
        self.ax2 = self.fig.add_subplot(412)
        self.ax3 = self.fig.add_subplot(413)
        self.ax4 = self.fig.add_subplot(414)

        # Adjust the location of subplots (the axes holding graphs)
        self.fig.subplots_adjust(left = .74, right = .94, bottom = .18,
                    top=.98,hspace=.25)

        ### Set up Buttons, RadioButtons, Sliders and their fcns! ###

        # The structure of setting up the temporary rax axes, then making
        # the RadioButton, then defining it's function, then adding the
        # on_clicked statement is taken from this example:
        # http://matplotlib.org/examples/widgets/radio_buttons.html

        axcolor = 'lightgoldenrodyellow'

        def getRadioPosList(ax):
            ''' This function is simply for positioning the 4 radio buttons
            used to change units approporatiely, vertically centered in
            relation to the adjacent graph. '''
            # Width and height are the same for each radio button
            widthRax = 0.12
            heightRax = 0.10

            # Find the lower (left) x value of adjacent graph
            # And construct appropriate x value for radio axes
            x0Ax = ax.get_position().get_points()[0,0]
            xRax = x0Ax - (widthRax*1.8)

            # Find lower and upper y values of the adjacent graph,
            # average them, and add half the height of the radiobutton axes
            # to determine a y value that will vertically center the radiobutton
            y0Ax = ax.get_position().get_points()[0,1]
            y1Ax = ax.get_position().get_points()[1,1]
            yRax = ((y0Ax+y1Ax)/2) - (heightRax/2)

            return [xRax, yRax, widthRax, heightRax]

        # Unit Change RadioButton 1: Change Time Units
        rax = plt.axes(getRadioPosList(self.ax1), axisbg=axcolor)
        self.radioTime = RadioButtons(rax, ('Hours', 'Minutes', 'Audiobooks'))
        def timeChange(label):
            self.updateGraph()
            plt.draw()
        self.radioTime.on_clicked(timeChange)

        # Unit Change RadioButton 2: Change Money Units
        rax = plt.axes(getRadioPosList(self.ax2), axisbg=axcolor)
        self.radioCost = RadioButtons(rax, ('Dollars', 'Coffees'))
        def costChange(label):
            self.updateGraph()
            plt.draw()
        self.radioCost.on_clicked(costChange)

        # Unit Change RadioButton 3: Change calorie burn units
        rax = plt.axes(getRadioPosList(self.ax3), axisbg=axcolor)
        self.radioCal = RadioButtons(rax, ('Cal (total)', 'Cal (/hour)'))
        def calChange(label):
            self.updateGraph()
            plt.draw()
        self.radioCal.on_clicked(calChange)

        # Unit Change RadioButton 4: Change CO2 Emissions Units
        rax = plt.axes(getRadioPosList(self.ax4), axisbg=axcolor)
        self.radioCO2 = RadioButtons(rax, ('CO2 (lbs)', 'CO2 (trees)'))
        def CO2Change(label):
            self.updateGraph()
            plt.draw()
        self.radioCO2.on_clicked(CO2Change)

        # Sliders 1 and 2: Distnace and Number of Trips
        # Axes and instance of slider for distance control
        axslideDist = plt.axes([0.17, 0.10, 0.77, 0.03], axisbg=axcolor)
        self.slideDist = Slider(axslideDist, 'Distance (miles)', 0.0, 100.0,
                    valinit=self.dist, valfmt='%4.2f')
        # Axes and instance of slider for number of trips control
        axslideTrip = plt.axes([0.17, 0.05, 0.77, 0.03], axisbg=axcolor)
        self.slideTrip = Slider(axslideTrip, 'Trips', 0.0, 100.0,
                    valinit=self.trips, valfmt='%4.2f')
        # Function for updating values after either slider is moved.
        def sliderUpdate(val):
            self.trips = self.slideTrip.val
            self.dist = self.slideDist.val
            self.calcDict = self.calculate()
            self.updateGraph()
        self.slideTrip.on_changed(sliderUpdate)
        self.slideDist.on_changed(sliderUpdate)

        axcolor = 'gold'

        # Customization RadioButton 1: Car - Car Type
        rax = plt.axes([.06, .72, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Driving Info", fontsize=11)
        rax.text(0,1.05, "Car Type", fontsize=11)
        self.radioCarType = RadioButtons(rax, ('Average','Small Sedan',
            'Medium Sedan', 'Large Sedan', '4WD/Sport',
            'Minivan'))
        def carTypeChange(label):
            ''' Adjusts instance of driver object based on the category selected
            using the driver object's setCat function.'''
            if label == 'Average':
                self.d.setCat('average')
            elif label == 'Small Sedan':
                self.d.setCat('smallSedan')
            elif label == ('Medium Sedan'):
                self.d.setCat('mediumSedan')
            elif label == 'Large Sedan':
                self.d.setCat('largeSedan')
            elif label == '4WD/Sport':
                self.d.setCat('4wdSport')
            elif label == 'Minivan':
                self.d.setCat('minivan')
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()
        self.radioCarType.on_clicked(carTypeChange)

        # Customization RadioButton 2: Bike - Spend on Bike
        rax = plt.axes([.26, .72, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Biking Info:", fontsize=11)
        rax.text(0,1.05, "Spend on parts/maintenance ($/year)", fontsize=11)
        self.radioBikeSpend = RadioButtons(rax, ('$0-25','$25-50',
            '$50-100', '$100-150', '$150-200', '>$200'), active=2)
        def bikeSpendChange(label):
            ''' Adjusts instance of biker object based on selected spending,
            using the biker object's setSpend fcn. Then updates graph.'''
            if label == '$0-25':
                self.b.setSpend(12.5)
            elif label == '$25-50':
                self.b.setSpend(37.5)
            elif label == '$50-100':
                self.b.setSpend(75)
            elif label == ('$100-150'):
                self.b.setSpend(125)
            elif label == '$150-200':
                self.b.setSpend(175)
            elif label == '>$200':
                self.b.setSpend(250)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()
        self.radioBikeSpend.on_clicked(bikeSpendChange)

        # Customization RadioButton 3: Walk - Spend on Shoes
        rax = plt.axes([.06, .424, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Walking Info:", fontsize=11)
        rax.text(0,1.05, "Spend on a new pair of shoes", fontsize=11)
        self.radioWalkSpend = RadioButtons(rax, ('$0-25','$25-50',
            '$50-100', '$100-150', '$150-200', '>$200'), active=1)
        def walkSpendChange(label):
            ''' Changes instance of walker object based on spending.'''
            if label == '$0-25':
                self.w.setSpend(12.5)
            elif label == '$25-50':
                self.w.setSpend(37.5)
            elif label == '$50-100':
                self.w.setSpend(75)
            elif label == ('$100-150'):
                self.w.setSpend(125)
            elif label == '$150-200':
                self.w.setSpend(175)
            elif label == '>$200':
                self.w.setSpend(250)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()
        self.radioWalkSpend.on_clicked(walkSpendChange)

        # Customization RadioButton 4: Person - Sex
        rax = plt.axes([.26, .424, .15, 0.2], axisbg=axcolor)
        rax.text(0, 1.15, "Customize Calorie Burn Info:", fontsize=11)
        rax.text(0,1.05, "Sex", fontsize=11)
        self.radioPersonSex = RadioButtons(rax, ('Male','Female'), active=0)
        def personSexChange(label):
            ''' Changes the sex of the person instance of the current instnace
            of the driver, biker, and walker objects. So much OOP!!!!'''
            if label == 'Male':
                self.d.person.setSex('M', True)
                self.b.person.setSex('M', True)
                self.w.person.setSex('M', True)
            elif label == 'Female':
                self.d.person.setSex('F', True)
                self.b.person.setSex('F', True)
                self.w.person.setSex('F', True)
            else:
                print('Error!')
            self.updateGraph()
            plt.draw()
        self.radioPersonSex.on_clicked(personSexChange)

        # Reset Button
        axReset = plt.axes([0.17, 0.25, 0.15, 0.10])
        bReset = Button(axReset, 'Reset Defaults')
        def resetDefaults(event):
            ''' Resets all buttons/sliders to their default position,
            which triggers recalculations and redrawing of the
            graphs. This function is a little slow.'''
            self.slideDist.set_val(1)
            self.slideTrip.set_val(1)
            self.radioTime.set_active(0)
            self.radioCost.set_active(0)
            self.radioCal.set_active(0)
            self.radioCO2.set_active(0)
            self.radioCarType.set_active(0)
            self.radioBikeSpend.set_active(2)
            self.radioWalkSpend.set_active(1)
            self.radioPersonSex.set_active(0)
            plt.draw()
        bReset.on_clicked(resetDefaults)

        # These keep the current drawing current.
        self.updateGraph()
        plt.show()

    def calculate(self):
        ''' This function does all the calculating behind the program. It uses
        attributes of the driver, walker, and biker object's to calculate
        values for time, cost, calorie burn, and CO2 emitted in various units.
        This information is stored in the handy-dandy dictionary: calcDict.'''

        # Dictionary that holds calculations for different categories in the form
        # of lists, where [0]=driver, [1]=biker, [2]=walker
        calcDict = {'time':[],'cost':[], 'cal':[],'time-mins':[], 'time-audio':[],
        'cost-coffee':[], 'cal-hour':[], 'cal-sansBMR':[],
        'CO2':0.0, 'CO2-tree':0.0}

        # Time in hours
        calcDict['time'].append(self.dist*self.trips / self.d.getMPH())
        calcDict['time'].append(self.dist*self.trips / self.b.getMPH())
        calcDict['time'].append(self.dist*self.trips / self.w.getMPH())

        # Cost in US dollars
        calcDict['cost'].append(self.d.getCost()*self.dist*self.trips)
        calcDict['cost'].append(self.b.getCost()*self.dist*self.trips)
        calcDict['cost'].append(self.w.getCost()*self.dist*self.trips)

        # Total calories burned
        calcDict['cal'].append(self.d.person.getCal()*calcDict['time'][0])
        calcDict['cal'].append(self.b.person.getCal()*calcDict['time'][1])
        calcDict['cal'].append(self.w.person.getCal()*calcDict['time'][2])

        ## Alternative units for above categories

        # Time in audiobooks (based on avg len of 12.59 hours)
        # Note: avg length determined from sample of 25 bestsellers on Audible.com
        calcDict['time-mins'].append(calcDict['time'][0]*60)
        calcDict['time-mins'].append(calcDict['time'][1]*60)
        calcDict['time-mins'].append(calcDict['time'][2]*60)

        # Time in audiobooks (based on avg len of 12.59 hours)
        # Note: avg length determined from sample of 25 bestsellers on Audible.com
        calcDict['time-audio'].append(calcDict['time'][0]/12.59)
        calcDict['time-audio'].append(calcDict['time'][1]/12.59)
        calcDict['time-audio'].append(calcDict['time'][2]/12.59)

        #Cost in terms of coffee at blue Mondays burned per hour
        calcDict['cost-coffee'].append(calcDict['cost'][0]/2.60)
        calcDict['cost-coffee'].append(calcDict['cost'][1]/2.60)
        calcDict['cost-coffee'].append(calcDict['cost'][2]/2.60)

        #Cal burned per hour
        calcDict['cal-hour'].append(self.d.person.getCal())
        calcDict['cal-hour'].append(self.b.person.getCal())
        calcDict['cal-hour'].append(self.w.person.getCal())

        # CO2 emissions in lbs
        calcDict['CO2'] = self.d.getCO2()*(self.dist*self.trips)

        # CO2 emissions in terms of trees planted
        # A single tree planted thru americanforests.org sequesters 911 pounds of CO2
        # This value reflects the number of trees one should plant to sequester the carbon
        # emitted by driving
        calcDict['CO2-tree'] = (calcDict['CO2'] / 911)

        return calcDict

    def makeGraph(self, ax, data, ylab):
        ''' makeGraph is called by updateGraph and redraws the 3 graphs
        every time it is called. The x labels are always the same but the
        y values are passed in as 'data'. '''

        ax.clear()

        N = 3                   # 3 divisions of x axis
        maxNum = max(data)      # determine max y value

        ind = np.arange(N)      # the x locations for the groups
        width = 0.5             # the width of the bars

        ## the bars
        rects1 = ax.bar(ind, data, width, color=['cyan','yellow','magenta'])

        # axes and labels
        ax.set_xlim(-.1,len(ind)-.4)
        ax.set_ylim(0,maxNum+(maxNum/10)*2)

        xTickMarks = ['Drive','Bike','Walk']
        ax.set_xticks(ind+(width/2))
        xtickNames = ax.set_xticklabels(xTickMarks)
        ax.set_ylabel(ylab)

        def autolabel(rects):
            ''' Adds labels above bars. Code adapted from matplotlib demo code:
            http://matplotlib.org/examples/api/barchart_demo.html'''
            # attach some text labels
            for rect in rects:
                height = rect.get_height()
                ax.text(rect.get_x() + rect.get_width()/2., 1.05*height,
                        '%4.2f' % (height), fontsize=11,
                        ha='center', va='bottom')

        autolabel(rects1)

    def showInfo(self, ax, data, msg):
        ''' The forth subplot (axes) holds text instead of a bar plot
        and it gets updated using this function.'''

        # Erase any existing information to avoid redrawing
        ax.clear()

        # Remove labels
        ax.set_xticklabels('')
        ax.set_yticklabels('')

        ax.text(.08, .70, "By not driving...", fontsize=11)

        ax.text(.4, .45, '%4.2f' % (data), style='italic',
            bbox={'facecolor':'lightgreen', 'alpha':0.65, 'pad':10})
        ax.text(.08, .20, msg, fontsize=11)

    def updateGraph(self):
        ''' This is called whenever the graph needs to be updated. It calls
        self.calculate to make sure self.calcDate is up to date and it uses
        the values of the radio buttons and sliders as well as the values stored
        in calcDict to determine which y values to pass into makeGraph to
        make the 3 graphs and which values to pass to showInfo.'''

        self.calcDict = self.calculate()

        if self.radioTime.value_selected == 'Hours':
            self.makeGraph(self.ax1, self.calcDict['time'], 'Time (Hours)')
        elif self.radioTime.value_selected == 'Minutes':
            self.makeGraph(self.ax1, self.calcDict['time-mins'], 'Time (Minutes)')
        elif self.radioTime.value_selected == 'Audiobooks':
            self.makeGraph(self.ax1, self.calcDict['time-audio'], 'Time (Audiobooks)')

        if self.radioCost.value_selected == 'Dollars':
            self.makeGraph(self.ax2, self.calcDict['cost'], 'Cost ($)')
        elif self.radioCost.value_selected == 'Coffees':
            self.makeGraph(self.ax2, self.calcDict['cost-coffee'], 'Cost (Coffees)')
        else:
            print('Error!')

        if self.radioCal.value_selected == 'Cal (total)':
            self.makeGraph(self.ax3, self.calcDict['cal'], 'Calories (total)')
        elif self.radioCal.value_selected == 'Cal (/hour)':
            self.makeGraph(self.ax3, self.calcDict['cal-hour'], 'Calories (/hour)')
        else:
            print('Error!')

        if self.radioCO2.value_selected == 'CO2 (lbs)':
            self.showInfo(self.ax4, self.calcDict['CO2'], 'Pounds of CO2 not emitted')
        elif self.radioCO2.value_selected == 'CO2 (trees)':
            self.showInfo(self.ax4, self.calcDict['CO2-tree'], 'Trees planted!')
        else:
            print('Error!')
Пример #18
0
        self.stplot[self.plotdict[self.curhist[0]]].set_visible(
            self.showexpected)
        self.ax.set_title(self.tstarts[self.curhist[0]] + r'$p\leq 2^{%d}$' %
                          (self.curhist[1] + 8) + ' with %d primes' %
                          (self.primespi[self.curhist[1] + 7]))
        plt.draw()


if __name__ == '__main__':
    minp = 8
    maxp = 26
    ncurves = 7
    p = plotmaker()
    callback = Index(p)
    axprev = plt.axes([0.22, 0.02, 0.07, 0.075])
    axnext = plt.axes([0.30, 0.02, 0.07, 0.075])
    bnext = Button(axnext, 'Next')
    bnext.on_clicked(callback.next)
    bprev = Button(axprev, 'Previous')
    bprev.on_clicked(callback.prev)
    axcolor = 'lightgoldenrodyellow'
    rax = plt.axes([0.01, 0.02, 0.20, 0.22], facecolor=axcolor)
    radio = RadioButtons(
        rax, ('Rk 0, CM', 'Rk 0, no CM', 'Rk 1, no CM', 'Rk 4, no CM',
              'Rk 14, no CM', 'Rk 0, CM over CM', 'Rk 1, CM over CM'))
    radio.on_clicked(callback.radiocurve)
    rax2 = plt.axes([0.22, 0.11, 0.15, 0.13], facecolor=axcolor)
    radio2 = RadioButtons(rax2, ('Show expected', 'Hide expected'))
    radio2.set_active(1)
    radio2.on_clicked(callback.curveshow)
    plt.show()
Пример #19
0
class GUI():
    def __init__(self):
        self.axcolor = 'lightgrey'
        self.fig = plt.figure(figsize=(8, 6))
        self.updating_status = False
        self.main()

    def main(self):
        self.calc_infos()
        self.init_figure()
        self.init_control_zone()
        self.init_result_zone()
        self.add_events()
        plt.show()

    def calc_infos(self, algorithm=0):
        self.t_r2, self.s, self.Q, self.x_bias, self.y_bias, self.RMSE, self.W_p, self.rac_u_p, self.s_p, self.t_r2_p, self.T, self.S = run_AutoFit(
            algorithm)

    def init_figure(self):
        ######## Plot Zone ########
        ax = self.fig.add_subplot(1, 1, 1)
        plt.subplots_adjust(bottom=0.42, left=0.18, right=0.65)

        rec_u, W_std = W_u_std(1e-8, 10, 0.01)
        self.l, = plt.plot(self.t_r2 + self.x_bias,
                           self.s + self.y_bias,
                           's',
                           color='maroon',
                           markersize=5,
                           label='measurement data')
        plt.plot(rec_u,
                 W_std,
                 color='k',
                 linewidth=0.85,
                 label='theis type-curve')
        plt.xlabel(r'lg$(\frac{1}{u})$ ', fontsize=10)
        plt.ylabel(r'lg$(W)$', fontsize=10)
        plt.xlim(
            np.min(self.t_r2 + self.x_bias) - 1,
            np.max(self.t_r2 + self.x_bias) + 1)
        plt.ylim(
            np.min(self.s + self.y_bias) - 1,
            np.max(self.s + self.y_bias) + 0.5)
        plt.grid(True, ls='--', lw=.5, c='k', alpha=0.6)
        plt.legend(fontsize=8)
        plt.suptitle("Theis Type-curve Fitting")

    def init_control_zone(self):
        ######## Control Zone ########
        contrl_ax = plt.axes([0.05, 0.05, 0.90, 0.27])
        plt.text(0.035, 0.85, 'Control Zone')
        contrl_ax.set_xticks([])
        contrl_ax.set_yticks([])

        # slider
        ax_xbias = plt.axes([0.14, 0.21, 0.35, 0.03], facecolor=self.axcolor)
        ax_ybias = plt.axes([0.14, 0.16, 0.35, 0.03], facecolor=self.axcolor)

        self.s_xbias = Slider(ax_xbias, 'x_bias', 0.0, 5, valinit=self.x_bias)
        self.s_ybias = Slider(ax_ybias, 'y_bias', 0.0, 5, valinit=self.y_bias)

        # buttons
        aftax = plt.axes([0.8, 0.06, 0.1, 0.04])
        self.reset_button = Button(aftax,
                                   'RESET',
                                   color=self.axcolor,
                                   hovercolor='mistyrose')

        x_bias_dec_tax = plt.axes([0.57, 0.21, 0.03, 0.03])
        self.x_bias_dec_button = Button(x_bias_dec_tax,
                                        '-',
                                        color=self.axcolor,
                                        hovercolor='mistyrose')

        x_bias_inc_tax = plt.axes([0.72, 0.21, 0.03, 0.03])
        self.x_bias_inc_button = Button(x_bias_inc_tax,
                                        '+',
                                        color=self.axcolor,
                                        hovercolor='mistyrose')

        y_bias_dec_tax = plt.axes([0.57, 0.16, 0.03, 0.03])
        self.y_bias_dec_button = Button(y_bias_dec_tax,
                                        '-',
                                        color=self.axcolor,
                                        hovercolor='mistyrose')

        y_bias_inc_tax = plt.axes([0.72, 0.16, 0.03, 0.03])
        self.y_bias_inc_button = Button(y_bias_inc_tax,
                                        '+',
                                        color=self.axcolor,
                                        hovercolor='mistyrose')

        # textbox
        xtextax = plt.axes([0.61, 0.21, 0.1, 0.03])
        ytextax = plt.axes([0.61, 0.16, 0.1, 0.03])
        self.t_xbias = TextBox(xtextax, ' ', str(float('%.2f' % self.x_bias)))
        self.t_ybias = TextBox(ytextax, ' ', str(float('%.2f' % self.y_bias)))

        # radio button
        rax = plt.axes([0.78, 0.12, 0.15, 0.17], facecolor=self.axcolor)
        plt.title('AutoFit Algorithms', fontsize=8)
        self.algos = np.array(['Traverse', 'Mass Center', 'Slope'])
        self.radio = RadioButtons(rax, self.algos, active=0)

    def init_result_zone(self, ):
        ######## Result Zone ########
        re_ax = plt.axes([0.70, 0.37, 0.20, 0.50])
        plt.text(0.035, 0.95, 'Results Zone')
        re_ax.set_xticks([])
        re_ax.set_yticks([])

        self.W_l = plt.text(0.2, 0.85, r'$W: %.2f $' % (self.W_p))
        self.u_l = plt.text(0.2, 0.75, r'$1/u: %.2f $' % (self.rac_u_p))
        self.s_l = plt.text(0.2, 0.65, r'$s: %.2f $' % (self.s_p))
        self.tr2_l = plt.text(0.2, 0.55, r'$t/r^2: %.4f $' % (self.t_r2_p))
        self.T_l = plt.text(0.2, 0.45, r'$T: %.2f m^3/d$' % (self.T))
        self.S_l = plt.text(0.2, 0.35, r'$S: %.6f $' % (self.S))
        plt.text(0.035, 0.2, 'Statistics of Fitting')
        self.RMSE_l = plt.text(0.2, 0.1, r'$RMSE: %.4f$' % (self.RMSE))

    def add_events(self):
        self.s_xbias.on_changed(self.update_x)
        self.s_ybias.on_changed(self.update_y)
        self.reset_button.on_clicked(self.reset)
        self.x_bias_dec_button.on_clicked(self.dec_x_bias)
        self.x_bias_inc_button.on_clicked(self.inc_x_bias)
        self.y_bias_dec_button.on_clicked(self.dec_y_bias)
        self.y_bias_inc_button.on_clicked(self.inc_y_bias)
        self.t_xbias.on_submit(self.submit_x)
        self.t_ybias.on_submit(self.submit_y)
        self.radio.on_clicked(self.chg_algorithm)

    def chg_result(self, t_r2, s, x_bias, y_bias):
        """ change the results in the results zone """
        self.W_p, self.rac_u_p, self.s_p, self.t_r2_p, self.T, self.S = calResult(
            self.t_r2, self.s, self.Q, self.x_bias, self.y_bias, num=2)
        self.W_bias = (W(1 / (10**(t_r2 + float('%2f' % x_bias)))))
        self.RMSE = np.sqrt(
            np.mean((self.W_bias - 10**(s + float('%.2f' % y_bias)))**2))

        self.W_l.set_text(r'$W: %.2f $' % (self.W_p))
        self.u_l.set_text(r'$1/u: %.2f $' % (self.rac_u_p))
        self.s_l.set_text(r'$s: %.2f $' % (self.s_p))
        self.tr2_l.set_text(r'$t/r^2: %.4f $' % (self.t_r2_p))
        self.T_l.set_text(r'$T: %.2f m^3/d$' % (self.T))
        self.S_l.set_text(r'$S: %.6f $' % (self.S))
        self.RMSE_l.set_text(r'$RMSE: %.4f$' % (self.RMSE))

    def update_x(self, x_bias):
        """update all widgets with new x_bias"""
        # 防抖
        if self.updating_status:
            return

        x_bias = float('%.2f' % x_bias)
        if x_bias == self.x_bias:
            return

        self.updating_status = True
        self.x_bias = x_bias
        self.s_xbias.set_val(float('%.2f' % self.x_bias))
        self.t_xbias.set_val(float('%.2f' % self.x_bias))
        self.l.set_xdata(self.t_r2 + self.x_bias)

        self.fig.canvas.draw_idle()
        self.chg_result(self.t_r2, self.s, self.x_bias, self.y_bias)
        self.updating_status = False

    def update_y(self, y_bias):
        """update all widgets with new y_bias"""
        # 防抖
        if self.updating_status:
            return

        y_bias = float('%.2f' % y_bias)
        if y_bias == self.y_bias:
            return

        self.updating_status = True
        self.y_bias = y_bias
        self.s_ybias.set_val(float('%.2f' % self.y_bias))
        self.t_ybias.set_val(float('%.2f' % self.y_bias))
        self.l.set_ydata(self.s + self.y_bias)

        self.fig.canvas.draw_idle()
        self.chg_result(self.t_r2, self.s, self.x_bias, self.y_bias)
        self.updating_status = False

    def reset(self, event):
        """ reset the figure """
        self.radio.set_active(0)

    def dec_x_bias(self, event):
        """ decrease the x_bias using '-' button """
        self.update_x(self.x_bias - 0.01)

    def inc_x_bias(self, event):
        """ increase the x_bias using '+' button """
        self.update_x(self.x_bias + 0.01)

    def dec_y_bias(self, event):
        """ decrease the y_bias using '-' button """
        self.update_y(self.y_bias - 0.01)

    def inc_y_bias(self, event):
        """ increase the y_bias using '+' button """
        self.update_y(self.y_bias + 0.01)

    def submit_x(self, text):
        """ updates the figure using textbox (x_bias)"""
        self.update_x(float(text))

    def submit_y(self, text):
        """ updates the figure using textbox (y_bias)"""
        self.update_y(float(text))

    def chg_algorithm(self, label):
        algorithm = np.argwhere(self.algos == label)[0][0]
        # t_r2, s, Q, x_bias, y_bias, RMSE, W_p, rac_u_p, s_p, t_r2_p, T, S = run_AutoFit(algorithm)
        self.calc_infos(algorithm)
        self.update_x(self.x_bias)
        self.update_y(self.y_bias)
Пример #20
0
class ObjectLabeler():
    def __init__(self, frameDirectory, annotationFile, number, projectID):

        self.frameDirectory = frameDirectory
        self.annotationFile = annotationFile
        self.number = number
        self.projectID = projectID

        self.frames = sorted([
            x for x in os.listdir(self.frameDirectory)
            if '.jpg' in x and '._' not in x
        ])
        random.Random(4).shuffle(self.frames)
        #self.frames = sorted([x for x in os.listdir(self.frameDirectory) if '.jpg' in x and '._' not in x]) # remove annoying mac OSX files
        assert len(self.frames) > 0

        # Keep track of the frame we are on and how many we have annotated
        self.frame_index = 0
        self.annotated_frames = []

        # Intialize lists to hold annotated objects
        self.coords = ()

        # Create dataframe to hold annotations
        if os.path.exists(self.annotationFile):
            self.dt = pd.read_csv(self.annotationFile, index_col=0)
        else:
            self.dt = pd.DataFrame(columns=[
                'ProjectID', 'Framefile', 'Nfish', 'Sex', 'Box', 'User',
                'DateTime'
            ])
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])

        # Get user and current time
        self.user = os.getenv('USER')
        self.now = datetime.datetime.now()

        # Create Annotation object
        self.annotation = Annotation(self)

        #
        self.annotation_text = ''

        # Start figure
        self._createFigure()

    def _createFigure(self):
        # Create figure
        self.fig = fig = plt.figure(1, figsize=(10, 7))

        # Create image subplot
        self.ax_image = fig.add_axes([0.05, 0.2, .8, 0.75])
        while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index])
                          & (self.dt.User == self.user)]) != 0:
            self.frame_index += 1
        # Create slider for saturation
        self.ax_saturation = fig.add_axes([0.1, 0.08, 0.2, 0.03])
        self.slid_saturation = Slider(self.ax_saturation,
                                      'Saturation',
                                      0,
                                      10,
                                      valinit=1,
                                      valstep=.1)

        # Plot image
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        #print(img.shape)
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        self.image_obj = self.ax_image.imshow(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])

        # Create selectors for identifying bounding bos and body parts (nose, left eye, right eye, tail)

        self.RS = RectangleSelector(
            self.ax_image,
            self._grabBoundingBox,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)
        self.RS.set_active(True)

        # Create radio buttons
        self.ax_radio = fig.add_axes([0.85, 0.85, 0.125, 0.1])
        self.radio_names = [
            r"$\bf{M}$" + 'ale', r"$\bf{F}$" + 'emale', r"$\bf{U}$" + 'nknown'
        ]
        self.bt_radio = RadioButtons(self.ax_radio,
                                     self.radio_names,
                                     active=0,
                                     activecolor='blue')

        # Create click buttons for adding annotations
        self.ax_boxAdd = fig.add_axes([0.85, 0.775, 0.125, 0.04])
        self.bt_boxAdd = Button(self.ax_boxAdd, r"$\bf{A}$" + 'dd Box')
        self.ax_boxClear = fig.add_axes([0.85, 0.725, 0.125, 0.04])
        self.bt_boxClear = Button(self.ax_boxClear, r"$\bf{C}$" + 'lear Box')

        # Create click buttons for saving frame annotations or starting over
        self.ax_frameClear = fig.add_axes([0.85, 0.375, 0.125, 0.04])
        self.bt_frameClear = Button(self.ax_frameClear,
                                    r"$\bf{R}$" + 'eset Frame')
        self.ax_frameAdd = fig.add_axes([0.85, 0.325, 0.125, 0.04])
        self.bt_frameAdd = Button(self.ax_frameAdd, r"$\bf{N}$" + 'ext Frame')
        self.ax_framePrevious = fig.add_axes([0.85, 0.275, 0.125, 0.04])
        self.bt_framePrevious = Button(self.ax_framePrevious,
                                       r"$\bf{P}$" + 'revious Frame')

        # Create click button for quitting annotations
        self.ax_quit = fig.add_axes([0.85, 0.175, 0.125, 0.04])
        self.bt_quit = Button(self.ax_quit, r"$\bf{Q}$" + 'uit and save')

        # Add text boxes to display info on annotations
        self.ax_cur_text = fig.add_axes([0.85, 0.575, 0.125, 0.14])
        self.ax_cur_text.set_axis_off()
        self.cur_text = self.ax_cur_text.text(0,
                                              1,
                                              '',
                                              fontsize=8,
                                              verticalalignment='top')

        self.ax_all_text = fig.add_axes([0.85, 0.425, 0.125, 0.19])
        self.ax_all_text.set_axis_off()
        self.all_text = self.ax_all_text.text(0,
                                              1,
                                              '',
                                              fontsize=9,
                                              verticalalignment='top')

        self.ax_error_text = fig.add_axes([0.3, 0.05, .6, 0.1])
        self.ax_error_text.set_axis_off()
        self.error_text = self.ax_error_text.text(0,
                                                  1,
                                                  '',
                                                  fontsize=14,
                                                  color='red',
                                                  verticalalignment='top')

        # Set buttons in active that shouldn't be pressed
        #self.bt_poses.set_active(False)

        # Turn on keypress events to speed things up
        self.fig.canvas.mpl_connect('key_press_event', self._keypress)

        # Turn off hover event for buttons (no idea why but this interferes with the image rectange remaining displayed)
        self.fig.canvas.mpl_disconnect(self.bt_boxAdd.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_boxClear.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_frameAdd.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_frameClear.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_framePrevious.cids[2])
        self.fig.canvas.mpl_disconnect(self.bt_quit.cids[2])

        # Connect buttons to specific functions
        self.bt_boxAdd.on_clicked(self._addBoundingBox)
        self.bt_boxClear.on_clicked(self._clearBoundingBox)
        self.bt_frameClear.on_clicked(self._clearFrame)
        self.bt_framePrevious.on_clicked(self._previousFrame)
        self.bt_frameAdd.on_clicked(self._nextFrame)
        self.bt_quit.on_clicked(self._quit)
        self.slid_saturation.on_changed(self._updateFrame)

        # Show figure
        plt.show()

    def _grabBoundingBox(self, eclick, erelease):
        self.error_text.set_text('')

        # Transform and store image coords
        image_coords = list(self.ax_image.transData.inverted().transform(
            (eclick.x, eclick.y))) + list(
                self.ax_image.transData.inverted().transform(
                    (erelease.x, erelease.y)))

        # Convert to integers:
        image_coords = tuple([int(x) for x in image_coords])

        xy = (min(image_coords[0],
                  image_coords[2]), min(image_coords[1], image_coords[3]))
        width = abs(image_coords[0] - image_coords[2])
        height = abs(image_coords[1] - image_coords[3])
        self.annotation.coords = xy + (width, height)

    def _keypress(self, event):
        if event.key in ['m', 'f', 'u']:
            self.bt_radio.set_active(['m', 'f', 'u'].index(event.key))
            #self.fig.canvas.draw()
        elif event.key == 'a':
            self._addBoundingBox(event)
        elif event.key == 'c':
            self._clearBoundingBox(event)
        elif event.key == 'n':
            self._nextFrame(event)
        elif event.key == 'r':
            self._clearFrame(event)
        elif event.key == 'p':
            self._previousFrame(event)
        elif event.key == 'q':
            self._quit(event)
        else:
            pass

    def _addBoundingBox(self, event):
        if self.annotation.coords == ():
            self.error_text.set_text('Error: Bounding box not set')
            self.fig.canvas.draw()
            return

        displayed_names = [
            r"$\bf{M}$" + 'ale', r"$\bf{F}$" + 'emale', r"$\bf{U}$" + 'nknown'
        ]
        stored_names = ['m', 'f', 'u']

        self.annotation.sex = stored_names[displayed_names.index(
            self.bt_radio.value_selected)]

        # Add new patch rectangle
        #colormap = {self.radio_names[0]:'blue', self.radio_names[1]:'pink', self.radio_names[2]: 'red', self.radio_names[3]: 'black'}
        #color = colormap[self.bt_radio.value_selected]
        self.annotation.addRectangle()

        outrow = self.annotation.retRow()

        if type(outrow) == str:
            self.error_text.set_text(outrow)
            self.fig.canvas.draw()
            return
        else:
            self.f_dt.loc[len(self.f_dt)] = outrow
            self.f_dt.drop_duplicates(
                subset=['ProjectID', 'Framefile', 'User', 'Sex', 'Box'])

        self.annotation_text += self.annotation.sex + ':' + str(
            self.annotation.coords) + '\n'
        # Add annotation to the temporary data frame
        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.annotation.reset()

        self.fig.canvas.draw()

    def _clearBoundingBox(self, event):

        if not self.annotation.removePatches():
            return

        self.annotation_text = self.annotation_text.split(
            self.annotation_text.split('\n')[-2])[0]

        self.annotation.reset()

        self.f_dt.drop(self.f_dt.tail(1).index, inplace=True)

        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.fig.canvas.draw()

    def _nextFrame(self, event):

        if self.annotation.coords != ():
            self.error_text.set_text(
                'Save or clear (esc) current annotation before moving on')
            return

        if len(self.f_dt) == 0:
            self.f_dt.loc[0] = [
                self.projectID, self.frames[self.frame_index], '', '',
                self.user, self.now
            ]
            self.f_dt['Nfish'] = 0
        else:
            self.f_dt['Nfish'] = len(self.f_dt)
        self.dt = self.dt.append(self.f_dt, sort=True)
        # Save dataframe (in case user quits)
        self.dt.to_csv(self.annotationFile,
                       sep=',',
                       columns=[
                           'ProjectID', 'Framefile', 'Nfish', 'Sex', 'Box',
                           'User', 'DateTime'
                       ])
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])
        self.annotated_frames.append(self.frame_index)

        # Remove old patches
        self.ax_image.patches = []

        # Reset annotations
        self.annotation = Annotation(self)
        self.annotation_text = ''

        # Update frame index and determine if all images are annotated
        self.frame_index += 1
        while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index])
                          & (self.dt.User == self.user)]) != 0:
            self.frame_index += 1

        if self.frame_index == len(self.frames) or len(
                self.annotated_frames) == self.number:

            # Disconnect connections and close figure
            plt.close(self.fig)

        self.cur_text.set_text('')
        self.all_text.set_text('')

        # Load new image and save it as the background
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        #print(img.shape)
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        self.image_obj.set_array(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])
        self.fig.canvas.draw()
        #self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)

    def _updateFrame(self, event):
        img = self.converter.enhance(self.slid_saturation.val)
        self.image_obj.set_array(img)
        self.fig.canvas.draw()

    def _clearFrame(self, event):
        print('Clearing')
        self.f_dt = pd.DataFrame(columns=[
            'ProjectID', 'Framefile', 'Sex', 'Box', 'User', 'DateTime'
        ])
        # Remove old patches
        self.ax_image.patches = []
        self.annotation_text = ''
        self.annotation = Annotation(self)

        self.cur_text.set_text(self.annotation_text)
        self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

        self.fig.canvas.draw()

        # Reset annotations
        self.annotation = Annotation(self)

    def _previousFrame(self, event):
        self.frame_index = self.annotated_frames.pop()
        self.dt = self.dt[self.dt.Framefile != self.frames[self.frame_index]]
        self._clearFrame(event)
        # Load new image and save it as the background
        self.img = Image.open(self.frameDirectory +
                              self.frames[self.frame_index])
        self.converter = ImageEnhance.Color(self.img)
        img = self.converter.enhance(self.slid_saturation.val)

        #img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
        self.image_obj.set_array(img)
        self.ax_image.set_title('Frame ' + str(self.frame_index) + ': ' +
                                self.frames[self.frame_index])
        self.fig.canvas.draw()

    def _quit(self, event):
        plt.close(self.fig)
Пример #21
0
class ReviewInterface(object):
    """Class to layout interaction elements and define callbacks. """

    def __init__(self, fig,
                 axes_seg, axes_mri,
                 qcw, subject_id,
                 flagged_as_outlier, outlier_alerts,
                 navig_options=default_navigation_options,
                 annot=None):
        "Constructor."

        self.fig = fig
        self.axes_seg = axes_seg
        self.axes_mri = axes_mri
        self.latest_alpha_seg = qcw.alpha_seg
        self.rating_list = qcw.rating_list
        self.flagged_as_outlier = flagged_as_outlier

        self.user_rating = None
        self.user_notes = None
        self.quit_now = False

        self.zoomed_in = False
        self.prev_axis = None
        self.prev_ax_pos = None

        # displaying some annotation text if provided
        if annot is not None:
            fig.text(cfg.position_annot_text[0], cfg.position_annot_text[1], **cfg.annot_text_props)

        # right above the rating area (blinking perhaps?)
        if self.flagged_as_outlier and outlier_alerts is not None:
            ax_alert = plt.axes(cfg.position_outlier_alert_box)
            ax_alert.axis('off')
            h_rect = Rectangle((0, 0), 1, 1, zorder=-1, facecolor='xkcd:coral', alpha=0.25)
            ax_alert.add_patch(h_rect)
            alert_msg = 'Flagged as an outlier'
            fig.text(cfg.position_outlier_alert[0], cfg.position_outlier_alert[1],
                     alert_msg, **cfg.annot_text_props)
            for idx, cause in enumerate(outlier_alerts):
                fig.text(cfg.position_outlier_alert[0], cfg.position_outlier_alert[1] - (idx+1) * 0.02,
                         cause, color=cfg.alert_colors_outlier[cause], **cfg.annot_text_props)

        ax_radio = plt.axes(cfg.position_rating_axis, facecolor=cfg.color_rating_axis, aspect='equal')
        self.radio_bt_rating = RadioButtons(ax_radio, self.rating_list,
                                            active=None, activecolor='orange')
        self.radio_bt_rating.on_clicked(self.save_rating)
        for txt_lbl in self.radio_bt_rating.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_rating.circles:
            circ.set(radius=0.06)


        # ax_quit = plt.axes(cfg.position_navig_options, facecolor=cfg.color_quit_axis, aspect='equal')
        # self.radio_bt_quit = RadioButtons(ax_quit, navig_options,
        #                                   active=None, activecolor='orange')
        # self.radio_bt_quit.on_clicked(self.advance_or_quit)
        # for txt_lbl in self.radio_bt_quit.labels:
        #     txt_lbl.set(color=cfg.text_option_color, fontweight='normal')
        #
        # for circ in self.radio_bt_quit.circles:
        #     circ.set(radius=0.06)


        # implementing two separate buttons for navigation (mulitple radio button has issues)
        ax_bt_quit = plt.axes(cfg.position_quit_button, facecolor=cfg.color_quit_axis, aspect='equal')
        ax_bt_next = plt.axes(cfg.position_next_button, facecolor=cfg.color_quit_axis, aspect='equal')
        self.bt_quit = Button(ax_bt_quit, 'Quit', hovercolor='red')
        self.bt_next = Button(ax_bt_next, 'Next', hovercolor='xkcd:greenish')
        self.bt_quit.on_clicked(self.quit)
        self.bt_next.on_clicked(self.next)
        self.bt_quit.label.set_color(cfg.color_navig_text)
        self.bt_next.label.set_color(cfg.color_navig_text)

        # # with qcw and subject id available here, we can add a button to
        # TODO open images directly in tkmedit with qcw.images_for_id[subject_id]['mri'] ... ['seg']

        # alpha slider
        ax_slider = plt.axes(cfg.position_slider_seg_alpha, facecolor=cfg.color_slider_axis)
        self.slider = Slider(ax_slider, label='transparency',
                             valmin=0.0, valmax=1.0, valinit=0.7, valfmt='%1.2f')
        self.slider.label.set_position((0.99, 1.5))
        self.slider.on_changed(self.set_alpha_value)

        # user notes
        ax_text = plt.axes(cfg.position_text_input) # , facecolor=cfg.color_textbox_input, aspect='equal'
        self.text_box = TextBox(ax_text, color=cfg.text_box_color, hovercolor=cfg.text_box_color,
                                label=cfg.textbox_title, initial=cfg.textbox_initial_text)
        self.text_box.label.update(dict(color=cfg.text_box_text_color, wrap=True,
                                   verticalalignment='top', horizontalalignment='left'))
        self.text_box.on_submit(self.save_user_notes)
        # matplotlib has issues if we connect two events to the same callback
        # self.text_box.on_text_change(self.save_user_notes_duplicate)


    def on_mouse(self, event):
        """Callback for mouse events."""

        if self.prev_axis is not None:
            # include all the non-data axes here (so they wont be zoomed-in)
            if event.inaxes not in [self.slider.ax, self.radio_bt_rating.ax,
                                    self.bt_next.ax, self.bt_quit.ax, self.text_box.ax] \
                    and event.button not in [3]: # allowing toggling of overlay in zoomed-in state with right click
                self.prev_axis.set_position(self.prev_ax_pos)
                self.prev_axis.set_zorder(0)
                self.prev_axis.patch.set_alpha(0.5)
                self.zoomed_in = False

        # right click to toggle overlay
        # TODO another useful appl could be to use right click to record erroneous slices
        if event.button in [3]:
            self.toggle_overlay()

        # double click to zoom in to any axis
        elif event.dblclick and event.inaxes is not None:
            # zoom axes full-screen
            self.prev_ax_pos = event.inaxes.get_position()
            event.inaxes.set_position(zoomed_position)
            event.inaxes.set_zorder(1) # bring forth
            event.inaxes.set_facecolor('black') # black
            event.inaxes.patch.set_alpha(1.0)  # opaque
            self.zoomed_in = True
            self.prev_axis = event.inaxes

        else:
            pass

        plt.draw()


    def do_shortcuts(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in ['right', ' ', 'space']:
            self.user_rating = self.radio_bt_rating.value_selected
            self.next()
        if key_pressed in ['ctrl+q', 'q+ctrl']:
            self.user_rating = self.radio_bt_rating.value_selected
            self.quit()
        else:
            if key_pressed in cfg.default_rating_list_shortform:
                self.user_rating = cfg.map_short_rating[key_pressed]
                self.radio_bt_rating.set_active(cfg.default_rating_list.index(self.user_rating))
            elif key_pressed in ['t']:
                self.toggle_overlay()
            else:
                pass

    def toggle_overlay(self):
        """Toggles the overlay by setting alpha to 0 and back."""

        if self.latest_alpha_seg != 0.0:
            self.prev_alpha_seg = self.latest_alpha_seg
            self.latest_alpha_seg = 0.0
        else:
            self.latest_alpha_seg = self.prev_alpha_seg
        self.update()

    def set_alpha_value(self, latest_value):
        """" Use the slider to set alpha."""

        self.latest_alpha_seg = latest_value
        self.update()

    def save_rating(self, label):
        """Update the rating"""

        # print('  rating {}'.format(label))
        self.user_rating = label

    def save_user_notes(self, text_entered):
        """Saves user free-form notes from textbox."""

        self.user_notes = text_entered

    def save_user_notes_duplicate(self, text_entered):
        """Saves user free-form notes from textbox."""

        self.user_notes = text_entered

    # this callback for 2nd radio button is not getting executed properly
    # tracing revelas, some problems set_active callback
    def advance_or_quit(self, label):
        """Signal to quit"""

        if label.upper() == u'QUIT':
            self.quit()
        else:
            self.next()

    def quit(self, ignore_arg=None):
        "terminator"

        if self.user_rating in cfg.ratings_not_to_be_recorded:
            print('You have not rated the current subject! Please rate it before you can advance to next subject, or to quit.')
        else:
            self.quit_now = True
            plt.close(self.fig)

    def next(self, ignore_arg=None):
        "terminator"

        if self.user_rating in cfg.ratings_not_to_be_recorded:
            print('You have not rated the current subject! Please rate it before you can advance to next subject, or to quit.')
        else:
            self.quit_now = False
            plt.close(self.fig)

    def update(self):
        """updating seg alpha for all axes"""

        for ax in self.axes_seg:
            ax.set_alpha(self.latest_alpha_seg)

        # update figure
        plt.draw()
Пример #22
0
class AlignmentInterface(BaseReviewInterface):
    """Custom interface for rating the quality of alignment between two 3d MR images"""
    def __init__(self,
                 fig,
                 rating_list=cfg.default_rating_list,
                 next_button_callback=None,
                 quit_button_callback=None,
                 change_vis_type_callback=None,
                 toggle_animation_callback=None,
                 show_first_image_callback=None,
                 show_second_image_callback=None,
                 alpha_seg=cfg.default_alpha_seg):
        """Constructor"""

        super().__init__(fig, None, next_button_callback, quit_button_callback)

        self.rating_list = rating_list

        self.latest_alpha_seg = alpha_seg
        self.prev_axis = None
        self.prev_ax_pos = None
        self.zoomed_in = False

        self.next_button_callback = next_button_callback
        self.quit_button_callback = quit_button_callback
        self.change_vis_type_callback = change_vis_type_callback
        self.toggle_animation_callback = toggle_animation_callback
        self.show_first_image_callback = show_first_image_callback
        self.show_second_image_callback = show_second_image_callback

        self.add_radio_buttons_rating()
        self.add_radio_buttons_comparison_method()

        # this list of artists to be populated later
        # makes to handy to clean them all
        self.data_handles = list()

        self.unzoomable_axes = [
            self.radio_bt_rating.ax,
            self.radio_bt_vis_type.ax,
            self.text_box.ax,
            self.bt_next.ax,
            self.bt_quit.ax,
        ]

    def add_radio_buttons_comparison_method(self):

        ax_radio = plt.axes(cfg.position_alignment_radio_button_method,
                            facecolor=cfg.color_rating_axis,
                            aspect='equal')
        self.radio_bt_vis_type = RadioButtons(ax_radio,
                                              cfg.choices_alignment_comparison,
                                              active=None,
                                              activecolor='orange')
        self.radio_bt_vis_type.on_clicked(self.change_vis_type_callback)
        for txt_lbl in self.radio_bt_vis_type.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_vis_type.circles:
            circ.set(radius=0.06)

    def add_radio_buttons_rating(self):

        ax_radio = plt.axes(cfg.position_alignment_radio_button_rating,
                            facecolor=cfg.color_rating_axis,
                            aspect='equal')
        self.radio_bt_rating = RadioButtons(ax_radio,
                                            self.rating_list,
                                            active=None,
                                            activecolor='orange')
        self.radio_bt_rating.on_clicked(self.save_rating)
        for txt_lbl in self.radio_bt_rating.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_rating.circles:
            circ.set(radius=0.06)

    def save_rating(self, label):
        """Update the rating"""

        # print('  rating {}'.format(label))
        self.user_rating = label

    def get_ratings(self):
        """Returns the final set of checked ratings"""

        return self.user_rating

    def allowed_to_advance(self):
        """
        Method to ensure work is done for current iteration,
        before allowing the user to advance to next subject.

        Returns False if atleast one of the following conditions are not met:
            Atleast Checkbox is checked
        """

        return self.radio_bt_rating.value_selected is not None

    def reset_figure(self):
        "Resets the figure to prepare it for display of next subject."

        self.clear_data()
        self.clear_radio_buttons()
        self.clear_notes_annot()

    def clear_data(self):
        """clearing all data/image handles"""

        if self.data_handles:
            for artist in self.data_handles:
                artist.remove()
            # resetting it
            self.data_handles = list()

    def clear_radio_buttons(self):
        """Clears the radio button"""

        # enabling default rating encourages lazy advancing without review
        # self.radio_bt_rating.set_active(cfg.index_freesurfer_default_rating)
        for index, label in enumerate(self.radio_bt_rating.labels):
            if label.get_text() == self.radio_bt_rating.value_selected:
                self.radio_bt_rating.circles[index].set_facecolor(
                    cfg.color_rating_axis)
                break
        self.radio_bt_rating.value_selected = None

    def clear_notes_annot(self):
        """clearing notes and annotations"""

        self.text_box.set_val(cfg.textbox_initial_text)
        # text is matplotlib artist
        self.annot_text.remove()

    def on_mouse(self, event):
        """Callback for mouse events."""

        if self.prev_axis is not None:
            # include all the non-data axes here (so they wont be zoomed-in)
            if event.inaxes not in self.unzoomable_axes:
                self.prev_axis.set_position(self.prev_ax_pos)
                self.prev_axis.set_zorder(0)
                self.prev_axis.patch.set_alpha(0.5)
                self.zoomed_in = False

        # right click undefined
        if event.button in [3]:
            pass
        # double click to zoom in to any axis
        elif event.dblclick and event.inaxes is not None and \
            event.inaxes not in self.unzoomable_axes:
            # zoom axes full-screen
            self.prev_ax_pos = event.inaxes.get_position()
            event.inaxes.set_position(cfg.zoomed_position)
            event.inaxes.set_zorder(1)  # bring forth
            event.inaxes.set_facecolor('black')  # black
            event.inaxes.patch.set_alpha(1.0)  # opaque
            self.zoomed_in = True
            self.prev_axis = event.inaxes

        else:
            pass

        plt.draw()

    def on_keyboard(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax or key_in.key is None:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in [
                'right',
        ]:
            self.next_button_callback()
        elif key_pressed in ['ctrl+q', 'q+ctrl']:
            self.quit_button_callback()
        elif key_pressed in [' ', 'space']:
            self.toggle_animation_callback()
        elif key_pressed in ['alt+1', '1+alt']:
            self.show_first_image_callback()
        elif key_pressed in ['alt+2', '2+alt']:
            self.show_second_image_callback()
        else:
            if key_pressed in cfg.default_rating_list_shortform:
                self.user_rating = cfg.map_short_rating[key_pressed]
                index_to_set = cfg.default_rating_list.index(self.user_rating)
                self.radio_bt_rating.set_active(index_to_set)
            else:
                pass

        plt.draw()
    def setup_plots(self,
                    catchment_section,
                    ref_orog_field,
                    data_orog_original_scale_field,
                    ref_flowtocellfield,
                    data_flowtocellfield,
                    rdirs_field,
                    super_fine_orog_field,
                    super_fine_data_flowtocellfield,
                    pair,
                    catchment_bounds,
                    scale_factor,
                    ref_to_super_fine_scale_factor,
                    ref_grid_offset_adjustment=0,
                    fine_grid_offset_adjustment=0,
                    super_fine_grid_offset_adjustment=0):
        """Setup the set of interactive plots and bind on events to callbacks

        Arguments:
        catchment_section
        ref_orog_field
        data_orog_original_scale_field
        ref_flowtocellfield
        data_flowtocellfield
        rdirs_field
        super_fine_orog_field
        super_fine_data_flowtocellfield
        pair
        catchment_bounds
        scale_factor
        ref_to_super_fine_scale_facto
        Returns: nothing
        """

        self.scale_factor = scale_factor
        self.ref_to_super_fine_scale_factor = ref_to_super_fine_scale_factor
        self.catchment_bounds = catchment_bounds
        self.catchment_section = catchment_section
        self.ref_orog_field = ref_orog_field
        self.data_orog_original_scale_field = data_orog_original_scale_field
        self.ref_flowtocellfield = ref_flowtocellfield
        self.data_flowtocellfield = data_flowtocellfield
        self.rdirs_field = rdirs_field
        self.super_fine_orog_field = super_fine_orog_field
        self.super_fine_data_flowtocellfield = super_fine_data_flowtocellfield
        self.pair = pair
        plt.figure(figsize=(25, 14))
        gs = gridspec.GridSpec(6,
                               9,
                               width_ratios=[8, 1, 1.5, 8, 1, 1.5, 8, 1, 1],
                               height_ratios=[12, 1, 1, 1, 1, 1])
        ax = plt.subplot(gs[0, 0])
        ax2 = plt.subplot(gs[0, 3])
        ax_catch = plt.subplot(gs[0, 6])
        cax_cct = plt.subplot(gs[0, 7])
        self.using_numbers_ref.append(False)
        self.using_numbers_data.append(False)
        self.use_super_fine_orog_flags.append(False)
        rc_pts.plot_catchment(ax_catch,
                              self.catchment_section,
                              cax=cax_cct,
                              remove_ticks_flag=False,
                              format_coords=True,
                              lat_offset=self.catchment_bounds[0],
                              lon_offset=self.catchment_bounds[2],
                              colors=None)

        def format_fn_y(tick_val, tick_pos, offset=self.catchment_bounds[0]):
            return pts.calculate_lat_label(tick_val, offset)

        self.y_formatter_funcs.append(format_fn_y)

        def format_fn_x(tick_val,
                        tick_pos,
                        offset=self.catchment_bounds[2] +
                        ref_grid_offset_adjustment):
            return pts.calculate_lon_label(tick_val, offset)

        self.x_formatter_funcs.append(format_fn_x)
        ax_catch.yaxis.set_major_formatter(FuncFormatter(format_fn_y))
        ax_catch.xaxis.set_major_formatter((FuncFormatter(format_fn_x)))
        plt.setp(ax_catch.xaxis.get_ticklabels(), rotation='vertical')
        ref_orog_field_section = ref_orog_field[
            self.catchment_bounds[0]:self.catchment_bounds[1],
            self.catchment_bounds[2]:self.catchment_bounds[3]]
        data_orog_original_scale_field_section = \
            data_orog_original_scale_field[self.catchment_bounds[0]*self.scale_factor:
                                           self.catchment_bounds[1]*self.scale_factor,
                                           self.catchment_bounds[2]*self.scale_factor:
                                           self.catchment_bounds[3]*self.scale_factor]
        ref_flowtocellfield_section = \
            self.ref_flowtocellfield[self.catchment_bounds[0]:self.catchment_bounds[1],
                                     self.catchment_bounds[2]:self.catchment_bounds[3]]
        data_flowtocellfield_section = \
            data_flowtocellfield[self.catchment_bounds[0]:self.catchment_bounds[1],
                                 self.catchment_bounds[2]:self.catchment_bounds[3]]
        rdirs_field_section = \
            rdirs_field[self.catchment_bounds[0]:self.catchment_bounds[1],
                        self.catchment_bounds[2]:self.catchment_bounds[3]]
        if self.super_fine_orog_field is not None:
            super_fine_orog_field_section = self.super_fine_orog_field[
                self.catchment_bounds[0] *
                self.ref_to_super_fine_scale_factor:self.catchment_bounds[1] *
                self.ref_to_super_fine_scale_factor, self.catchment_bounds[2] *
                self.ref_to_super_fine_scale_factor:self.catchment_bounds[3] *
                self.ref_to_super_fine_scale_factor]
            self.super_fine_orog_field_sections.append(
                super_fine_orog_field_section)
            if self.super_fine_data_flowtocellfield is not None:
                super_fine_data_flowtocellfield_section = self.super_fine_data_flowtocellfield[self.catchment_bounds[0]
                                                                                               *self.\
                                                                                               ref_to_super_fine_scale_factor:
                                                                                               self.catchment_bounds[1]
                                                                                               *self.\
                                                                                               ref_to_super_fine_scale_factor,
                                                                                               self.catchment_bounds[2]
                                                                                               *self.\
                                                                                               ref_to_super_fine_scale_factor:
                                                                                               self.catchment_bounds[3]
                                                                                               *self.\
                                                                                               ref_to_super_fine_scale_factor]
                self.super_fine_data_flowtocellfield_sections.append(
                    super_fine_data_flowtocellfield_section)
        cax1 = plt.subplot(gs[0, 1])
        cax2 = plt.subplot(gs[0, 4])
        ref_field_min = np.min(ref_orog_field_section)
        ref_field_max = np.max(ref_orog_field_section)
        def_num_levels = 50
        rc_pts.plot_orography_section(ax,
                                      cax1,
                                      ref_orog_field_section,
                                      ref_field_min,
                                      ref_field_max,
                                      pair[0].get_lat(),
                                      pair[0].get_lon(),
                                      new_cb=True,
                                      num_levels=def_num_levels,
                                      lat_offset=self.catchment_bounds[0],
                                      lon_offset=self.catchment_bounds[2])
        ax.yaxis.set_major_formatter(FuncFormatter(format_fn_y))
        ax.xaxis.set_major_formatter(FuncFormatter(format_fn_x))
        plt.setp(ax.xaxis.get_ticklabels(), rotation='vertical')
        rc_pts.plot_orography_section(
            ax2,
            cax2,
            data_orog_original_scale_field_section,
            ref_field_min,
            ref_field_max,
            pair[0].get_lat(),
            pair[0].get_lon(),
            new_cb=True,
            num_levels=def_num_levels,
            lat_offset=self.catchment_bounds[0] * self.scale_factor,
            lon_offset=self.catchment_bounds[2] * self.scale_factor)

        def format_fn_y_scaled(tick_val,
                               tick_pos,
                               offset=self.catchment_bounds[0]):
            return pts.calculate_lat_label(tick_val,
                                           offset,
                                           scale_factor=self.scale_factor)

        self.y_scaled_formatter_funcs.append(format_fn_y_scaled)
        ax2.yaxis.set_major_formatter(FuncFormatter(format_fn_y_scaled))

        def format_fn_x_scaled(tick_val,
                               tick_pos,
                               offset=self.catchment_bounds[2] +
                               fine_grid_offset_adjustment):
            return pts.calculate_lon_label(tick_val,
                                           offset,
                                           scale_factor=self.scale_factor)

        self.x_scaled_formatter_funcs.append(format_fn_x_scaled)
        ax2.xaxis.set_major_formatter(FuncFormatter(format_fn_x_scaled))

        def format_fn_y_super_fine_scaled(tick_val,
                                          tick_pos,
                                          offset=self.catchment_bounds[0]):
            return pts.calculate_lat_label(
                tick_val,
                offset,
                scale_factor=self.ref_to_super_fine_scale_factor)

        self.y_super_fine_scaled_formatter_funcs.append(
            format_fn_y_super_fine_scaled)

        def format_fn_x_super_fine_scaled(tick_val,
                                          tick_pos,
                                          offset=self.catchment_bounds[2] +
                                          super_fine_grid_offset_adjustment):
            return pts.calculate_lon_label(
                tick_val,
                offset,
                scale_factor=self.ref_to_super_fine_scale_factor)

        self.x_super_fine_scaled_formatter_funcs.append(
            format_fn_x_super_fine_scaled)
        plt.setp(ax2.xaxis.get_ticklabels(), rotation='vertical')
        ax3 = plt.subplot(gs[2, 0])
        ax4 = plt.subplot(gs[3, 0])
        ax5 = plt.subplot(gs[2, 2])
        ax6 = plt.subplot(gs[2:4, 3])
        ax7 = plt.subplot(gs[4, 0])
        ax8 = plt.subplot(gs[4, 3])
        ax9 = plt.subplot(gs[2, 6])
        ax10 = plt.subplot(gs[3, 6])
        ax11 = plt.subplot(gs[5, 0])
        ax12 = plt.subplot(gs[4:6, 6])
        ax13 = plt.subplot(gs[5, 3])
        min_height = Slider(ax3,
                            'Minimum Height (m)',
                            ref_field_min,
                            ref_field_max,
                            valinit=ref_field_min)
        max_height = Slider(ax4,
                            'Maximum Height (m)',
                            ref_field_min,
                            ref_field_max,
                            valinit=ref_field_max,
                            slidermin=min_height)
        num_levels_slider = Slider(ax7,
                                   'Number of\nContour Levels',
                                   2,
                                   100,
                                   valinit=def_num_levels)
        fmap_threshold_slider = Slider(ax9,
                                       "Flowmap Threshold",
                                       1,
                                       min(np.max(ref_flowtocellfield_section),
                                           200),
                                       valinit=50)
        reset_height_range_button = Button(ax5,
                                           "Reset",
                                           color='lightgrey',
                                           hovercolor='grey')
        linear_colors_buttons = CheckButtons(ax8, ["Uniform Linear Colors"],
                                             [False])
        overlay_flowmap_buttons = CheckButtons(ax10, ['Flowmap Overlay'],
                                               [False])
        overlay_flowmap_on_super_fine_orog_buttons = CheckButtons(
            ax13, ['Overlay Flowmaps on\n'
                   'Super Fine Orography'], [False])
        display_height_buttons = CheckButtons(ax11, ['Display Height Values'],
                                              [False])
        plot_type_radio_buttons = RadioButtons(
            ax6, ('Contour', 'Filled Contour', 'Image'))
        plot_type_radio_buttons.set_active(2)
        secondary_orog_radio_buttons = RadioButtons(
            ax12, ('Reference Scale Orography', 'Super Fine Scale Orography'))
        secondary_orog_radio_buttons.set_active(0)
        if self.super_fine_orog_field is None:
            ax12.set_visible(False)
        else:
            ax12.set_visible(True)
        ax13.set_visible(False)
        self.ref_orog_field_sections.append(ref_orog_field_section)
        self.data_orog_original_scale_field_sections.append(
            data_orog_original_scale_field_section)
        self.ref_flowtocellfield_sections.append(ref_flowtocellfield_section)
        self.data_flowtocellfield_sections.append(data_flowtocellfield_section)
        self.rdirs_field_sections.append(rdirs_field_section)
        self.catchment_sections.append(self.catchment_section)
        self.orog_gridspecs.append(gs)
        self.orog_rmouth_coords.append(pair[0].get_coords())
        #better to rewrite this as a class with the __call__ property
        update = Update(orog_plot_num=len(self.orog_gridspecs) - 1,
                        lat_offset=self.catchment_bounds[0],
                        lon_offset=self.catchment_bounds[2],
                        plots_object=self)
        min_height.on_changed(update)
        max_height.on_changed(update)
        self.update_funcs.append(update)
        self.min_heights.append(min_height)
        self.max_heights.append(max_height)
        num_levels_slider.on_changed(update)
        self.num_levels_sliders.append(num_levels_slider)
        fmap_threshold_slider.on_changed(update)
        self.fmap_threshold_sliders.append(fmap_threshold_slider)
        overlay_flowmap_buttons.on_clicked(update)
        self.overlay_flowmap_buttons_list.append(overlay_flowmap_buttons)
        overlay_flowmap_on_super_fine_orog_buttons.on_clicked(update)
        self.overlay_flowmap_on_super_fine_orog_buttons_list.append(
            overlay_flowmap_on_super_fine_orog_buttons)
        linear_colors_buttons.on_clicked(update)
        self.linear_colors_buttons_list.append(linear_colors_buttons)
        display_height_buttons.on_clicked(update)
        self.display_height_buttons_list.append(display_height_buttons)
        if self.super_fine_orog_field is not None:
            secondary_orog_radio_buttons.on_clicked(update)
        self.secondary_orog_radio_buttons_list.append(
            secondary_orog_radio_buttons)
        reset = Reset(orog_plot_num=len(self.orog_gridspecs) - 1,
                      plots_object=self)
        reset_height_range_button.on_clicked(reset)
        self.reset_height_range_buttons.append(reset_height_range_button)
        plot_type_radio_buttons.on_clicked(update)
        self.plot_type_radio_buttons_list.append(plot_type_radio_buttons)
        match_right_scaling = Match_Right_Scaling(
            orog_plot_num=len(self.orog_gridspecs) - 1, plots_object=self)
        match_left_scaling = Match_Left_Scaling(
            orog_plot_num=len(self.orog_gridspecs) - 1, plots_object=self)
        match_catch_scaling = Match_Catch_Scaling(
            orog_plot_num=len(self.orog_gridspecs) - 1, plots_object=self)
        ax.callbacks.connect('xlim_changed', match_left_scaling)
        ax.callbacks.connect('ylim_changed', match_left_scaling)
        ax2.callbacks.connect('xlim_changed', match_right_scaling)
        ax2.callbacks.connect('ylim_changed', match_right_scaling)
        ax_catch.callbacks.connect('xlim_changed', match_catch_scaling)
        ax_catch.callbacks.connect('ylim_changed', match_catch_scaling)
        self.match_left_scaling_callback_functions.append(match_left_scaling)
        self.match_right_scaling_callback_functions.append(match_right_scaling)
        self.match_catch_scaling_callback_functions.append(match_catch_scaling)
Пример #24
0
class SubjectCurvesVisualizer:
    """ Displays all curves (experiment data) recorded for a subject. Trial curves are included. """
    def __init__(self, expe, default_subject_index, show_radio_selector=True):
        assert expe.global_params.synths_count == 12, 'This display function is made for 12 synths (with 2 trials) at the moment'

        self.expe = expe

        self.rows_count = 4
        self.cols_count = 6
        self.fig, self.axes = plt.subplots(nrows=self.rows_count,
                                           ncols=self.cols_count,
                                           sharex=True,
                                           sharey=True,
                                           figsize=(9, 8))

        # global x and y labels, and limits
        self.fig.text(0.5, 0.04, 'Time [s]', ha='center', fontsize='12')
        self.fig.text(0.04,
                      0.5,
                      'Normalized parametric errors',
                      va='center',
                      rotation='vertical',
                      fontsize='12')
        self.fig.subplots_adjust(bottom=0.1, left=0.1)

        # radio buttons for going through all subjects
        self.show_radio_selector = show_radio_selector
        radio_selector_x = 0.85 if show_radio_selector else 1.0
        self.fig.subplots_adjust(right=radio_selector_x - 0.05)
        self.widget_ax = plt.axes([radio_selector_x, 0.1, 0.3, 0.8],
                                  frameon=True)  # ,aspect='equal')
        self.radio_buttons = RadioButtons(
            self.widget_ax,
            tuple([str(i) for i in range(len(self.expe.subjects))]))
        if show_radio_selector:
            self.fig.text(0.93, 0.93, 'Subject:', ha='center', fontsize='10')
        self.radio_buttons.on_clicked(self.on_radio_button_changed)
        self.radio_buttons.set_active(default_subject_index)

    def close(self):
        self.fig.clear()
        plt.close(self.fig)

    def on_radio_button_changed(self, label):
        subject_index = int(label)
        subject = self.expe.subjects[subject_index]
        self.update_plot(subject)

    def update_plot(self, subject):
        synths = self.expe.synths

        self.fig.suptitle("Recorded data for subject #{}".format(
            subject.index))

        # actual plots of data
        for j in range(subject.tested_cycles_count):
            col = j % self.cols_count
            row = int(math.floor(j / self.cols_count))
            self.axes[row, col].clear()

            synth_index = subject.synth_indexes_in_appearance_order[j]
            search_type = subject.search_types_in_appearance_order[j]
            if synth_index >= 0 and search_type >= 0 and subject.is_cycle_valid[
                    synth_index, search_type]:
                if search_type == 0:
                    fader_or_interp_char = 'S'
                else:
                    fader_or_interp_char = 'I'
                self.axes[row, col].set_title('[{}]{}'.format(
                    fader_or_interp_char, synths[synth_index].name),
                                              size=10)

                for k in range(subject.params_count):
                    self.axes[row, col].plot(
                        subject.data[synth_index][search_type][k][:, 0],
                        subject.data[synth_index][search_type][k][:, 1])
            else:
                self.axes[row, col].set_title('Unvalid data.', size=10)

        # remaining subplots are hidden
        for j in range(subject.tested_cycles_count,
                       self.rows_count * self.cols_count):
            col = j % self.cols_count
            row = int(math.floor(j / self.cols_count))
            self.axes[row, col].clear()
            self.axes[row, col].axis('off')

        # last is for drawing the legend
        for i in range(4):
            self.axes[self.rows_count - 1, self.cols_count - 1].plot([0], [0])
        self.axes[self.rows_count - 1, self.cols_count - 1].legend(
            ['Parameter 1', 'Parameter 2', 'Parameter 3', 'Parameter 4'],
            loc='lower right')

        # global x and y labels, and limits
        self.axes[0, 0].set_ylim([-1, 1])
        self.axes[0, 0].set_xlim([0, self.expe.global_params.allowed_time])

        # printing of the remark, and hearing impairments
        print('[Subject {:02d}] Remark=\'{}\''.format(subject.index,
                                                      subject.remark))

        plt.draw()  # or does not update graphically...
        if not self.show_radio_selector:
            figurefiles.save_in_subjects_folder(
                self.fig, "Rec_data_subject_{:02d}.pdf".format(subject.index))
class COCO_dataset_generator(object): 
    
    def __init__(self, fig, ax, img_dir, classes, model_path, json_file):
    
        self.RS = RectangleSelector(ax, self.line_select_callback,
                                       drawtype='box', useblit=True,
                                       button=[1, 3],  # don't use middle button
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=True) 
                                         
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        
        #self.classes, self.img_paths, _ = read_JSON_file(json_file)
        with open(classes, 'r') as f:
            self.classes, img_paths = sorted([x.strip().split(',')[0] for x in f.readlines()]), glob.glob(os.path.abspath(os.path.join(img_dir, '*.jpg')))
        plt.tight_layout()

        self.ax = ax
        self.fig = fig
        self.axradio = plt.axes([0.0, 0.0, 0.1, 1])
        self.radio = RadioButtons(self.axradio, self.classes)
        self.zoom_scale = 1.2
        self.zoom_id = self.fig.canvas.mpl_connect('scroll_event', self.zoom) 
        self.keyboard_id = self.fig.canvas.mpl_connect('key_press_event', self.onkeyboard)
        self.selected_poly = False
        self.axsave = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_save = Button(self.axsave, 'Save')
        self.b_save.on_clicked(self.save)        
        self.objects, self.existing_patches, self.existing_rects = [], [], []
        self.num_pred = 0
        if json_file is None:
            self.images, self.annotations = [], [] 
            self.index = 0
            self.ann_id = 0
        else:
            with open(json_file, 'r') as g:
                d = json.loads(g.read())
            self.images, self.annotations = d['images'], d['annotations']
            self.index = len(self.images)
            self.ann_id = len(self.annotations)
        prev_files = [x['file_name'] for x in self.images]
        for i, f in enumerate(img_paths):
            im = Image.open(f)
            width, height = im.size
            dic = {'file_name': f, 'id': self.index+i, 'height': height, 'width': width} 
            if f not in prev_files:
                self.images.append(dic)
            else:
                self.index+=1
        image = plt.imread(self.images[self.index]['file_name'])
        self.ax.imshow(image, aspect='auto')

        if not args['no_feedback']:
            from mask_rcnn.get_json_config import get_demo_config 
            from mask_rcnn import model as modellib
            from mask_rcnn.visualize_cv2 import random_colors
        
            self.config = get_demo_config(len(self.classes)-1, True)

            if 'config_path' in args:
                self.config.from_json(args['config_path'])
        
            plt.connect('draw_event', self.persist)
        
            # Create model object in inference mode.
            self.model = modellib.MaskRCNN(mode="inference", model_dir='/'.join(args['weights_path'].split('/')[:-2]), config=self.config)

            # Load weights trained on MS-COCO
            self.model.load_weights(args['weights_path'], by_name=True)
        
            r = self.model.detect([image], verbose=0)[0]
     
            # Number of instances
            N = r['rois'].shape[0]
        
            masks = r['masks']
        
            # Show area outside image boundaries.
            height, width = image.shape[:2]
        
            class_ids, scores, rois = r['class_ids'], r['scores'], r['rois'],
       
            for i in range(N):
            
                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.classes[class_id-1]
                pat = patches.Rectangle((rois[i][1], rois[i][0]), rois[i][3]-rois[i][1], rois[i][2]-rois[i][0], linewidth=1, edgecolor='r',facecolor='r', alpha=0.4)
                rect = self.ax.add_patch(pat)
                        
                self.objects.append(label)
                self.existing_patches.append(pat.get_bbox().get_points())
                self.existing_rects.append(pat)
            self.num_pred = len(self.objects)
    
    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
    
    def zoom(self, event):
        
        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata # get event x location
        ydata = event.ydata # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print (event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata)/(cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata)/(cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([xdata - new_width * (1-relx), xdata + new_width * (relx)])
        self.ax.set_ylim([ydata - new_height * (1-rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def save(self, event):
        
        data = {'images':self.images[:self.index+1], 'annotations':self.annotations, 'categories':[], 'classes': self.classes}

        with open('output.json', 'w') as outfile:
            json.dump(data, outfile)
    
    def persist(self, event):
        if self.RS.active:
            self.RS.update()
        
    def onkeyboard(self, event):
        
        if not event.inaxes:
            return
        elif event.key == 'a':
            for i, ((xmin, ymin), (xmax, ymax)) in enumerate(self.existing_patches):
                if xmin<=event.xdata<=xmax and ymin<=event.ydata<=ymax:
                    self.radio.set_active(self.classes.index(self.objects[i]))
                    self.RS.set_active(True)
                    self.rectangle = self.existing_rects[i]
                    self.rectangle.set_visible(False)
                    coords = self.rectangle.get_bbox().get_points()
                    self.RS.extents = [coords[0][0], coords[1][0], coords[0][1], coords[1][1]]
                    self.RS.to_draw.set_visible(True)
                    self.fig.canvas.draw()
                    self.existing_rects.pop(i)
                    self.existing_patches.pop(i)
                    self.objects.pop(i)
                    fig.canvas.draw()
                    break
            
        elif event.key == 'i':
            b = self.RS.extents # xmin, xmax, ymin, ymax
            b = [int(x) for x in b]
            if b[1]-b[0]>0 and b[3]-b[2]>0:
                poly = [b[0], b[2], b[0], b[3], b[1], b[3], b[1], b[2], b[0], b[2]]
                area = (b[1]-b[0])*(b[3]-b[2])
                bbox = [b[0], b[2], b[1], b[3]]
                dic2 = {'segmentation': [poly], 'area': area, 'iscrowd':0, 'image_id':self.index, 'bbox':bbox, 'category_id': self.classes.index(self.radio.value_selected)+1, 'id': self.ann_id}
                if dic2 not in self.annotations:
                    self.annotations.append(dic2)
                self.ann_id+=1
                rect = patches.Rectangle((b[0],b[2]),b[1]-b[0],b[3]-b[2],linewidth=1,edgecolor='g',facecolor='g', alpha=0.4)
                self.ax.add_patch(rect)
                
                self.RS.set_active(False)
                
                self.fig.canvas.draw()
                self.RS.set_active(True)
        elif event.key in ['N', 'n']:
            self.ax.clear()
            self.index+=1
            if (len(self.objects)==self.num_pred):
                self.images.pop(self.index-1)
                self.index-=1
            if self.index==len(self.images):
                exit()
            image = plt.imread(self.images[self.index]['file_name'])
            self.ax.imshow(image)
            self.ax.set_yticklabels([])
            self.ax.set_xticklabels([])
            r = self.model.detect([image], verbose=0)[0]
     
            # Number of instances
            N = r['rois'].shape[0]
        
            masks = r['masks']
        
            # Show area outside image boundaries.
            height, width = image.shape[:2]
        
            class_ids, scores, rois = r['class_ids'], r['scores'], r['rois'],
            self.existing_rects, self.existing_patches, self.objects = [], [], []
            for i in range(N):
                
                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.classes[class_id-1]
                pat = patches.Rectangle((rois[i][1], rois[i][0]), rois[i][3]-rois[i][1], rois[i][2]-rois[i][0], linewidth=1, edgecolor='r',facecolor='r', alpha=0.4)
                rect = self.ax.add_patch(pat)
                        
                self.objects.append(label)

                self.existing_patches.append(pat.get_bbox().get_points())
                self.existing_rects.append(pat)
            
            self.num_pred = len(self.objects)
            self.fig.canvas.draw()
            
        elif event.key in ['q','Q']:
            exit()
Пример #26
0
class EventViewer():
    def __init__(self,
                 event_stream,
                 camera_config_file,
                 n_samples,
                 scale='lin',
                 limits_colormap=None,
                 limits_readout=None,
                 time_bin_start=0,
                 pixel_id_start=0):

        mpl.figure.autolayout = False
        self.first_call = True
        self.event_stream = event_stream
        self.scale = scale
        self.limits_colormap = limits_colormap if limits_colormap is not None else [
            -np.inf, np.inf
        ]
        self.limits_readout = limits_readout
        self.time_bin = time_bin_start
        self.pixel_id = pixel_id_start
        self.mask_pixels = False
        self.hillas = False

        self.event_clicked_on = EventClicked(pixel_start=self.pixel_id)
        self.camera = camera.Camera(_config_file=camera_config_file)
        self.geometry = geometry.generate_geometry(camera=self.camera)[0]
        self.n_pixels = len(self.camera.Pixels)
        self.n_samples = n_samples
        self.cluster_matrix = np.zeros(
            (len(self.camera.Clusters_7), len(self.camera.Clusters_7)))

        for cluster in self.camera.Clusters_7:

            for patch in cluster.patchesID:
                self.cluster_matrix[cluster.ID, patch] = 1

        self.event_id = None
        self.r0_container = None
        self.r1_container = None
        self.dl0_container = None
        self.dl1_container = None
        self.dl2_container = None
        self.trigger_output = None
        self.trigger_input = None
        self.trigger_patch = None
        self.nsb = [np.nan] * self.n_pixels
        self.gain_drop = [np.nan] * self.n_pixels
        self.baseline = [np.nan] * self.n_pixels
        self.std = [np.nan] * self.n_pixels
        self.flag = None

        self.readout_view_types = [
            'raw', 'baseline substracted', 'photon', 'trigger input',
            'trigger output', 'cluster 7', 'reconstructed charge'
        ]
        self.readout_view_type = 'raw'

        self.camera_view_types = ['sum', 'std', 'mean', 'max', 'time']
        self.camera_view_type = 'std'

        self.figure = plt.figure(figsize=(20, 10))
        self.axis_readout = self.figure.add_subplot(122)
        self.axis_camera = self.figure.add_subplot(121)
        self.axis_camera.axis('off')

        self.axis_readout.set_xlabel('t [ns]')
        self.axis_readout.set_ylabel('[ADC]')
        self.axis_readout.legend(loc='upper right')
        self.axis_readout.yaxis.set_major_formatter(FormatStrFormatter('%d'))
        self.axis_readout.yaxis.set_major_locator(
            MaxNLocator(integer=True, nbins=10))

        self.trace_readout, = self.axis_readout.step(
            np.arange(self.n_samples) * 4,
            np.ones(self.n_samples),
            where='mid')
        self.trace_time_plot, = self.axis_readout.plot(
            np.array([self.time_bin, self.time_bin]) * 4,
            np.ones(2),
            color='k',
            linestyle='--')

        self.camera_visu = visualization.CameraDisplay(self.geometry,
                                                       ax=self.axis_camera,
                                                       title='',
                                                       norm=self.scale,
                                                       cmap='viridis',
                                                       allow_pick=True)

        #if limits_colormap is not None:
        #    self.camera_visu.set_limits_minmax(limits_colormap[0], limits_colormap[1])

        self.camera_visu.image = np.zeros(self.n_pixels)
        self.camera_visu.cmap.set_bad(color='k')
        self.camera_visu.add_colorbar(orientation='horizontal',
                                      pad=0.03,
                                      fraction=0.05,
                                      shrink=.85)

        #if self.scale == 'log':
        #    self.camera_visu.colorbar.set_norm(LogNorm(vmin=1, vmax=None, clip=False))
        self.camera_visu.colorbar.set_label('[LSB]')
        self.camera_visu.axes.get_xaxis().set_visible(False)
        self.camera_visu.axes.get_yaxis().set_visible(False)
        self.camera_visu.on_pixel_clicked = self.draw_readout
        self.camera_visu.pixels.set_snap(False)  # snap cursor to pixel center

        #        self.camera_visu.pixels.set_picker(False)

        # Buttons

        self.axis_next_event_button = self.figure.add_axes(
            [0.35, 0.9, 0.15, 0.07], zorder=np.inf)
        self.axis_next_camera_view_button = self.figure.add_axes(
            [0., 0.85, 0.1, 0.15], zorder=np.inf)
        self.axis_next_view_type_button = self.figure.add_axes(
            [0., 0.18, 0.1, 0.15], zorder=np.inf)
        self.axis_check_button = self.figure.add_axes([0.35, 0.18, 0.1, 0.1],
                                                      zorder=np.inf)
        self.axis_next_camera_view_button.axis('off')
        self.axis_next_view_type_button.axis('off')
        self.button_next_event = Button(self.axis_next_event_button, 'Next')
        self.radio_button_next_camera_view = RadioButtons(
            self.axis_next_camera_view_button,
            self.camera_view_types,
            active=self.camera_view_types.index(self.camera_view_type))
        self.radio_button_next_view_type = RadioButtons(
            self.axis_next_view_type_button,
            self.readout_view_types,
            active=self.readout_view_types.index(self.readout_view_type))

        self.check_button = CheckButtons(self.axis_check_button,
                                         ('mask', 'hillas'),
                                         (self.mask_pixels, self.hillas))
        self.radio_button_next_view_type.set_active(
            self.readout_view_types.index(self.readout_view_type))
        self.radio_button_next_camera_view.set_active(
            self.camera_view_types.index(self.camera_view_type))

    def next(self, event=None, step=1):

        for i, event in zip(range(step), self.event_stream):
            pass

        telescope_id = event.r0.tels_with_data[0]
        self.event_id = event.r0.event_id
        self.r0_container = event.r0.tel[telescope_id]
        self.r1_container = event.r1.tel[telescope_id]
        self.dl0_container = event.dl0.tel[telescope_id]
        self.dl1_container = event.dl1.tel[telescope_id]
        self.dl2_container = event.dl2
        self.adc_samples = self.r0_container.adc_samples
        self.trigger_output = self.r0_container.trigger_output_patch7
        self.trigger_input = self.r0_container.trigger_input_traces

        try:

            self.baseline = self.r0_container.baseline if np.isnan(
                self.r0_container.digicam_baseline).all(
                ) else self.r0_container.digicam_baseline
            zero_image = np.zeros((self.n_pixels, self.n_samples))

            self.std = self.r0_container.standard_deviation if self.r0_container.standard_deviation is not None else np.nan * zero_image
            self.flag = self.r0_container.camera_event_type if self.r0_container.camera_event_type is not None else np.nan
            self.nsb = self.r1_container.nsb if self.r1_container.nsb is not None else np.nan * zero_image
            self.gain_drop = self.r1_container.gain_drop if self.r1_container.gain_drop is not None else np.nan * zero_image

        except:

            pass

        if self.first_call:

            self.first_call = False

        self.update()

    def update(self):

        self.draw_readout(self.pixel_id)
        self.draw_camera()
        self.button_next_event.label.set_text('Next : current event #%d' %
                                              (self.event_id))

    def draw(self):

        self.next()
        self.button_next_event.on_clicked(self.next)
        self.radio_button_next_camera_view.on_clicked(self.next_camera_view)
        self.radio_button_next_view_type.on_clicked(self.next_view_type)
        self.check_button.on_clicked(self.draw_on_camera)
        self.figure.canvas.mpl_connect('key_press_event', self.press)
        self.camera_visu._on_pick(self.event_clicked_on)
        plt.show()

    def draw_camera(self, plot_hillas=False):

        self.camera_visu.image = self.compute_image()

        if plot_hillas:
            self.camera_visu.overlay_moments(self.dl2_container.shower)

    def draw_readout(self, pixel):

        y = self.compute_trace()[pixel]
        limits_y = self.limits_readout if self.limits_readout is not None else [
            np.min(y), np.max(y) + 10
        ]
        self.pixel_id = pixel
        self.event_clicked_on.ind[-1] = self.pixel_id
        self.trace_readout.set_ydata(y)

        legend = ''

        try:

            legend += ' flag = {},'.format(self.flag)

        except:

            pass

        legend += ' pixel = {},'.format(self.pixel_id)
        legend += ' bin = {} \n'.format(self.time_bin)

        try:

            legend += ' B = {:0.2f} [LSB],'.format(
                self.baseline[self.pixel_id])

        except:

            pass

        try:

            legend += ' $\sigma = $ {:0.2f} [LSB] \n'.format(
                self.std[self.pixel_id])

        except:

            pass

        try:

            legend += ' $G_{{drop}} = $ {:0.2f},'.format(
                self.gain_drop[self.pixel_id])
            legend += ' $f_{{nsb}} = $ {:0.2f} [GHz]'.format(
                self.nsb[self.pixel_id])

        except:

            pass

        self.trace_readout.set_label(legend)
        self.trace_time_plot.set_ydata(limits_y)
        self.trace_time_plot.set_xdata(self.time_bin * 4)
        self.axis_readout.set_ylim(limits_y)
        self.axis_readout.legend(loc='upper right')

        if self.readout_view_type in ['photon', 'reconstructed charge']:

            self.axis_readout.set_ylabel('[p.e.]')

        else:

            self.axis_readout.set_ylabel('[LSB]')

    def compute_trace(self):

        if self.readout_view_type in self.readout_view_types:

            if self.readout_view_type == 'raw':

                image = self.adc_samples

            elif self.readout_view_type == 'trigger output' and self.trigger_output is not None:

                image = np.array([
                    self.trigger_output[pixel.patch]
                    for pixel in self.camera.Pixels
                ])

            elif self.readout_view_type == 'trigger input' and self.trigger_input is not None:

                image = np.array([
                    self.trigger_input[pixel.patch]
                    for pixel in self.camera.Pixels
                ])  #np.zeros((self.n_pixels, self.n_samples))

            elif self.readout_view_type == 'cluster 7' and self.trigger_input is not None:

                trigger_input_patch = np.dot(self.cluster_matrix,
                                             self.trigger_input)
                image = np.array([
                    trigger_input_patch[pixel.patch]
                    for pixel in self.camera.Pixels
                ])

            elif self.readout_view_type == 'photon' and self.dl1_container.pe_samples_trace is not None:

                image = self.dl1_container.pe_samples_trace

            elif self.readout_view_type == 'baseline substracted' and self.r1_container.adc_samples is not None:

                image = self.adc_samples - self.baseline[:, np.newaxis]

            elif self.readout_view_type == 'reconstructed charge' and (
                    self.dl1_container.time_bin is not None
                    or self.dl1_container.pe_samples is not None):

                image = np.zeros((self.n_pixels, self.n_samples))
                time_bins = self.dl1_container.time_bin
                image[time_bins] = self.dl1_container.pe_samples

            else:

                image = np.zeros((self.n_pixels, self.n_samples))

        return image

    def next_camera_view(self, camera_view, event=None):

        self.camera_view_type = camera_view
        if self.readout_view_type in ['photon', 'reconstructed charge']:
            self.camera_visu.colorbar.set_label('[p.e.]')

        else:
            self.camera_visu.colorbar.set_label('[LSB]')

        self.update()

    def next_view_type(self, view_type, event=None):

        self.readout_view_type = view_type

        if view_type in ['photon', 'reconstructed charge']:
            self.camera_visu.colorbar.set_label('[p.e.]')

        else:
            self.camera_visu.colorbar.set_label('[LSB]')

        self.update()

    def draw_on_camera(self, to_draw_on, event=None):

        if to_draw_on == 'hillas':

            if self.hillas:

                self.hillas = False

            else:

                self.hillas = True

        if to_draw_on == 'mask':

            if self.mask_pixels:

                self.mask_pixels = False

            else:

                self.mask_pixels = True

        self.update()

    def set_time(self, time):

        if time < self.n_samples and time >= 0:

            self.time_bin = time
            self.update()

    def set_pixel(self, pixel_id):

        if pixel_id < self.n_samples and pixel_id >= 0:
            self.pixel_id = pixel_id
            self.update()

    def compute_image(self):

        image = self.compute_trace()

        if self.camera_view_type in self.camera_view_types:

            if self.camera_view_type == 'mean':

                self.image = np.mean(image, axis=1)

            elif self.camera_view_type == 'std':

                self.image = np.std(image, axis=1)

            elif self.camera_view_type == 'max':

                self.image = np.max(image, axis=1)

            elif self.camera_view_type == 'sum':

                self.image = np.sum(image, axis=1)

            elif self.camera_view_type == 'time':

                self.image = image[:, self.time_bin]

        else:

            print('Cannot compute for camera type : %s' % self.camera_view)

        if self.limits_colormap is not None:

            mask = (self.image >= self.limits_colormap[0])
            if not self.limits_colormap[1] == np.inf:
                image[(self.image >
                       self.limits_colormap[1])] = self.limits_colormap[1]

        if self.mask_pixels:

            #mask = mask * self.dl1_container.cleaning_mask
            mask = mask * self.dl1_container.cleaning_mask
            #image[~self.dl1_container.cleaning_mask] = 0

        if self.hillas:

            self.camera_visu.overlay_moments(self.dl2_container.shower,
                                             color='r',
                                             linewidth=4)

        else:

            self.camera_visu.clear_overlays()

        return np.ma.masked_where(~mask, self.image)

    def press(self, event):

        sys.stdout.flush()

        if event.key == 'enter':
            self.next()

        if event.key == 'right':
            self.set_time(self.time_bin + 1)

        if event.key == 'left':
            self.set_time(self.time_bin - 1)

        if event.key == '+':
            self.set_pixel(self.pixel_id + 1)

        if event.key == '-':
            self.set_pixel(self.pixel_id - 1)

        if event.key == 'h':
            self.axis_next_event_button.set_visible(False)

        if event.key == 'v':
            self.axis_next_event_button.set_visible(True)

        self.update()
Пример #27
0
class ObjectLabeler():
	def __init__(self, frameDirectory, annotationFile):

		self.frameDirectory = frameDirectory
		self.annotationFile = annotationFile
		self.frames = [x for x in os.listdir(self.frameDirectory) if '.jpg' in x]
		assert len(self.frames) > 0

		# Initialize flag to keep track of whether the user is drawing and object box
		self.drawing = False

		# Keep track of the frame we are on
		self.frame_index = 0

		# Intialize lists to hold annotated objects
		self.coords = ()
		self.poses = () 

		# Create dataframe to hold annotations
		if os.path.exists(self.annotationFile):
			self.dt = pd.read_csv(self.annotationFile)
		else:
			self.dt = pd.DataFrame(columns=['Framefile', 'Nfish', 'Sex', 'Box', 'Nose', 'LEye', 'REye', 'Tail', 'User', 'DateTime'])
		self.f_dt = pd.DataFrame(columns=['Framefile','Sex', 'Box', 'Nose', 'LEye', 'REye', 'Tail', 'User', 'DateTime'])

		# Get user and current time
		self.user = os.getenv('USER')
		self.now = datetime.datetime.now()

		# Create Annotation object
		self.annotation = Annotation(self)

		# Start figure
		self._createFigure()

	def _createFigure(self):
		# Create figure
		self.fig = fig = plt.figure(1, figsize=(10,7))

		# Create image subplot
		self.ax_image = fig.add_axes([0.05,0.2,.8,0.75])
		while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index]) & (self.dt.User == self.user)]) != 0:
			self.frame_index += 1
		img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
		self.image_obj = self.ax_image.imshow(img)
		self.ax_image.set_title(self.frames[0])

		# Create selectors for identifying bounding bos and body parts (nose, left eye, right eye, tail)
		self.PS = PolygonSelector(self.ax_image, self._grabPoses,
									   useblit=True )
		self.PS.set_active(False)

		self.RS = RectangleSelector(self.ax_image, self._grabBoundingBox,
									   drawtype='box', useblit=True,
									   button=[1, 3],  # don't use middle button
									   minspanx=5, minspany=5,
									   spancoords='pixels',
									   interactive=True)
		self.RS.set_active(True)

		# Create radio buttons
		self.ax_radio = fig.add_axes([0.85,0.85,0.125,0.1])
		self.radio_names = [r"$\bf{M}$" + 'ale',r"$\bf{F}$" + 'emale',r"$\bf{O}$" +'ccluded',r"$\bf{U}$" +'nknown']
		self.bt_radio =  RadioButtons(self.ax_radio, self.radio_names, active=0, activecolor='blue' )

		# Create click buttons for adding annotations
		self.ax_box = fig.add_axes([0.85,0.775,0.125,0.04])
		self.ax_pose = fig.add_axes([0.85,0.725,0.125,0.04])
		self.bt_box = Button(self.ax_box, 'Add ' + r"$\bf{B}$" + 'ox')
		self.bt_poses = Button(self.ax_pose, 'Add ' + r"$\bf{P}$" + 'ose')
		
		# Create click buttons for keeping or clearing annotations
		self.ax_anAdd = fig.add_axes([0.85,0.525,0.125,0.04])
		self.ax_anClear = fig.add_axes([0.85,0.475,0.125,0.04])
		self.bt_anAdd = Button(self.ax_anAdd, r"$\bf{K}$" + 'eep Ann')
		self.bt_anClear = Button(self.ax_anClear, r"$\bf{C}$" + 'lear Ann')

		# Create click buttons for saving frame annotations or starting over
		self.ax_frameAdd = fig.add_axes([0.85,0.225,0.125,0.04])
		self.ax_frameClear = fig.add_axes([0.85,0.175,0.125,0.04])
		self.bt_frameAdd = Button(self.ax_frameAdd, r"$\bf{N}$" + 'ext Frame')
		self.bt_frameClear = Button(self.ax_frameClear, r"$\bf{R}$" + 'estart')

		# Add text boxes to display info on annotations
		self.ax_cur_text = fig.add_axes([0.85,0.575,0.125,0.14])
		self.ax_cur_text.set_axis_off()
		self.cur_text =self.ax_cur_text.text(0, 1, '', fontsize=9, verticalalignment='top')

		self.ax_all_text = fig.add_axes([0.85,0.275,0.125,0.19])
		self.ax_all_text.set_axis_off()
		self.all_text =self.ax_all_text.text(0, 1, '', fontsize=9, verticalalignment='top')

		self.ax_error_text = fig.add_axes([0.1,0.05,.7,0.1])
		self.ax_error_text.set_axis_off()
		self.error_text =self.ax_error_text.text(0, 1, 'TEST', fontsize=14, color = 'red', verticalalignment='top')


		# Set buttons in active that shouldn't be pressed
		#self.bt_poses.set_active(False)
		
		# Turn on keypress events to speed things up
		self.fig.canvas.mpl_connect('key_press_event', self._keypress)

		# Turn off hover event for buttons (no idea why but this interferes with the image rectange remaining displayed)
		self.fig.canvas.mpl_disconnect(self.bt_box.cids[2])
		self.fig.canvas.mpl_disconnect(self.bt_poses.cids[2])
		self.fig.canvas.mpl_disconnect(self.bt_anAdd.cids[2])
		self.fig.canvas.mpl_disconnect(self.bt_anClear.cids[2])
		self.fig.canvas.mpl_disconnect(self.bt_frameAdd.cids[2])
		self.fig.canvas.mpl_disconnect(self.bt_frameClear.cids[2])

		# Connect buttons to specific functions		
		self.bt_box.on_clicked(self._addBoundingBox)
		self.bt_poses.on_clicked(self._addPose)
		self.bt_anAdd.on_clicked(self._saveAnnotation)
		self.bt_anClear.on_clicked(self._clearAnnotation)
		self.bt_frameAdd.on_clicked(self._nextFrame)
		self.bt_frameClear.on_clicked(self._clearFrame)

		# Show figure
		plt.show()

	def _grabBoundingBox(self, eclick, erelease):
		self.error_text.set_text('')

		# Transform and store image coords
		image_coords = list(self.ax_image.transData.inverted().transform((eclick.x, eclick.y))) + list(self.ax_image.transData.inverted().transform((erelease.x, erelease.y)))
		
		# Convert to integers:
		image_coords = tuple([int(x) for x in image_coords])

		xy = (min(image_coords[0], image_coords[2]), min(image_coords[1], image_coords[3]))
		width = abs(image_coords[0] - image_coords[2])
		height = abs(image_coords[1] - image_coords[3])
		self.annotation.coords = xy + (width, height)

	def _grabPoses(self, event):
		self.error_text.set_text('')

		if len(event) != 4:
			self.error_text.set_text('Error: must have exactly four points selected. Try again')

			self.PS._xs, self.PS._ys = [0], [0]
			self.PS._polygon_completed = False
		else:
			self.annotation.poses = tuple([(int(x[0]),int(x[1])) for x in event])

	def _keypress(self, event):
		if event.key in ['m', 'f', 'o', 'u']:
			self.bt_radio.set_active(['m', 'f', 'o', 'u'].index(event.key))
			#self.fig.canvas.draw()
		elif event.key == 'b':
			self._addBoundingBox(event)
		elif event.key == 'p':
			self._addPose(event)	
		elif event.key == 'k':
			self._saveAnnotation(event)
		elif event.key == 'c':
			self._clearAnnotation(event)
		elif event.key == 'n':
			self._nextFrame(event)
		elif event.key == 'r':
			self._clearFrame(event)
		else:
			pass

	def _addBoundingBox(self, event):
		if self.annotation.coords == ():
			self.error_text.set_text('Error: Bounding box not set')
			self.fig.canvas.draw()

			self.RS.set_active(True)
			self.PS.set_active(False)
			return

		# Add new patch rectangle
		#colormap = {self.radio_names[0]:'blue', self.radio_names[1]:'pink', self.radio_names[2]: 'red', self.radio_names[3]: 'black'}
		#color = colormap[self.bt_radio.value_selected]
		self.annotation.addRectangle()
		self.fig.canvas.draw()

		# Change to pose selection
		self.RS.set_active(False)
		self.PS.set_active(True)
		#self.bt_poses.set_active(True) # Need to save bounding box before you can select poses
		#self.bt_box.set_active(False) # Need to save bounding box before you can select poses

	def _addPose(self, event):
		if self.annotation.poses == ():
			self.error_text.set_text('Error: Poses not set')
			self.fig.canvas.draw()

			self.RS.set_active(False)
			self.PS.set_active(True)
			return

		#print(self.poses)
		#self.img_annotations.add(self.xy + (self.width, self.height))
		self.annotation.addCircles()
		self.fig.canvas.draw()

		self.PS._xs, self.PS._ys = [0], [0]
		self.PS._polygon_completed = False

		# Change to box selection
		self.RS.set_active(False)
		self.PS.set_active(False)

	def _saveAnnotation(self, event):

		displayed_names = [r"$\bf{M}$" + 'ale',r"$\bf{F}$" + 'emale',r"$\bf{O}$" +'ccluded',r"$\bf{U}$" +'nknown']
		stored_names = ['m','f','o','u']
		
		self.annotation.sex = stored_names[displayed_names.index(self.bt_radio.value_selected)]

		outrow = self.annotation.retRow()

		if type(outrow) == str:
			self.error_text.set_text(outrow)
			self.fig.canvas.draw()
			self.RS.set_active(False)
			self.PS.set_active(True)
			return
		else:
			self.f_dt.loc[len(self.f_dt)] = outrow

		# Add annotation to the temporary data frame
		self.cur_text.set_text('')
		self.all_text.set_text('# Ann = ' + str(len(self.f_dt)))

		self.annotation.reset()

		self.fig.canvas.draw()

		self.RS.set_active(True)
		self.PS.set_active(False)


	def _clearAnnotation(self, event):

		print(self.ax_image.patches)
		self.annotation.removePatches()
		
		print(self.ax_image.patches)

		self.annotation.reset()

		self.cur_text.set_text('')
		
		self.fig.canvas.draw()

		self.RS.set_active(True)
		self.PS.set_active(False)

	def _nextFrame(self, event):

		if len(self.f_dt) == 0:
			self.f_dt.loc[0] = [self.frames[self.frame_index],'','','','','','',self.user, self.now]
			self.f_dt['Nfish'] = 0
		else:
			self.f_dt['Nfish'] = len(self.f_dt)
		self.dt = self.dt.append(self.f_dt, sort=True)
		# Save dataframe (in case user quits)
		self.dt.to_csv(self.annotationFile, sep = ',')
		self.f_dt = pd.DataFrame(columns=['Framefile','Sex', 'Box', 'Nose', 'LEye', 'REye', 'Tail', 'User', 'DateTime'])


		# Remove old patches
		self.ax_image.patches = []

		# Reset annotations
		self.annotation.reset()

		# Update frame index and determine if all images are annotated
		self.frame_index += 1
		while len(self.dt[(self.dt.Framefile == self.frames[self.frame_index]) & (self.dt.User == self.user)]) != 0:
			self.frame_index += 1

		if self.frame_index == len(self.frames):

			# Disconnect connections and close figure
			plt.close(self.fig)

		self.cur_text.set_text('')
		self.all_text.set_text('')

		# Load new image and save it as the background
		img = plt.imread(self.frameDirectory + self.frames[self.frame_index])
		self.image_obj.set_array(img)
		self.ax_image.set_title(self.frames[self.frame_index])
		self.fig.canvas.draw()
		#self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)
		self.RS.set_active(True)
		self.PS.set_active(False)


	def _clearFrame(self, event):
		self.f_dt = pd.DataFrame(columns=['Framefile','Sex', 'Box', 'Nose', 'LEye', 'REye', 'Tail', 'User', 'DateTime'])
		# Remove old patches
		self.ax_image.patches = []

		self.fig.canvas.draw()

		# Reset annotations
		self.annotation.reset()
		self.RS.set_active(True)
		self.PS.set_active(False)