class plot_window(QtWidgets.QDialog):

    def __init__(self,parent=None):
        super(plot_window,self).__init__(parent)
        self.window = QtWidgets.QWidget(self)
        self.resize(653, 500)

        self.figure = plt.figure(figsize=(2, 2))
        self.canvas = FigureCanvas(self.figure)
        self.ax = plt.axes(projection='3d')

        self.import_button = QtWidgets.QPushButton('import')
        self.plot_button = QtWidgets.QPushButton('plot')

        self.import_button.clicked.connect(self.import_csv)
        self.plot_button.clicked.connect(self.plot)

        self.main_layout = QtWidgets.QVBoxLayout()
        
        self.main_layout.addWidget(self.canvas)
        self.main_layout.addWidget(self.import_button)
        self.main_layout.addWidget(self.plot_button)
        self.window.setLayout(self.main_layout)
        self.i = 1

    def import_csv(self):
        self.fileName_choose, self.filetype = QtWidgets.QFileDialog.getOpenFileName(self, "Find Files", QtCore.QDir.currentPath(), 'CSV (*.csv)')
        self.csv = pd.read_csv(self.fileName_choose, sep='\s+')
        self.nbody = int(((len(self.csv.columns) - 1) / 3))
        self.timeticks = int(len(self.csv))
        
    def plot(self):
        self.ax.clear()
        for i in range(1, self.timeticks):
            x = [0]
            y = [0]
            z = [0]
            for j in range(0, self.nbody):
                x.append(self.csv.iloc[i, int(j*3 + 1)])
                y.append(self.csv.iloc[i, int(j*3 + 2)])
                z.append(-self.csv.iloc[i, int(j*3 + 3)])
            self.ax.plot(x, y, z)
            self.canvas.draw()
            self.canvas.flush_events()
Beispiel #2
0
class HeatMapWidget(QWidget):
    def __init__(self):
        QWidget.__init__(self)

        self.setWindowTitle("Heatmap")

        # QWidget Layout
        self.main_layout = QHBoxLayout()
        self.size = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)   
        
        # generate the plot
        self.fig, self.ax = plt.subplots()
        self.fig.tight_layout = True
        self.fig.subplots_adjust(left=0.10, right=0.90, top=0.90, bottom=0.10)
        self.ax.margins(0,0)
        self.canvas = FigureCanvas(self.fig)       

        ## Main Layout
        self.size.setHorizontalStretch(4)
        self.canvas.setSizePolicy(self.size)
        self.main_layout.addWidget(self.canvas)

        # Set the layout to the QWidget
        self.setLayout(self.main_layout)

    def Load(self, wbname, directory, wave):
        # generate heatmap
        filename = 'matrix_PSC_' + directory + '_' + wbname + '_' + wave
        mat = sp.loadmat('./Datas/PSC Matrix/'+ directory + '/' + wave + '/' + filename + '.mat')
        M = mat['PSC']
        sns.heatmap(M, cmap="jet", yticklabels=False, xticklabels=getElectrodesList(), vmin=0, vmax=1)
        self.ax.set_title("Heatmap for synchronization " + wave + " waves for " + directory + ":" + wbname)
        self.ax.set_xlabel("Electrodes")
        self.ax.set_ylabel("Synchronization between electrodes")    
        # generate the canvas to display the plot  
        self.canvas.draw()
        self.canvas.flush_events()
class TrainingController(QMainWindow, Ui_TrainingWindow):
    def __init__(self, parent, data: pd.DataFrame, title=''):
        QMainWindow.__init__(self, parent, QtCore.Qt.WindowStaysOnTopHint)
        Ui_TrainingWindow.__init__(self)
        self.setupUi(self)

        self.data: pd.DataFrame = data.select_dtypes([np.number])
        self.filterTitle = title
        self.setWindowTitle(
            'Neural Network Training - {}'.format(self.filterTitle))
        self.isStop = False

        # Filter Columns
        self.comboOutput.addItems(list(self.data.columns))
        self.comboOutput.currentIndexChanged.connect(self.onComboOutput)
        self.comboOutput.setCurrentIndex(-1)

        # widget inputs
        self.layoutInputs = QVBoxLayout()
        self.layoutInputs.setContentsMargins(0, 0, 0, 0)
        self.widgetInputs.setLayout(self.layoutInputs)

        # widget canvas
        self.canvas = FigureCanvasQTAgg(Figure())
        vLayout = QVBoxLayout()
        vLayout.setContentsMargins(0, 0, 0, 0)
        vLayout.addWidget(self.canvas)
        self.GraphCanvas.setLayout(vLayout)

        self.addToolBar(NavigationToolbar2QT(self.canvas, self))

        # icons
        self.playIcon = Icon('play.svg', '#0A0').getIcon()
        self.stopIcon = Icon('stop.svg', '#A00').getIcon()
        self.saveIcon = Icon('save.svg', '#00A').getIcon()

        # Buttons
        self.pushStart.setIcon(self.playIcon)
        self.pushStart.clicked.connect(self.onStart)

        self.pushStop.setIcon(self.stopIcon)
        self.pushStop.clicked.connect(self.onStop)

        self.pushSave.setIcon(self.saveIcon)
        self.pushSave.clicked.connect(self.onSave)

    def onComboOutput(self, index):
        if index > -1:
            self.pushStart.setEnabled(True)
            self.clearLayout(self.layoutInputs)

            self.listCheckBox = list(self.data.drop(
                self.comboOutput.currentText(), 1).columns)
            for i, value in enumerate(self.listCheckBox):
                self.listCheckBox[i] = QCheckBox(value)
                self.layoutInputs.addWidget(self.listCheckBox[i])

    def clearLayout(self, layout):
        if layout is not None:
            while layout.count():
                item = layout.takeAt(0)
                widget = item.widget()
                if widget is not None:
                    widget.deleteLater()
                else:
                    self.clearLayout(item.layout())

    def onStart(self):
        if self.spinIterations.value() > 0 and self.spinLR.value() > 0 and self.spinAlpha.value() > 0:
            self.Xn = []
            self.notXn = []
            for i, v in enumerate(self.listCheckBox):
                if v.checkState():
                    self.Xn.append(v.text())
                else:
                    self.notXn.append(v.text())
            self.yn = self.comboOutput.currentText()

            if len(self.Xn) > 0:
                self.comboOutput.setEnabled(False)
                self.widgetInputs.setEnabled(False)
                self.pushStart.setEnabled(False)
                self.pushSave.setEnabled(False)
                self.pushStop.setEnabled(True)
                self.spinIterations.setEnabled(False)
                self.spinLR.setEnabled(False)
                self.spinAlpha.setEnabled(False)

                nColumns = ceil(len(self.Xn)/2)

                X = self.data.drop(self.yn, 1)
                for i, v in enumerate(self.notXn):
                    X = X.drop(v, 1)

                y = self.data[self.yn]
                X_train, X_test, y_train, y_test = train_test_split(X, y)

                self.mlp = MLPRegressor(solver='adam', alpha=self.spinAlpha.value(
                ), learning_rate_init=self.spinLR.value(), max_iter=self.spinIterations.value())

                i = 1
                self.score_test = 0
                self.score_train = 0
                self.canvas.figure.clf()
                while i <= self.spinIterations.value():
                    self.mlp.partial_fit(X_train, y_train)
                    self.score_test = r2_score(
                        y_test, self.mlp.predict(X_test))
                    self.score_train = self.mlp.score(X_train, y_train)

                    auxDF = pd.DataFrame()
                    for j, value in enumerate(self.Xn):
                        auxDF[value] = np.linspace(self.data[value].min(), self.data[value].max())

                    for j, value in enumerate(self.Xn):
                        axes = self.getAxes(2, nColumns, j+1)
                        axes.clear()

                        axes.set_title('Training X={}, y={} - Epoch {}/{}'.format(
                            value, self.comboOutput.currentText(), i, self.spinIterations.value()))
                        axes.scatter(
                            self.data[value], y, c='blue', label='Real Data - Test Score={:.2f}'.format(self.score_test))
                        axes.plot(auxDF[value], self.mlp.predict(auxDF), 'r--', c='red', label='NN Model - Training Score={:.2f}'.format(self.score_train))
                        axes.grid(True)
                        axes.legend(loc='upper right')

                    self.toGraph()
                    if self.isStop:
                        break
                    i += 1

                if i > self.spinIterations.value():
                    i = self.spinIterations.value()

                self.isStop = False
                self.comboOutput.setEnabled(True)
                self.widgetInputs.setEnabled(True)
                self.pushStart.setEnabled(True)
                self.pushSave.setEnabled(True)
                self.pushStop.setEnabled(False)
                self.spinIterations.setEnabled(True)
                self.spinLR.setEnabled(True)
                self.spinAlpha.setEnabled(True)
                QMessageBox.warning(self, 'Training Complete', 'Test Score={:.2f}, Training Score={:.2f} at {} Iterations'.format(
                    self.score_test, self.score_train, i))
            else:
                QMessageBox.warning(
                    self, 'No Xs inputs for Training', 'You Must Select at Least a X for Input')
        else:
            QMessageBox.warning(self, "Zeros Can't be Selected",
                                'Learning Rate, Alpha and Maximum Iterations must be greater than zero')

    def onStop(self):
        self.isStop = True
        self.pushStart.setEnabled(True)
        self.pushStop.setEnabled(False)
        self.pushSave.setEnabled(True)
        self.comboOutput.setEnabled(True)
        self.widgetInputs.setEnabled(True)

    def onSave(self):
        auxName = '{}'.format(self.Xn[0])
        i = 1
        while i < len(self.Xn):
            auxName = '{},{}'.format(auxName, self.Xn[i])
            i += 1
        path = QFileDialog.getExistingDirectory(
            self, 'Save Trained Neural Network', '')
        if path != '':
            path = '{}/{}-y={}-X={}-Test={:.2f}-Trainig={:.2f}.pkl'.format(
                path, self.filterTitle, self.yn, auxName, self.score_test, self.score_train)
            # save model
            joblib.dump(self.mlp, path)
            QMessageBox.warning(self, 'Trained Neural Network Saved',
                                'Trained Neural Network Saved at "{}"'.format(path))

    def getAxes(self, nRows, nColumns, position):
        return self.canvas.figure.add_subplot(nRows, nColumns, position)

    def toGraph(self):
        self.canvas.figure.tight_layout()
        self.canvas.draw()
        self.canvas.flush_events()
        pause(.000001)
Beispiel #4
0
class MatplotlibWidget(QWidget):

    def __init__(self, parent=None, axtype='', title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', showtoolbar=True, dpi=100, *args, **kwargs):
        super(MatplotlibWidget, self).__init__(parent)

        # initialize axes (ax) and plots (l)
        self.axtype = axtype
        self.ax = [] # list of axes 
        self.l = {} # all the plot stored in dict
        self.l['temp'] = [] # for temp lines in list
        self.txt = {} # all text in fig
        self.leg = '' # initiate legend 

        # set padding size
        if axtype == 'sp': 
            self.fig = Figure(tight_layout={'pad': 0.05}, dpi=dpi)
            # self.fig = Figure(tight_layout={'pad': 0.05}, dpi=dpi, facecolor='none')
        elif axtype == 'legend':
            self.fig = Figure(dpi=dpi, facecolor='none')

        else:
            self.fig = Figure(tight_layout={'pad': 0.2}, dpi=dpi)
            # self.fig = Figure(tight_layout={'pad': 0.2}, dpi=dpi, facecolor='none')
        ### set figure background transparsent
        # self.setStyleSheet("background: transparent;")

        # FigureCanvas.__init__(self, fig)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setSizePolicy(QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        self.canvas.setFocusPolicy(Qt.ClickFocus)
        self.canvas.setFocus()
        # connect with resize function
        # self.canvas.mpl_connect("resize_event", self.resize)

        # layout
        self.vbox = QVBoxLayout()
        self.vbox.setContentsMargins(0, 0, 0, 0) # set layout margins
        self.vbox.addWidget(self.canvas)
        self.setLayout(self.vbox)

        # add toolbar and buttons given by showtoolbar
        if isinstance(showtoolbar, tuple):
            class NavigationToolbar(NavigationToolbar2QT):
                toolitems = [t for t in NavigationToolbar2QT.toolitems if t[0] in showtoolbar]
        else:
            class NavigationToolbar(NavigationToolbar2QT):
                pass                    

        self.toolbar = NavigationToolbar(self.canvas, self)
        self.toolbar.setMaximumHeight(config_default['max_mpl_toolbar_height'])
        self.toolbar.setStyleSheet("QToolBar { border: 0px;}")
        # if isinstance(showtoolbar, tuple):
        #     logger.info(self.toolbar.toolitems) 
        #     NavigationToolbar.toolitems = (t for t in NavigationToolbar2QT.toolitems if t[0] in showtoolbar)
        #     logger.info(self.toolbar.toolitems) 

        # self.toolbar.hide() # hide toolbar (or setHidden(bool))
        self.toolbar.isMovable()
        if showtoolbar:
            self.vbox.addWidget(self.toolbar)
            if self.axtype == 'sp_fit':
                self.toolbar.press_zoom = types.MethodType(press_zoomX, self.toolbar)
        else:
            # pass
            self.toolbar.hide() # hide toolbar. remove this will make every figure with shwotoolbar = False show tiny short toolbar 
        
       
        self.initial_axes(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale='linear')

        # set figure border
        # self.resize('draw_event')
        # if axtype == 'sp':
        #     plt.tight_layout(pad=1.08)
        # else:
        #     self.fig.subplots_adjust(left=0.12, bottom=0.13, right=.97, top=.98, wspace=0, hspace=0)
        #     # self.fig.tight_layout()
        #     # self.fig.tight_layout(pad=0.5, h_pad=0, w_pad=0, rect=(0, 0, 1, 1))

        self.canvas_draw()
    
             
        # self.fig.set_constrained_layout_pads(w_pad=0., h_pad=0., hspace=0., wspace=0.) # for python >= 3.6
        # self.fig.tight_layout()


    def initax_xy(self, *args, **kwargs):
        # axes
        ax1 = self.fig.add_subplot(111, facecolor='none')
        # ax1 = self.fig.add_subplot(111)

        # if self.axtype == 'sp_fit':
        #     # setattr(ax1, 'drag_pan', AxesLockY.drag_pan)
        #     ax1 = self.fig.add_subplot(111, facecolor='none', projection='AxesLockY')
        # else:
        #     ax1 = self.fig.add_subplot(111, facecolor='none')

        # ax1.autoscale()
        # logger.info(ax.format_coord) 
        # logger.info(ax.format_cursor_data) 
        # plt.tight_layout()
        # plt.tight_layout(pad=None, w_pad=None, h_pad=None,  rect=None)
        
        # append to list
        self.ax.append(ax1)


    def initax_xyy(self):
        '''
        input: ax1 
        output: [ax1, ax2]
        '''
        self.initax_xy()
        
        ax2 = self.ax[0].twinx()
        self.ax[0].set_zorder(self.ax[0].get_zorder()+1)

        ax2.tick_params(axis='y', labelcolor=color[1], color=color[1])
        ax2.yaxis.label.set_color(color[1])
        ax2.spines['right'].set_color(color[1])
        # ax2.autoscale()
        ax2.spines['left'].set_visible(False)


        # change axes color
        self.ax[0].tick_params(axis='y', labelcolor=color[0], color=color[0])
        self.ax[0].yaxis.label.set_color(color[0])
        self.ax[0].spines['left'].set_color(color[0])
        self.ax[0].spines['right'].set_visible(False)

        # append ax2 to self.ax
        self.ax.append(ax2)


    def init_sp(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the sp[n]
        initialize ax[0]: .lG, .lGfit, .lp, .lpfit plot
        initialize ax[1]: .lB, .lBfit, plot
        '''

        self.initax_xyy()

        self.ax[0].margins(x=0)
        self.ax[1].margins(x=0)
        self.ax[0].margins(y=.05)
        self.ax[1].margins(y=.05)

        # self.ax[0].autoscale()
        # self.ax[1].autoscale()

        self.l['lG'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G
        self.l['lB'] = self.ax[1].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[1]
        ) # B
        # self.l['lGpre'] = self.ax[0].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous G
        # self.l['lBpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous B
        # self.l['lPpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous polar
        self.l['lGfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # G fit
        self.l['lBfit'] = self.ax[1].plot(
            [], [], 
            color='k'
        ) # B fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['ltollb'] = self.ax[0].plot(
            [], [],
            linestyle='--',
            color='k'
        ) # tolerance interval lines
        self.l['ltolub'] = self.ax[0].plot(
            [], [],
            linestyle='--',
            color='k'
        ) # tolerance interval lines


        self.l['lP'] = self.ax[0].plot(
            [], [],
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # polar plot
        self.l['lPfit'] = self.ax[0].plot(
            [], [],
            color='k'
        ) # polar plot fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['lsp'] = self.ax[0].plot(
            [], [],
            color=color[2]
        ) # peak freq span


        # self.ax[0].xaxis.set_major_locator(plt.AutoLocator())
        # self.ax[0].xaxis.set_major_locator(plt.LinearLocator())
        # self.ax[0].xaxis.set_major_locator(plt.MaxNLocator(3))

        # add text
        self.txt['sp_harm'] = self.fig.text(0.01, 0.98, '', va='top',ha='left') # option: weight='bold'
        self.txt['chi'] = self.fig.text(0.01, 0.01, '', va='bottom',ha='left')

        # set label of ax[1]
        self.set_ax(self.ax[0], title=title, xlabel=r'$f$ (Hz)',ylabel=r'$G_P$ (mS)')
        self.set_ax(self.ax[1], xlabel=r'$f$ (Hz)',ylabel=r'$B_P$ (mS)')

        self.ax[0].xaxis.set_major_locator(ticker.LinearLocator(3))



    def update_sp_text_harm(self, harm):
        if isinstance(harm, int):
            harm = str(harm)
        self.txt['sp_harm'].set_text(harm)


    def update_sp_text_chi(self, chi=None):
        '''
        chi: number
        '''
        if not chi:
            self.txt['chi'].set_text('')
        else:
            self.txt['chi'].set_text(r'$\chi^2$ = {:.4f}'.format(chi))


    def init_sp_fit(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the spectra fit
        initialize .lG, .lB, .lGfit, .lBfit .lf, .lg plot
        '''
        self.initax_xyy()

        self.l['lG'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G
        self.l['lB'] = self.ax[1].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[1]
        ) # B
        # self.l['lGpre'] = self.ax[0].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous G
        # self.l['lBpre'] = self.ax[1].plot(
        #     [], [], 
        #     marker='.', 
        #     linestyle='none',
        #     markerfacecolor='none', 
        #     color='gray'
        # ) # previous B
        self.l['lGfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # G fit
        self.l['lBfit'] = self.ax[1].plot(
            [], [], 
            color='k'
        ) # B fit
        self.l['strk'] = self.ax[0].plot(
            [], [],
            marker='+',
            linestyle='none',
            color='r'
        ) # center of tracking peak
        self.l['srec'] = self.ax[0].plot(
            [], [],
            marker='x',
            linestyle='none',
            color='g'
        ) # center of recording peak

        self.l['lsp'] = self.ax[0].plot(
            [], [],
            color=color[2]
        ) # peak freq span

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$f$ (Hz)',ylabel=r'$G_P$ (mS)')
        self.set_ax(self.ax[1], xlabel=r'$f$ (Hz)',ylabel=r'$B_P$ (mS)')

        self.ax[0].xaxis.set_major_locator(ticker.LinearLocator(3))

        # self.ax[0].xaxis.set_major_locator(plt.AutoLocator())
        # self.ax[0].xaxis.set_major_locator(plt.LinearLocator())
        # self.ax[0].xaxis.set_major_locator(plt.MaxNLocator(3))

        self.ax[0].margins(x=0)
        self.ax[1].margins(x=0)
        self.ax[0].margins(y=.05)
        self.ax[1].margins(y=.05)
        # self.ax[1].sharex = self.ax[0]

        # self.ax[0].autoscale()
        # self.ax[1].autoscale()

        # add span selector
        self.span_selector_zoomin = SpanSelector(
            self.ax[0], 
            self.sp_spanselect_zoomin_callback,
            direction='horizontal', 
            useblit=True,
            button=[1],  # left click
            minspan=5,
            span_stays=False,
            rectprops=dict(facecolor='red', alpha=0.2)
        )        

        self.span_selector_zoomout = SpanSelector(
            self.ax[0], 
            self.sp_spanselect_zoomout_callback,
            direction='horizontal', 
            useblit=True,
            button=[3],  # right
            minspan=5,
            span_stays=False,
            rectprops=dict(facecolor='blue', alpha=0.2)
        )        


    def sp_spanselect_zoomin_callback(self, xclick, xrelease): 
        '''
        callback of span_selector
        '''

        self.ax[0].set_xlim(xclick, xrelease)


    def sp_spanselect_zoomout_callback(self, xclick, xrelease): 
        '''
        callback of span_selector
        '''
        curr_f1, curr_f2 = self.ax[0].get_xlim()
        curr_fc, curr_fs = UIModules.converter_startstop_to_centerspan(curr_f1, curr_f2)
        # selected range
        sel_fc, sel_fs = UIModules.converter_startstop_to_centerspan(xclick, xrelease)
        # calculate the new span
        ratio = curr_fs / sel_fs
        new_fs = curr_fs * ratio
        new_fc = curr_fc * (1 + ratio) - sel_fc * ratio
        # center/span to f1/f2
        new_f1, new_f2 = UIModules.converter_centerspan_to_startstop(new_fc, new_fs)
        # logger.info('curr_fs %s', curr_fs) 
        # logger.info('sel_fs %s', sel_fs) 
        # logger.info('new_fs %s', new_fs) 
        # logger.info('curr %s %s', curr_f1, curr_f2) 
        # logger.info('new %s', new_f1, new_f2) 
        # set new xlim
        self.ax[0].set_xlim(new_f1, new_f2)
       

    def init_sp_polar(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the spectra polar
        initialize plot: l['l'], l['lfit']
        '''
        self.initax_xy()

        self.l['l'] = self.ax[0].plot(
            [], [], 
            marker='.', 
            linestyle='none',
            markerfacecolor='none', 
            color=color[0]
        ) # G vs. B
        self.l['lfit'] = self.ax[0].plot(
            [], [], 
            color='k'
        ) # fit

        self.l['lsp'] = self.ax[0].plot(
            [], [], 
            color=color[2]
        ) # fit in span range

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$G_P$ (mS)',ylabel=r'$B_P$ (mS)')

        # self.ax[0].autoscale()
        self.ax[0].set_aspect('equal')


    def init_data(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the mpl_plt1 & mpl_plt2
        initialize plot: 
            .l<nharm> 
            .lm<nharm>
        '''
        self.sel_mode = 'none' # 'none', 'selector', 'picker'

        self.initax_xy()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l
        
        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['lm' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                linestyle='none',
                picker=True,
                pickradius=5, # 5 points tolerance
                label='lm'+str(i),
                alpha=0.75,
            ) # marked points of line

            self.l['lt' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                linestyle='none',
                # picker=None, # picker is off by default
                label='lt'+str(i),
                alpha=0.35,
            ) # temperary points of line
            
            self.l['ls' + str(i)] = self.ax[0].plot(
                [], [], 
                marker='o', 
                markeredgecolor=color[1], 
                markerfacecolor=color[1],
                alpha= 0.5,
                linestyle='none',
                label='ls'+str(i),
            ) # points in rectangle_selector

        self.l['lp'] = self.ax[0].plot(
            [], [], 
            marker='+', 
            markersize = 12,
            markeredgecolor=color[1],
            markeredgewidth=1, 
            alpha= 1,
            linestyle='none',
            label='',
        ) # points of picker
        
        self.cid = None # for storing the cid of pick_event

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel='Time (s)',ylabel=ylabel)

        # self.ax[0].autoscale()

        # add rectangle_selector
        self.rect_selector = RectangleSelector(
            self.ax[0], 
            self.data_rectselector_callback,
            drawtype='box',
            button=[1], # left
            useblit=True,
            minspanx=5,
            minspany=5,
            # lineprops=None,
            rectprops=dict(edgecolor = 'black', facecolor='none', alpha=0.2, fill=False),
            spancoords='pixels', # default 'data'
            maxdist=10,
            marker_props=None,
            interactive=False, # change rect after drawn
            state_modifier_keys=None,
        )  
        # set if inavtive
        self.rect_selector.set_active(False)

        # create a toggle button fro selector
        self.pushButton_selectorswitch = QPushButton()
        self.pushButton_selectorswitch.setText('')
        self.pushButton_selectorswitch.setCheckable(True)
        self.pushButton_selectorswitch.setFlat(True)
        # icon
        icon_sel = QIcon()
        icon_sel.addPixmap(QPixmap(':/button/rc/selector.svg'), QIcon.Normal, QIcon.Off)
        self.pushButton_selectorswitch.setIcon(icon_sel)
        
        self.pushButton_selectorswitch.clicked.connect(self.data_rectsleector_picker_switch)

        # add it to toolbar
        self.toolbar.addWidget(self.pushButton_selectorswitch)
        toolbar_children = self.toolbar.children()
        # toolbar_children.insert(6, toolbar_children.pop(-1)) # this does not move the icon position
        toolbar_children[4].clicked.connect(self.data_show_all) # 4 is the home button

        # NOTE below does not work
        # add selector switch button to toolbar
        # self.fig.canvas.manager.toolmanager.add_tool('Data Selector', SelectorSwitch, selector=self.rect_selector)

        # add button to toolbar
        # self.canvas.manager.toolbar.add_tool(self.fig.canvas.manager.toolmanager.get_tool('DataSelector'), 'toolgroup')


    def data_rectsleector_picker_switch(self, checked):
        if checked:
            logger.info(True) 
            # active rectangle selector
            self.rect_selector.set_active(True)
            # connect pick event
            self.cid = self.canvas.mpl_connect('pick_event', self.data_onpick_callback)

            logger.info('%s', self.toolbar.mode)

            if self.toolbar.mode == "pan/zoom":
                self.toolbar.pan()
            elif self.toolbar.mode == "zoom rect":
                self.toolbar.zoom()

            # below works only for matplotlib < 3.3
            # if self.toolbar.mode == "PAN": # matplotlib < 3.3
            #     self.toolbar.pan()
            # elif self.toolbar.mode == "ZOOM": # matplotlib < 3.3
            #     self.toolbar.zoom()
        else:

            # deactive rectangle selector
            self.rect_selector.set_active(False)
            # reset .l['ls<n>']
            self.clr_lines(l_list=['ls'+ str(i) for i in range(1, int(config_default['max_harmonic']+2), 2)])
            # deactive pick event            
            self.canvas.mpl_disconnect(self.cid)
            # clear data .l['lp']
            self.clr_lines(l_list=['lp'])

            # reset
            self.sel_mode = 'none'


    def data_rectselector_callback(self, eclick, erelease):
        '''
        when rectangle selector is active
        '''
        # clear pick data .l['lp']
        self.clr_lines(l_list=['lp'])

        # logger.info(dir(eclick)) 
        logger.info(eclick) 
        logger.info(erelease) 
        x1, x2 = sorted([eclick.xdata, erelease.xdata]) # x1 < x2
        y1, y2 = sorted([eclick.ydata, erelease.ydata]) # y1 < y2
        
        # # dict for storing the selected indices
        # sel_idx_dict = {}
        # list for updating selected data
        sel_list = []
        # find the points in rect
        for l_str in ['l', 'lm']: # only one will be not empty
            for harm in range(1, config_default['max_harmonic']+2, 2):
                harm = str(harm)
                # logger.info(harm) 
                # get data from current plotted lines
                # clear .l['ls<n>']
                self.clr_lines(l_list=['ls'+harm])
                
                # logger.info(l_str) 
                # logger.info(self.l[l_str + harm][0].get_data()) 
                harm_x, harm_y = self.l[l_str + harm][0].get_data()
                
                if isinstance(harm_x, pd.Series): # if data is series (not empty)
                    sel_bool = harm_x.between(x1, x2) & harm_y.between(y1, y2)

                    # save data for plotting selected data
                    sel_list.append({'ln': 'ls'+harm, 'x': harm_x[sel_bool], 'y': harm_y[sel_bool]})

                    # # save indices for later process
                    # sel_idx = harm_x[sel_bool].index
                    # logger.info(sel_idx) 
                    # # update selected indices
                    # sel_idx_dict[harm] = sel_idx
            if (l_str == 'l') and sel_list: # UI mode showall
                break
                #TODO It can also set the display mode from UI (showall/showmarked) and do the loop by the mode

        if sel_list: # there is data selected
            self.sel_mode = 'selector'
        else:
            self.sel_mode = 'none'

        # logger.info(sel_list) 
        # plot the selected data
        self.update_data(*sel_list)


    def data_onpick_callback(self, event):
        '''
        callback function of mpl_data pick_event
        '''
        # clear selector data .l['ls<n>']
        self.clr_lines(l_list=['ls'+ str(i) for i in range(1, int(config_default['max_harmonic']+2), 2)])

        # logger.info(dir(event)) 
        thisline = event.artist
        x_p = thisline.get_xdata()
        y_p = thisline.get_ydata()
        ind = event.ind[0]
        logger.info(thisline) 
        # logger.info(dir(thisline)) 
        logger.info(thisline.get_label()) 
        logger.info(x_p.name) 
        logger.info(y_p.name) 
        logger.info(ind) 
        # logger.info('onpick1 line: %s %s', zip(np.take(xdata, ind), np.take(ydata, ind))) 

        # plot
        logger.info('x_p %s', x_p) 
        logger.info('%s %s', x_p.iloc[ind], y_p.iloc[ind]) 
        self.l['lp'][0].set_data(x_p.iloc[ind], y_p.iloc[ind])
        self.l['lp'][0].set_label(thisline.get_label() + '_' + str(ind)) # transfer the label of picked line and ind to 'lp'
        self.canvas_draw()

        # set
        self.sel_mode = 'picker'


    def init_contour(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize the mechanics_contour1 & mechanics_contour2
        initialize plot: 
            .l['C'] (contour) 
            .l['cbar'] (colorbar)
            .l['l<n>]
            .l['lm<n>]
        NOTE: this function should be used every time contour is changed
        '''
        logger.info("self.l.get('C'): %s", self.l.get('C')) 
        logger.info('kwargs: %s', kwargs.keys()) 

        # if self.l.get('colorbar'):
        #     self.l['colorbar'].remove()
        if self.ax:
            self.ax[0].cla()
            self.ax[1].clear()
        else:
            self.initax_xy()
            # create axes for the colorbar
            self.ax.append(make_axes_locatable(self.ax[0]).append_axes("right", size="5%", pad="2%"))

        if not 'X' in kwargs or not 'Y' in kwargs or not 'Z' in kwargs:
            num = config_default['contour_array']['num']
            phi_lim = config_default['contour_array']['phi_lim']
            dlam_lim = config_default['contour_array']['dlam_lim']

            # initiate X, Y, Z data
            x = np.linspace(phi_lim[0], phi_lim[1], num=num)
            y = np.linspace(dlam_lim[0], dlam_lim[1], num=num)
            X, Y = np.meshgrid(y, x)
            Z = np.random.rand(*X.shape)
        else:
            X = kwargs.get('X')
            Y = kwargs.get('Y')
            Z = kwargs.get('Z')

        if 'levels' in kwargs:
            levels = kwargs.get('levels')
        else:
            levels = config_default['contour_array']['levels']
            levels = np.linspace(np.min(Z), np.max(Z), levels)
        
        if 'cmap' in kwargs:
            cmap = kwargs.get('cmap')
        else:
            logger.info('cmap not in kwargs') 
            cmap = config_default['contour_array']['cmap']
        
        logger.info('levels %s', type(levels), ) 
        self.l['C'] = self.ax[0].contourf(
            X, Y, Z, # X, Y, Z
            levels=levels, 
            cmap=cmap,
        ) # contour

        self.l['colorbar'] = plt.colorbar(self.l['C'], cax=self.ax[1]) # colorbar
        self.l['colorbar'].locator = ticker.MaxNLocator(nbins=6)
        self.l['colorbar'].update_ticks()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                # marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['p' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                xerr=np.nan,
                yerr=np.nan,
                marker='o', 
                markerfacecolor='none', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['pm' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                yerr=np.nan,
                xerr=np.nan,
                marker='o', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop marked   

        # set label of ax[1]
        self.set_ax(self.ax[0], xlabel=r'$d/\lambda$',ylabel=r'$\Phi$ ($\degree$)', title=title)

        self.canvas_draw()
        self.ax[0].autoscale(enable=False)


    def init_legendfig(self, *args, **kwargs):
        ''' 
        plot a figure with only legend
        '''
        self.initax_xy()

        for i in range(1, config_default['max_harmonic']+2, 2):
            l = self.ax[0].plot([], [], label=i) # l[i]
        self.leg = self.fig.legend(
            # handles=l,
            # labels=range(1, config_default['max_harmonic']+2, 2),
            loc='upper center', 
            bbox_to_anchor=(0.5, 1),
            borderaxespad=0.,
            borderpad=0.,
            ncol=int((config_default['max_harmonic']+1)/2), 
            frameon=False, 
            facecolor='none',
            labelspacing=0.0, 
            columnspacing=0.5
        )
        self.canvas_draw()
        
        # set label of ax[1]
        self.set_ax(self.ax[0], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs)

        self.ax[0].set_axis_off() # turn off the axis

        # logger.info(dir(self.leg)) 
        # p = self.leg.get_window_extent() #Bbox of legend
        # # set window height
        # dpi = self.fig.get_dpi()
        # # logger.info(dir(self.fig)) 
        # fsize = self.fig.get_figheight()
 

    def init_prop(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        initialize property plot
        initialize plot: 
            .l<ln>
            ln = l, lm, p, pm 
        '''
        self.initax_xy()

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['l' + str(i)] = self.ax[0].plot(
                [], [], 
                # marker='o', 
                markerfacecolor='none', 
                picker=True,
                pickradius=5, # 5 points tolerance
                label='l'+str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
            ) # l
        
        # for i in range(1, int(config_default['max_harmonic']+2), 2):
        #     self.l['lm' + str(i)] = self.ax[0].plot(
        #         [], [], 
        #         # marker='o', 
        #         color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
        #         pickradius=5, # 5 points tolerance
        #         label='lm'+str(i),
        #         alpha=0.75,
        #     ) # maked points of line

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['p' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                xerr=np.nan,
                yerr=np.nan,
                marker='o', 
                markerfacecolor='none', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop

        for i in range(1, int(config_default['max_harmonic']+2), 2):
            self.l['pm' + str(i)] = self.ax[0].errorbar(
                np.nan, np.nan, # Note: for matplotlib >= 3.3, can't be [], which will make caplines () 
                yerr=np.nan,
                xerr=np.nan,
                marker='o', 
                linestyle='none',
                color=self.l['l' + str(i)][0].get_color(), # set the same color as .l
                # picker=True,
                # pickradius=5, # 5 points tolerance
                label=str(i),
                alpha=0.75, # TODO markerfacecolor becomes dark on Linux when alpha used
                capsize=config_default['mpl_capsize'],
            ) # prop marked
        
        # set label of ax[1]
        self.set_ax(self.ax[0], title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)

    # def sizeHint(self):
    #     return QSize(*self.get_width_height())

    # def minimumSizeHint(self):
    #     return QSize(10, 10)

    def set_ax(self, ax, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        self.set_ax_items(ax, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        self.set_ax_font(ax)


    def set_ax_items(self, ax, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        if title:
            ax.set_title(title)
        if xlabel:
            ax.set_xlabel(xlabel)
        if ylabel:
            ax.set_ylabel(ylabel)

        if xscale is not None:
            ax.set_xscale(xscale)
        if yscale is not None:
            ax.set_yscale(yscale)
        if xlim is not None:
            ax.set_xlim(*xlim)
        if ylim is not None:
            ax.set_ylim(*ylim)


    def set_ax_font(self, ax, *args, **kwargs):
        if self.axtype == 'sp':
            fontsize = config_default['mpl_sp_fontsize']
            legfontsize = config_default['mpl_sp_legfontsize']
            txtfontsize = config_default['mpl_sp_txtfontsize']
        else:
            fontsize = config_default['mpl_fontsize']
            legfontsize = config_default['mpl_legfontsize']
            txtfontsize = config_default['mpl_txtfontsize']

        if self.axtype == 'contour':
            for ticklabel in self.l['colorbar'].ax.yaxis.get_ticklabels():
                ticklabel.set_size(fontsize)
            self.l['colorbar'].ax.yaxis.offsetText.set_size(fontsize)
            

        if self.axtype == 'legend':
            self.leg
            plt.setp(self.leg.get_texts(), fontsize=legfontsize) 
            
        ax.title.set_fontsize(fontsize+1)
        ax.xaxis.label.set_size(fontsize+1)
        ax.yaxis.label.set_size(fontsize+1)
        ax.tick_params(labelsize=fontsize)
        ax.xaxis.offsetText.set_size(fontsize)
        ax.yaxis.offsetText.set_size(fontsize)
        # ax.xaxis.set_major_locator(ticker.LinearLocator(3))

        # set text fontsize
        for key in self.txt.keys():
            if key == 'sp_harm':
                self.txt[key].set_fontsize(config_default['mpl_sp_harmfontsize'])
            else: # normal text
                self.txt[key].set_fontsize(txtfontsize)
        


    def resize(self, event):
        # on resize reposition the navigation toolbar to (0,0) of the axes.
        # require connect
        # self.canvas.mpl_connect("resize_event", self.resize)

        # borders = [60, 40, 10, 5] # left, bottom, right, top in px
        # figw, figh = self.fig.get_size_inches()
        # dpi = self.fig.dpi
        # borders = [
        #     borders[0] / int(figw * dpi), # left
        #     borders[1] / int(figh * dpi), # bottom
        #     (int(figw * dpi) - borders[2]) / int(figw * dpi), # right
        #     (int(figh * dpi) - borders[3]) / int(figh * dpi), # top
        # ]
        # logger.info('%s %s', figw, figh) 
        # logger.info(borders) 
        # self.fig.subplots_adjust(left=borders[0], bottom=borders[1], right=borders[2], top=borders[3], wspace=0, hspace=0)
        
        self.fig.tight_layout(pad=1.08)
        # x,y = self.ax[0].transAxes.transform((0,0))
        # logger.info('%s %s', x, y) 
        # figw, figh = self.fig.get_size_inches()
        # ynew = figh*self.fig.dpi-y - self.toolbar.frameGeometry().height()
        # self.toolbar.move(x,ynew)        

    def initial_axes(self, title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        '''
        intialize axes by axtype:
        'xy', 'xyy', 'sp', 'sp_fit', 'sp_polar', 'data', 'contour'
        '''
        if self.axtype == 'xy':
            self.initax_xy()
        elif self.axtype == 'xyy':
            self.initax_xyy()
        elif self.axtype == 'sp':
            self.init_sp(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'sp_fit':
            self.init_sp_fit(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'sp_polar':
            self.init_sp_polar(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'data':
            self.init_data(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        elif self.axtype == 'contour':
            self.init_contour(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale, *args, **kwargs)
        elif self.axtype == 'legend':
            self.init_legendfig()
        elif self.axtype == 'prop':
            self.init_prop(title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale)
        else:
            pass

    def update_data(self, *args):
        ''' 
        update data of given in args (dict)
        arg = {'ln':, 'x':, 'y':, 'xerr':, 'yerr':, 'label':,}
            ln: string of line name
            x : x data
            y : y data
            yerr: y error
        NOTE: don't use this func to update contour
        '''
        axs = set() # initialize a empty set
        
        for arg in args:
            keys = arg.keys()

            if ('xerr' in keys) or ('yerr' in keys): # errorbar with caps and barlines
                if isinstance(self.l[arg['ln']], ErrorbarContainer): # type match 
                    # logger.info(arg) 
                    # since we initialize the errorbar plots with xerr and yerr, we don't check if they exist here. If you want, use self.l[ln].has_yerr
                    ln = arg['ln'] 
                    x = arg['x'] 
                    y = arg['y'] 
                    xerr = arg['xerr'] 
                    yerr = arg['yerr']

                    line, caplines, barlinecols = self.l[ln]
                    line.set_data(x, y)
                    # Find the ending points of the errorbars 
                    error_positions = (x-xerr,y), (x+xerr,y), (x,y-yerr), (x,y+yerr) 
                    # Update the caplines 
                    for i, pos in enumerate(error_positions): 
                        # logger.info('i %s', i) 
                        # logger.info(caplines) 
                        # logger.info('caplines_len %s', len(caplines)) 
                        caplines[i].set_data(pos) 
                    # Update the error bars 
                    barlinecols[0].set_segments(zip(zip(x-xerr,y), zip(x+xerr,y))) 
                    barlinecols[1].set_segments(zip(zip(x,y-yerr), zip(x,y+yerr))) 
                    
                    # barlinecols[0].set_segments(
                    #     np.array([[x - xerr, y], 
                    #     [x + xerr, y]]).transpose((2, 0, 1))
                    # ) 
                    axs.add(line.axes)
            else: # not errorbar
                ln = arg['ln'] 
                x = arg['x'] 
                y = arg['y']
                # logger.info(len(x: %s), len(y)) 
                # self.l[ln][0].set_xdata(x)
                # self.l[ln][0].set_ydata(y)
                self.l[ln][0].set_data(x, y)
                axs.add(self.l[ln][0].axes)
            
            if 'label' in keys: # with label
                self.l[ln][0].set_label(arg['label'])

            # ax = self.l[ln][0].axes
            # axbackground = self.canvas.copy_from_bbox(ax.bbox)
            # logger.info(ax) 
            # self.canvas.restore_region(axbackground)
            # ax.draw_artist(self.l[ln][0])
            # self.canvas.blit(ax.bbox)

        for ax in axs:
            self.reset_ax_lim(ax)

        self.canvas_draw()


    def get_data(self, ls=[]):
        '''
        get data of given ls (lis of string)
        return a list of data with (x, y)
        '''
        data = []
        for l in ls:
            # xdata = self.l[l][0].get_xdata()
            # ydata = self.l[l][0].get_ydata()
            xdata, ydata = self.l[l][0].get_data()
            data.append((xdata, ydata))
        return data


    def del_templines(self, ax=None):
        ''' 
        del all temp lines .l['temp'][:] 
        '''
        if ax is None:
            ax = self.ax[0]

        # logger.info(ax.lines) 
        # logger.info('len temp %s', len(self.l['temp'])) 
        # logger.info('temp %s', self.l['temp']) 

        for l_temp in self.l['temp']:
            # logger.info('l_temp %s', l_temp) 
            ax.lines.remove(l_temp[0]) # remove from ax
        self.l['temp'] = [] # inintiate

        self.reset_ax_lim(ax)
        self.canvas_draw()


    def clr_all_lines(self):
        self.clr_lines()


    def clr_lines(self, l_list=None):
        ''' 
        clear all lines in .l (but not .l['temp'][:]) of key in l_list
        '''
        # logger.info(self.l) 
        for key in self.l:
            # logger.info(key) 
            if key not in ['temp', 'C', 'colorbar']:
                if  l_list is None or key in l_list: # clear all or key
                    # self.l[key][0].set_xdata([])
                    # self.l[key][0].set_ydata([])
                    
                    
                    if isinstance(self.l[key], ErrorbarContainer): # errorbar plot
                        logger.info('key: %s', key)
                        logger.info('len(self.l[key]): %s', len(self.l[key]))
                        # clear errorbar
                        line, caplines, barlinecols = self.l[key]

                        logger.info(line) 
                        logger.info(caplines) 
                        logger.info('caplines len %s', len(caplines))
                        logger.info(barlinecols) 
                        logger.info('barlinecols len %s', len(barlinecols))

                        line.set_data([], [])
                        error_positions = ([],[]), ([],[]), ([],[]), ([],[]) 
                        # Update the caplines 
                        for i, pos in enumerate(error_positions): 
                            logger.info('i %s', i) 
                            caplines[i].set_data(pos) 
                        # Update the error bars 
                        barlinecols[0].set_segments(zip(zip([],[]), zip([],[]))) 
                        barlinecols[1].set_segments(zip(zip([],[]), zip([],[]))) 
                    else:
                        self.l[key][0].set_data([], []) # line plot




                else:
                    pass
            elif key == 'temp':
                for ax in self.ax:
                    self.del_templines(ax=ax)

        logger.info('it is a contour: %s', 'C' in self.l) 
        if 'C' not in self.l: # not contour
            # we don't reset contour limit
            self.reset_ax_lim(ax)
        self.canvas_draw()


    def change_style(self, line_list, **kwargs):
        '''
        set style of artists in class
        artists: 'linestyle', 'markersize' etc. the same keywords as in matplotlib
        '''
        logger.info(line_list) 
        logger.info(self.l) 
        logger.info(self.l.keys()) 
        for key, val in kwargs.items():
            for l in line_list:
                eval("self.l['{0}'][0].set_{1}('{2}')".format(l, key, val))


    def new_plt(self, xdata=[], ydata=[], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs):
        ''' 
        plot data of in new plots 
        #TODO need to define xdata, ydata structure
        [[x1], [x2], ...] ?
        '''
        self.initax_xy()
        # set label of ax[1]
        self.set_ax(self.ax[0], title='', xlabel='', ylabel='', xlim=None, ylim=None, xscale='linear', yscale='linear', *args, **kwargs)

        # self.l = {}
        for i, x, y in enumerate(zip(xdata, ydata)):
            self.l[i] = self.ax[0].plot(
                x, y, 
                marker='o', 
                markerfacecolor='none', 
            ) # l[i]

        self.ax[0].autoscale()


    def add_temp_lines(self, ax=None, xlist=[], ylist=[], label_list=[]):
        '''
        add line in self.l['temp'][i]
        all the lines share the same xdata
        '''
        logger.info('add_temp_lines')
        if len(label_list) == 0:
            label_list = [''] * len(xlist) # make up a label_list with all ''
        for (x, y, label) in zip(xlist, ylist, label_list):
            # logger.info('len x: %s', len(x)) 
            # logger.info('len y: %s', len(y)) 
            # logger.info(x) 
            # logger.info(y) 
            
            if ax is None:
                ax = self.ax[0]

            self.l['temp'].append(ax.plot(
                x, y,
                linestyle='--',
                color=color[-1],
                )
            )

            if label:
                self.l['temp'][-1][0].set_label(label)
        
        self.canvas_draw()


    def canvas_draw(self):
        '''
        redraw canvas after data changed
        '''
        self.canvas.draw()
        # self.canvas.draw_idle()
        # self.canvas.draw_event()
        # self.canvas.draw_cursor()
        self.canvas.flush_events() # flush the GUI events 


    def data_show_all(self):
        for ax in self.ax:
            ax.set_autoscale_on(True) # this reactive autoscale which might be turnned of by zoom/pan
            self.reset_ax_lim(ax)


    def reset_ax_lim(self, ax):
        '''
        reset the lim of ax
        this change the display and where home button goes back to
        '''
        # turn off PAN/ZOOM
        # if self.toolbar._active == "PAN":
        #     self.toolbar.pan()
        # elif self.toolbar._active == "ZOOM":
        #     self.toolbar.zoom()
        
        # self.canvas.toolbar.update() # reset toolbar memory
        # self.canvas.toolbar.push_current() # set current to memory

        # ax.set_autoscale_on(True) # this reactive autoscale which might be turnned of by zoom/pan
        ax.relim(visible_only=True)
        ax.autoscale_view(True,True,True)
Beispiel #5
0
class figureToPlot(QtWidgets.QVBoxLayout):
    def __init__(self, parent=None):
        super(figureToPlot, self).__init__(parent)

        #super(figureToPlot, self).__init__(None)
        #self.parent = parent

        # a figure instance to plot on
        self.figure = Figure(frameon=True)
        self.canvas = FigureCanvas(self.figure)

        #self.canvas.setParent(self)
        self.draw_thread = False
        self.pausePlot = False
        self.plotting = False
        self.semaforoPlot = False
        self.dataRestarted = False

        self.fileList = QtWidgets.QComboBox()
        self.fileList.addItem("RAW_ACCELEROMETERS")
        self.fileList.addItem("RAW_GPS")
        self.fileList.addItem("PROC_LANE_DETECTION")
        self.fileList.addItem("PROC_VEHICLE_DETECTION")
        self.fileList.addItem("PROC_OPENSTREETMAP_DATA")
        self.fileList.setCurrentIndex(0)
        self.fileList.currentIndexChanged.connect(self.selectFileFromList)
        self.fileList.setFixedWidth(220)

        self.indexCol = 1
        self.spinSelectCol = QtWidgets.QSpinBox()
        self.spinSelectCol.valueChanged.connect(self.setIndexCol)
        self.spinSelectCol.setFixedWidth(45)

        self.columnInfo = QtWidgets.QLabel()

        self.hBoxFileCol = QtWidgets.QHBoxLayout()
        self.hBoxFileCol.addWidget(self.fileList)
        self.hBoxFileCol.addWidget(self.spinSelectCol)
        self.hBoxFileCol.addWidget(self.columnInfo)

        self.selectFileFromList()  #sets self.datafilename and calls loadData

        self.addLayout(self.hBoxFileCol)
        self.addWidget(self.canvas)

    def removeAll(self):
        self.fileList.deleteLater()
        self.spinSelectCol.deleteLater()
        self.hBoxFileCol.deleteLater()
        self.columnInfo.deleteLater()
        self.canvas.deleteLater()
        self.deleteLater()

    def loadData(self):
        try:
            self.data = np.genfromtxt(self.dataFileName,
                                      dtype=np.float,
                                      delimiter=' ')
            self.dataCorrect = True
            print(self.data.shape)
            self.spinSelectCol.setRange(1, self.data.shape[1] - 1)
            self.setIndexCol()  #necessary to restart datax and datay
        except:
            self.dataCorrect = False
        print(self.dataCorrect)

    def selectFileFromList(self):
        dataFile = str(self.fileList.currentText() + '.txt')
        self.dataFileName = dataFolderName + '/' + dataFile
        self.loadData()

    def setIndexCol(self):
        self.indexCol = self.spinSelectCol.value()
        if (self.dataCorrect):
            self.datax = self.data[:, 0]
            self.datay = self.data[:, self.indexCol]
            self.dataRestarted = True
        self.setColumnInfo()

    def setColumnInfo(
            self):  #def info shown depending on index and file selected
        f = self.fileList.currentText()
        i = self.indexCol
        if (f == 'RAW_ACCELEROMETERS'):
            if (i == 1):
                self.columnInfo.setText('Activation bool (1 if speed>50Km/h)')
            elif (i == 2):
                self.columnInfo.setText('X acceleration (Gs)')
            elif (i == 3):
                self.columnInfo.setText('Y acceleration (Gs)')
            elif (i == 4):
                self.columnInfo.setText('Z acceleration (Gs)')
            elif (i == 5):
                self.columnInfo.setText('X accel filtered by KF (Gs)')
            elif (i == 6):
                self.columnInfo.setText('Y accel filtered by KF (Gs)')
            elif (i == 7):
                self.columnInfo.setText('Z accel filtered by KF (Gs)')
            elif (i == 8):
                self.columnInfo.setText('Roll (degrees)')
            elif (i == 9):
                self.columnInfo.setText('Pitch (degrees)')
            elif (i == 10):
                self.columnInfo.setText('Yaw (degrees)')
        elif (f == 'RAW_GPS'):
            if (i == 1):
                self.columnInfo.setText('Speed (Km/h)')
            elif (i == 2):
                self.columnInfo.setText('Latitude')
            elif (i == 3):
                self.columnInfo.setText('Longitude')
            elif (i == 4):
                self.columnInfo.setText('Altitude')
            elif (i == 5):
                self.columnInfo.setText('Vertical accuracy')
            elif (i == 6):
                self.columnInfo.setText('Horizontal accuracy')
            elif (i == 7):
                self.columnInfo.setText('Course (degrees)')
            elif (i == 8):
                self.columnInfo.setText('Difcourse: course variation')
            elif (i == 9):
                self.columnInfo.setText('Position state [internal val]')
            elif (i == 10):
                self.columnInfo.setText('Lanex dist state [internal val]')
            elif (i == 11):
                self.columnInfo.setText('Lanex history [internal val]')
        elif (f == 'PROC_LANE_DETECTION'):
            if (i == 1):
                self.columnInfo.setText('Car pos. from lane center (meters)')
            elif (i == 2):
                self.columnInfo.setText('Phi')
            elif (i == 3):
                self.columnInfo.setText('Road width (meters)')
            elif (i == 4):
                self.columnInfo.setText('State of lane estimator')
        elif (f == 'PROC_VEHICLE_DETECTION'):
            if (i == 1):
                self.columnInfo.setText('Distance to ahead vehicle (meters)')
            elif (i == 2):
                self.columnInfo.setText('Impact time to ahead vehicle (secs.)')
            elif (i == 3):
                self.columnInfo.setText('Detected # of vehicles')
            elif (i == 4):
                self.columnInfo.setText('Gps speed (Km/h) [redundant val]')

        elif (f == 'PROC_OPENSTREETMAP_DATA'):
            if (i == 1):
                self.columnInfo.setText('Current road maxspeed')
            elif (i == 2):
                self.columnInfo.setText('Maxspeed reliability [Flag]')
            elif (i == 3):
                self.columnInfo.setText('Road type [graph not available]')
            elif (i == 4):
                self.columnInfo.setText('# of lanes in road')
            elif (i == 5):
                self.columnInfo.setText('Estimated current lane')
            elif (i == 6):
                self.columnInfo.setText('Latitude used to query OSM')
            elif (i == 7):
                self.columnInfo.setText('Longitude used to query OSM')
            elif (i == 8):
                self.columnInfo.setText('Delay answer OSM query (seconds)')
            elif (i == 9):
                self.columnInfo.setText('Speed (Km/h) [redundant val]')

    def getYaxisMinMax(self):
        maxY = np.amax(self.datay)
        minY = np.amin(self.datay)

        f = self.fileList.currentText()
        i = self.indexCol
        if (f == 'PROC_LANE_DETECTION'):
            if (i == 1):
                minY = -1.5
                maxY = 1.5
            elif (i == 2):
                minY = -2
                maxY = 2
            elif (i == 3):
                minY = 0
                maxY = 5
            elif (i == 4):
                minY = -1.5
                maxY = 2.5

        return (minY, maxY)

    def startPlot(self):
        if (self.dataCorrect):
            if (self.pausePlot):
                self.pausePlot = False
            else:
                self.plotting = True
                if (self.draw_thread == False):
                    #change:
                    self.draw_thread = threading.Thread(target=self.plot)
                    self.draw_thread.start()
                    #self.plotting = True
                    #self.plot()
                else:
                    self.draw_thread.stopped = False

    def stopPlot(self):
        self.plotting = False
        self.draw_thread = False
        self.dataRestarted = True
        self.pausePlot = False

    def plot(self):
        self.ax = self.figure.add_subplot(111)
        #self.figure.tight_layout()
        #self.ax.axis([0, timeWindow , 0, 1000])  #just to fix the padding calculation
        #self.figure.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)

        i = 0
        #indexPrev = 0
        while (self.plotting == True):

            while (self.pausePlot == True and self.plotting == True):
                time.sleep(0.005)

            if ((self.fileList.currentText() == "PROC_OPENSTREETMAP_DATA")
                    and (self.indexCol == 3)):
                #i = np.abs(self.datax - (currentSecond - delayVideoToData + 1/fps)).argmin()
                #self.columnInfo.setText('Road type: ' + str(self.datay[i]))
                time.sleep(0.033)
                continue

            if (self.dataRestarted == False
                ):  #if data was just reloaded wait to next loop to update i
                self.semaforoPlot = True

                self.ax.set_xlim([
                    currentSecond - delayVideoToData - timeWindow,
                    currentSecond - delayVideoToData
                ])
                self.canvas.draw(
                )  # refresh canvas  #ahora lo hacemos en main thread
                """
            #self.figure.draw_artist(self.figure.patch)
            self.ax.draw_artist(self.ax.patch)
            self.ax.draw_artist(self.plotter)
            #self.ax.draw_artist(self.ax.xaxis)
            #self.ax.draw_artist(self.ax.yaxis)
            #TODO: encontrar manera de reemplazar el draw, falta los ejes de la figura
            self.canvas.update()
            """

                self.canvas.flush_events(
                )  #LINEA IMPORTANTE! SINO NO ACTUALIZA EL QT LOOP, como el plt.pause hace con su plot
                #QtWidgets.QApplication.processEvents()
                #time.sleep(0.033)  #OLD: small value makes plots flicker
                time.sleep(
                    0.2
                )  #importante, sino esta ploteando rapido y hace flickering
                self.semaforoPlot = False
            else:
                self.plotter, = self.ax.plot(self.datax, self.datay,
                                             '-')  # plot data

                loc = plticker.MultipleLocator(base=float(5))
                self.ax.xaxis.set_major_locator(loc)
                loc2 = plticker.MultipleLocator(base=float(1))
                self.ax.xaxis.set_minor_locator(loc2)
                self.ax.grid(True)
                self.ax.grid(which='both')
                self.ax.grid(which='minor', alpha=0.5)
                self.ax.grid(which='major', alpha=1)
                self.miny, self.maxy = self.getYaxisMinMax()

                self.margen = (self.maxy - self.miny) / 20
                self.ax.axis([
                    currentSecond - delayVideoToData - timeWindow,
                    currentSecond - delayVideoToData, self.miny - self.margen,
                    self.maxy + self.margen
                ])  #[xmin, xmax, ymin, ymax]

                self.figure.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)

                #ani = animation.FuncAnimation(self.figure, self.animate, range(1, 200), interval=0, blit=False)
                #print('hasta aqui')
                #self.plotting = False
                #ani = animation.FuncAnimation(self.figure, self.animate, interval=50, blit=False)

                #self.canvas.draw()

                self.dataRestarted = False

#TODO: optimize loop: http://bastibe.de/2013-05-30-speeding-up-matplotlib.html

#http://stackoverflow.com/questions/8955869/why-is-plotting-with-matplotlib-so-slow

    def animate(self):
        self.ax.set_xlim([
            currentSecond - delayVideoToData - timeWindow,
            currentSecond - delayVideoToData
        ])
class PlotWrapper(QtWidgets.QWidget):
    def __init__(self, plotParams, parent=None):     
        super(PlotWrapper, self).__init__(parent)
        self.setWindowTitle('LSL Plot ' + plotParams["name"])
        self.figure = Figure()
        self.canvas = FigureCanvas(self.figure)
        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.canvas)
        self.setLayout(layout)
        self.ax = self.figure.add_subplot(111)
        self.curves = []
        self.markerList = []
        self.tableColorList = ['r', 'g', 'y', 'c', 'm', 'b']
        self.maxTimeRange = plotParams["max_time_range"]
        
        tick_pos = []
        tick_name = []
        for tick in plotParams["ticks"][0]:
            tick_pos.append(tick[0])
            tick_name.append(tick[1])
        
        self.ax.set(xlim=(-plotParams["init_time_range"], 0.0), ylim=(-1.0 * plotParams["chann_num"], 1.0), 
               xlabel='Time (s)', ylabel='Activation', yticks = tick_pos,
               yticklabels = tick_name);
        self.ax.set_title(plotParams["name"], y = 1.04)
        for ch_ix in range(plotParams["chann_num"]):
            line, = self.ax.plot([],[])
            self.curves.append(line)
        self.canvas.draw()
        self.canvas.flush_events()
        self.show()
        
    def getWindow(self):
        return self
    
    def updatePlotData(self, timeData, yData, scale):
        for ch in range(np.shape(yData)[0]):
            self.curves[ch].set_data(timeData, ((yData[ch,:])/scale)-np.shape(yData)[0]+ch+1)
        self.canvas.draw()
        self.canvas.flush_events()
            
    def addMarker(self, event, clock_val):
        eventParam = {}
        eventParam["ts"] = event[0]
        markerLine = self.ax.axvline(x=eventParam["ts"]-clock_val, color=self.tableColorList[event[1]])
        markerText = self.ax.text(eventParam["ts"]-clock_val, self.ax.get_ylim()[1]+0.1,
                        event[2], fontsize=9, color=self.tableColorList[event[1]],
                        horizontalalignment='center')
        eventParam["line"] = markerLine
        eventParam["text"] = markerText
        self.markerList.append(eventParam)
        #self.canvas.draw()
        #self.canvas.flush_events()
        
    def deleteMarker(self, markerNum):
        self.markerList[markerNum]["line"].remove()
        self.markerList[markerNum]["text"].remove()
        self.markerList.remove(self.markerList[markerNum])
        #self.canvas.draw()
        #self.canvas.flush_events()
        
    def setNewValMarker(self, markerNum, newVal, plot_duration):
        self.markerList[markerNum]["text"].set_x(newVal)
        if newVal < - plot_duration:
            self.markerList[markerNum]["text"].set_visible(False)
        else:
            self.markerList[markerNum]["text"].set_visible(True)
        self.markerList[markerNum]["line"].set_xdata((newVal,newVal))
        
    def decr_timerange(self):
        xlim_old = self.ax.get_xlim()
        xlim = (xlim_old[0]*0.9, 0.0)
        self.ax.set_xlim(xlim)
        return xlim[0]
    
    def incr_timerange(self):
        xlim_old = self.ax.get_xlim()
        newAxisVal = xlim_old[0]*1.1
        if newAxisVal < -self.maxTimeRange:
            newAxisVal = -self.maxTimeRange
        xlim = (newAxisVal, 0.0)
        self.ax.set_xlim(xlim)
        return xlim[0]
        
Beispiel #7
0
class LoadImage(QtWidgets.QMainWindow, Ui_MainWindow):
    """Load board images from the menu and the pins icon list."""
    def __init__(self):
        """Set up the user interference from the QT Designer.

        Initiate the menu bar and clicks on menu options.
        """
        super().__init__()
        """Set up the user interference from QT Designer."""
        self.setupUi(self)

        self.df_cls = None
        self.board_type = None
        self.board_label = None
        self.first_load = True
        self.first_resize = True
        self.resize_timer = None
        self.list_items = {}

        self.centralwidget.setDisabled(True)

        # Creating two menu options: Boards and Export
        self.file_menu = self.menuBar().addMenu('&File')
        self.boards_menu = self.menuBar().addMenu('&Boards')
        # self.export_menu = self.menuBar().addMenu('&Export')

        self.refresh_menus()

        self.export_action = self.file_menu.addAction('Save Board Label')
        self.run_macro_action = self.file_menu.addAction('Run the Macro File')
        self.exit_action = self.file_menu.addAction('Exit')
        self.exit_action.triggered.connect(QtWidgets.qApp.quit)

        # self.export_action = self.export_menu.addAction('Save Board Label')

    def refresh_menus(self):
        self.boards_menu.clear()

        unprocessed_menu = self.boards_menu.addMenu('Unprocessed')
        processed_menu = self.boards_menu.addMenu('Processed')

        config = ConfigParser()

        config.read('socketear.ini')

        file_path = os.path.normpath(config.get('path', 'dirpath'))

        # Going through the Board folder to create the appropriate actions
        for dirpath, dirnames, filenames in os.walk(
                os.path.join(file_path, 'Processed')):

            if 'pins' in dirnames and all(
                    fn in filenames for fn in
                ['classification.csv', 'stitched.png', 'info.txt']):
                # Add actions for the right path
                action_cb = processed_menu.addAction(
                    os.path.split(os.path.basename(dirpath))[-1])
                action_cb.triggered.connect(
                    functools.partial(self.load_image, dirpath))

        for dirpath, dirnames, filenames in os.walk(
                os.path.join(file_path, 'Unprocessed')):
            if all(fn.endswith(".tif") for fn in filenames) and not dirnames:

                board_name = os.path.split(os.path.split(dirpath)[0])[1]
                save_path = os.path.join(file_path, 'Processed')
                action_nb = unprocessed_menu.addAction(board_name)
                action_nb.triggered.connect(
                    functools.partial(self.load_model, dirpath, board_name,
                                      save_path))

    def remove_board(self):
        """Remove the layer of matplotlib container and widgetlist.

        Also, disconnects all the buttons.
        """
        self.mpl_vl.removeWidget(self.canvas)
        self.canvas.close()
        # self.mpl_vl.removeWidget(self.toolbar)
        # self.toolbar.close()
        self.mpl_figs.clear()
        self.mpl_figs.currentItemChanged.disconnect()
        self.nwo_button.disconnect()
        self.nco_button.disconnect()
        self.normal_button.disconnect()
        self.hnp_button.disconnect()
        self.sb_button.disconnect()
        self.type1_button.disconnect()
        self.type2_button.disconnect()
        self.type3_button.disconnect()
        self.type4_button.disconnect()
        self.reject_button.disconnect()
        # self.no_crack_button.disconnect()
        self.type_a_button.disconnect()
        self.type_b_button.disconnect()
        self.type_c_button.disconnect()
        self.type_d_button.disconnect()
        self.type_e_button.disconnect()
        self.refresh_sjr_button.disconnect()
        self.refresh_sjq_button.disconnect()
        self.sjr_image_button.disconnect()
        self.export_action.disconnect()
        self.export_action.setDisabled(True)

    def add_mpl(self):
        """Add a layer of matplotlib container."""

        self.figure = Figure(figsize=(5, 4), dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.mpl_vl.addWidget(self.canvas)
        self.canvas.draw()
        # self.toolbar = NavigationToolbar(self.canvas,
        #                                  self.mpl_window,
        #                                  coordinates=True)
        # self.addToolBar(self.toolbar)

    def button_clicked(self, value):
        """Change the radio button.

        Updates the CSV file after a radio button was clicked.
        """
        pin_index = self.mpl_figs.currentItem().data(32)

        if self.board_type == 'SJR':
            if isinstance(value, six.string_types):
                self.df_cls.set_value(pin_index, 'DYE Correction', value)

            elif isinstance(value, int):
                self.df_cls.set_value(pin_index, 'SJR Correction', value)

        elif self.board_type == 'SJQ':
            self.df_cls.set_value(pin_index, 'SJQ Correction', value)

        image_path = self.df_cls.loc[pin_index]['Image Path']

        csv_path = os.path.split(
            os.path.split(os.path.normpath(image_path))[0])[0]

        # Saving the CSV file
        header = [
            'Row', 'Col', 'StitchedX', 'StitchedY', 'Pin', 'SJQ', 'SJR', 'DYE',
            'SJQ Correction', 'SJR Correction', 'DYE Correction',
            'Type Change', 'Sorted SJQ', 'Image Path', 'Label Coords',
            'DYE Image Path'
        ]
        self.df_cls.to_csv(os.path.join(csv_path, 'classification.csv'),
                           columns=header)

    def nwo_button_clicked(self, enabled):
        """Activate the nwo button for SJQ."""
        if enabled:
            self.button_clicked(0)

    def nco_button_clicked(self, enabled):
        """Activate the nco button for SJQ."""
        if enabled:
            self.button_clicked(1)

    def normal_button_clicked(self, enabled):
        """Activate the normal button for SJQ."""
        if enabled:
            self.button_clicked(2)

    def hnp_button_clicked(self, enabled):
        """Activate the hnp button for SJQ."""
        if enabled:
            self.button_clicked(3)

    def sb_button_clicked(self, enabled):
        """Activate the sb button for SJQ."""
        if enabled:
            self.button_clicked(4)

    def type1_button_clicked(self, enabled):
        """Activate the type1 button for SJR."""
        if enabled:
            self.button_clicked(0)

    def type2_button_clicked(self, enabled):
        """Activate the type2 button for SJR."""
        if enabled:
            self.button_clicked(1)

    def type3_button_clicked(self, enabled):
        """Activate the normal button for SJR."""
        if enabled:
            self.button_clicked(2)

    def type4_button_clicked(self, enabled):
        """Activate the type4 button for SJR."""
        if enabled:
            self.button_clicked(3)

    def reject_button_clicked(self, enabled):
        """Activate the type4 button for SJR."""
        if enabled:
            self.button_clicked(4)

    # def no_crack_button_clicked(self, enabled):
    #     """Activate the no crack button for SJR."""
    #     if enabled:
    #         self.button_clicked('O')

    def type_a_button_clicked(self, enabled):
        """Activate the typeA button for SJR."""
        if enabled:
            self.button_clicked('A')

    def type_b_button_clicked(self, enabled):
        """Activate the typeB button for SJR."""
        if enabled:
            self.button_clicked('B')

    def type_c_button_clicked(self, enabled):
        """Activate the typeC button for SJR."""
        if enabled:
            self.button_clicked('C')

    def type_d_button_clicked(self, enabled):
        """Activate the typeD button for SJR."""
        if enabled:
            self.button_clicked('D')

    def type_e_button_clicked(self, enabled):
        """Activate the typeE button for SJR."""
        if enabled:
            self.button_clicked('E')

    def refresh(self, predictions, dye_predictions=None):

        # start_refresh = datetime.now()

        self.board_label = self.drawing_tool(shape=1000,
                                             pred=predictions,
                                             dyepred=dye_predictions)

        # print('remove scatters!')

        self.axes1_scatter.remove()
        self.axes2_scatter.remove()

        self.axes2_image = self.axes2.imshow(self.board_label)

        self.canvas.draw()

        self.background_board = self.canvas.copy_from_bbox(self.axes2.bbox)

        pin_index = self.mpl_figs.currentItem().data(32)

        self.load_pin_cls(pin_index)

        cx_image = self.df_cls.ix[pin_index, 'StitchedX']
        cy_image = self.df_cls.ix[pin_index, 'StitchedY']

        self.axes1_scatter = self.axes1.scatter(cy_image,
                                                cx_image,
                                                edgecolor='#ff01d0',
                                                marker='s',
                                                s=80,
                                                linewidth='2',
                                                facecolors='none')

        cx_label = self.df_cls.ix[pin_index, 'Label Coords'][0] * 1000
        cy_label = self.df_cls.ix[pin_index, 'Label Coords'][1] * 1000

        self.axes2_scatter = self.axes2.scatter(cy_label,
                                                cx_label,
                                                edgecolor='#ff01d0',
                                                marker='s',
                                                s=80,
                                                linewidth='2',
                                                facecolors='none')

        self.canvas.blit(self.axes1.bbox)
        self.canvas.blit(self.axes2.bbox)

        # print('timer canvas.draw: ', datetime.now() - start_refresh)
        # start_refresh = datetime.now()

        self.sort_classification()

        # print('timer sort_classification: ', datetime.now() - start_refresh)
        # start_refresh = datetime.now()

        self.pin_dividers()

        # print('timer pin_dividers: ', datetime.now() - start_refresh)

        self.mpl_figs.clear()

        # start_refresh = datetime.now()

        self.load_pins()

        # print('timer load_pins: ', datetime.now() - start_refresh)

        self.mpl_figs.setCurrentItem(self.mpl_figs.item(0))
        board_pin_info = self.board_info + "\nPin name: {0}".format(
            str(self.df_cls.ix[pin_index, 'Pin']))
        self.board_display_label.setText(board_pin_info)

    def refresh_sjq_clicked(self):
        """Redraw the label board for SJQ model."""

        # start_refresh = datetime.now()
        print('Starting to refresh...')

        predictions = self.df_cls.set_index('Pin')['SJQ Correction'].to_dict()
        dye_predictions = None

        self.refresh(predictions, dye_predictions)

        # print('start_refresh: ', datetime.now() - start_refresh)

    def refresh_sjr_clicked(self):
        """Redraw the label board for SJQ model."""

        # start_refresh = datetime.now()
        print('Starting to refresh...')

        predictions = self.df_cls.set_index('Pin')['SJR Correction'].to_dict()
        dye_predictions = self.df_cls.set_index(
            'Pin')['DYE Correction'].to_dict()

        self.refresh(predictions, dye_predictions)

        # print('start_refresh: ', datetime.now() - start_refresh)

    def sjr_image_switch(self):
        """Switch the pin image and dyed pin image for SJR models."""

        pin_index = self.mpl_figs.currentItem().data(32)

        if self.pin_dye_image == self.df_cls.loc[pin_index, 'Image Path']:
            self.pin_dye_image = self.df_cls.loc[pin_index, 'DYE Image Path']
            self.pin_image.setPixmap(
                QtGui.QPixmap(QtGui.QImage(self.pin_dye_image)))

        elif self.pin_dye_image == self.df_cls.loc[pin_index,
                                                   'DYE Image Path']:
            self.pin_dye_image = self.df_cls.loc[pin_index, 'Image Path']
            self.pin_image.setPixmap(
                QtGui.QPixmap(QtGui.QImage(self.pin_dye_image)))

    def on_click(self, event):

        if not event.inaxes:
            print('Clicked outside axes bounds but inside plot window')
            return

        if event.inaxes == self.axes1:

            self.df_cls['Image Coords'] = self.df_cls[[
                'StitchedX', 'StitchedY'
            ]].apply(tuple, axis=1)
            z = self.df_cls[['StitchedX', 'StitchedY']].values
            distance = np.linalg.norm(
                (z - np.array([event.ydata, event.xdata])), axis=1)
            row_index = np.argmin(distance)

            current_item = None
            for index in range(self.mpl_figs.count()):
                item = self.mpl_figs.item(index)
                if item.data(32) == row_index:
                    current_item = item

            if current_item:
                self.mpl_figs.setCurrentItem(current_item)
                pin_index = current_item.data(32)
                self.load_pin_cls(pin_index)
            else:
                print('No item found for row {}'.format(row_index))

        if event.inaxes == self.axes2:
            z = np.zeros(shape=(len(self.guide.pins), 2))

            for j, pin in enumerate(self.guide.pins):
                z[j] = 1000 * np.array(self.guide.position(pin))

            distance = np.linalg.norm(
                (z - np.array([event.ydata, event.xdata])), axis=1)

            k = np.argmin(distance)

            pin_name = self.guide.pins[k]
            row = self.df_cls.Pin[self.df_cls.Pin ==
                                  pin_name].index.tolist()[0]

            current_item = None
            for index in range(self.mpl_figs.count()):
                item = self.mpl_figs.item(index)
                if item.data(32) == row:
                    current_item = item

            if current_item:
                self.mpl_figs.setCurrentItem(current_item)
                pin_index = current_item.data(32)
                self.load_pin_cls(pin_index)
            else:
                print('No item found for row {}'.format(row))

    def current_item_changed(self):
        """Update the zoom-in pin image and the board image position.

        Based on the items in the scrollable list(WidgetListItem).
        """
        if self.mpl_figs.currentItem():

            pin_index = self.mpl_figs.currentItem().data(32)

            self.load_pin_cls(pin_index)
            cx_label = self.df_cls.ix[pin_index, 'Label Coords'][0] * 1000
            cy_label = self.df_cls.ix[pin_index, 'Label Coords'][1] * 1000

            try:
                self.axes2_scatter.remove()
            except:
                pass

            self.axes2_scatter = self.axes2.scatter(cy_label,
                                                    cx_label,
                                                    edgecolor='#ff01d0',
                                                    marker='s',
                                                    s=80,
                                                    linewidth='2',
                                                    facecolors='none')

            cx_image = self.df_cls.ix[pin_index, 'StitchedX']
            cy_image = self.df_cls.ix[pin_index, 'StitchedY']

            try:
                self.axes1_scatter.remove()
            except:
                pass

            self.axes1_scatter = self.axes1.scatter(cy_image,
                                                    cx_image,
                                                    edgecolor='#ff01d0',
                                                    marker='s',
                                                    s=80,
                                                    linewidth='2',
                                                    facecolors='none')

            start_canvas = datetime.now()

            self.canvas.restore_region(self.background_image)
            self.canvas.restore_region(self.background_board)
            self.axes1.draw_artist(self.axes1_scatter)
            self.axes2.draw_artist(self.axes2_scatter)
            self.canvas.blit(self.axes1.bbox)
            self.canvas.blit(self.axes2.bbox)

            self.canvas.flush_events()
            # print('canvas: ', datetime.now() - start_canvas)
            board_pin_info = self.board_info + "\nPin name: {0}".format(
                str(self.df_cls.ix[pin_index, 'Pin']))
            self.board_display_label.setText(board_pin_info)

    def load_pin_cls(self, pin_index):
        """Load the zoom-in pin image.

        And activate the appropriate radio button.
        """

        self.pin_dye_image = self.df_cls.ix[pin_index, 'Image Path']
        self.pin_image.setScaledContents(True)
        self.pin_image.setPixmap(
            QtGui.QPixmap(QtGui.QImage(self.pin_dye_image)))

        if self.board_type == 'SJR':

            classification = self.df_cls.ix[pin_index, 'SJR Correction']
            dye_cls = self.df_cls.ix[pin_index, 'DYE Correction']

            if classification == 0:
                self.type1_button.setChecked(True)
            elif classification == 1:
                self.type2_button.setChecked(True)
            elif classification == 2:
                self.type3_button.setChecked(True)
            elif classification == 3:
                self.type4_button.setChecked(True)
            elif classification == 4:
                self.reject_button.setChecked(True)

            # if dye_cls == 'O':
            #     self.no_crack_button.setChecked(True)
            if dye_cls == 'A':
                self.type_a_button.setChecked(True)
            elif dye_cls == 'B':
                self.type_b_button.setChecked(True)
            elif dye_cls == 'C':
                self.type_c_button.setChecked(True)
            elif dye_cls == 'D':
                self.type_d_button.setChecked(True)
            elif dye_cls == 'E':
                self.type_e_button.setChecked(True)

        elif self.board_type == 'SJQ':

            classification = self.df_cls.ix[pin_index, 'SJQ Correction']

            if classification == 0:
                self.nwo_button.setChecked(True)
            elif classification == 1:
                self.nco_button.setChecked(True)
            elif classification == 2:
                self.normal_button.setChecked(True)
            elif classification == 3:
                self.hnp_button.setChecked(True)
            elif classification == 4:
                self.sb_button.setChecked(True)

    def sort_sjq(self, row):

        if row['SJQ Correction'] == 0:
            value = 'a'
        elif row['SJQ Correction'] == 1:
            value = 'b'
        elif row['SJQ Correction'] == 2:
            value = 'e'
        elif row['SJQ Correction'] == 3:
            value = 'c'
        elif row['SJQ Correction'] == 4:
            value = 'd'
        else:
            raise ValueError('Unknown classification')

        return value

    def set_buttons(self, column_names):

        if all(cn in column_names for cn in ['SJR', 'SJQ', 'DYE']):

            if not(self.df_cls['SJR'].isnull().all()) and \
                    not(self.df_cls['DYE'].isnull().all()) and \
                    self.df_cls['SJQ'].isnull().all():

                self.board_type = 'SJR'
                self.nwo_button.setDisabled(True)
                self.nco_button.setDisabled(True)
                self.normal_button.setDisabled(True)
                self.hnp_button.setDisabled(True)
                self.sb_button.setDisabled(True)
                self.refresh_sjq_button.setDisabled(True)
                self.type1_button.setDisabled(False)
                self.type2_button.setDisabled(False)
                self.type3_button.setDisabled(False)
                self.type4_button.setDisabled(False)
                self.reject_button.setDisabled(False)
                # self.no_crack_button.setDisabled(False)
                self.type_a_button.setDisabled(False)
                self.type_b_button.setDisabled(False)
                self.type_c_button.setDisabled(False)
                self.type_d_button.setDisabled(False)
                self.type_e_button.setDisabled(False)
                self.refresh_sjr_button.setDisabled(False)
                self.sjr_image_button.setDisabled(False)

            elif not(self.df_cls['SJR'].isnull().all()) and \
                    not(self.df_cls['SJQ'].isnull().all()) and \
                    self.df_cls['DYE'].isnull().all():

                self.board_type = 'SJQ'
                self.nwo_button.setDisabled(False)
                self.nco_button.setDisabled(False)
                self.normal_button.setDisabled(False)
                self.hnp_button.setDisabled(False)
                self.sb_button.setDisabled(False)
                self.refresh_sjq_button.setDisabled(False)
                self.type1_button.setDisabled(True)
                self.type2_button.setDisabled(True)
                self.type3_button.setDisabled(True)
                self.type4_button.setDisabled(True)
                self.reject_button.setDisabled(True)
                # self.no_crack_button.setDisabled(True)
                self.type_a_button.setDisabled(True)
                self.type_b_button.setDisabled(True)
                self.type_c_button.setDisabled(True)
                self.type_d_button.setDisabled(True)
                self.type_e_button.setDisabled(True)
                self.refresh_sjr_button.setDisabled(True)
                self.sjr_image_button.setDisabled(True)

            else:
                raise ValueError('Incomplete CSV file')

        else:
            raise ValueError('Unknown board')

    def sort_classification(self):

        if self.board_type == 'SJQ':
            self.df_cls['Sorted SJQ'] = self.df_cls['SJQ Correction']

            self.df_cls['Sorted SJQ'] = self.df_cls.apply(self.sort_sjq,
                                                          axis=1)
            self.df_cls = self.df_cls.sort_values(
                by=['Sorted SJQ', 'SJR Correction'], ascending=[True, True])
            self.df_cls = self.df_cls.reset_index(drop=True)

        if self.board_type == 'SJR':

            self.df_cls = self.df_cls.sort_values(
                by=['SJR Correction', 'DYE Correction'],
                ascending=[True, False])
            self.df_cls = self.df_cls.reset_index(drop=True)

    def pin_dividers(self):

        # Create columns to indicate the divider location for the QListWidgetItem
        if self.board_type == 'SJQ':
            self.df_cls['Type Change'] = self.df_cls['SJQ Correction'].shift(
                -1) != self.df_cls['SJQ Correction']
            if self.df_cls.ix[0, 'SJQ Correction'] != self.df_cls.ix[
                    1, 'SJQ Correction']:
                self.df_cls.ix[0, 'Type Change'] = True
            else:
                self.df_cls.ix[0, 'Type Change'] = False

        elif self.board_type == 'SJR':
            self.df_cls['Type Change'] = self.df_cls['SJR Correction'].shift(
                -1) != self.df_cls['SJR Correction']
            if self.df_cls.ix[0, 'SJR Correction'] != self.df_cls.ix[
                    1, 'SJR Correction']:
                self.df_cls.ix[0, 'Type Change'] = True
            else:
                self.df_cls.ix[0, 'Type Change'] = False

    def load_pins(self, first_load=False):

        if first_load:
            self.list_items = {}

        count_image_cat = 0

        # Create a list of the 'Image Path'  column
        image_path_values = self.df_cls['Image Path'].values.tolist()
        for image_path in self.df_cls['Image Path'].tolist():

            if first_load:
                icon = QtGui.QIcon(image_path)
            else:
                icon = self.list_items[image_path]

            # Adding the pin images to the scrollable list (WidgetListItem)
            item = QtWidgets.QListWidgetItem()
            item.setIcon(icon)
            self.list_items[image_path] = icon

            index = image_path_values.index(image_path)
            item.setData(32, index)

            self.mpl_figs.addItem(item)

            count_image_cat += 1

            if self.df_cls.iloc[index]['Type Change'] == True:

                mod_of_5 = count_image_cat % 5
                if mod_of_5 != 0:
                    amount_of_white = (5 - mod_of_5) + 5
                else:
                    amount_of_white = 5

                for x in range(0, amount_of_white):

                    item = QtWidgets.QListWidgetItem()  # delimiter
                    item.setData(32, -1)
                    item.setFlags(
                        QtCore.Qt.NoItemFlags)  # item should not be selectable
                    self.mpl_figs.addItem(item)

                count_image_cat = 0

    def get_pins(self, root):
        """Load list of pin icons."""

        # Reading the CSV file from the appropriate path
        self.df_cls = pd.read_csv(os.path.join(root, 'classification.csv'))

        if not ('SJQ Correction' in self.df_cls.columns):
            self.df_cls['SJQ Correction'] = self.df_cls['SJQ'].copy()
        if not ('SJR Correction' in self.df_cls.columns):
            self.df_cls['SJR Correction'] = self.df_cls['SJR'].copy()
        if not ('DYE Correction' in self.df_cls.columns):
            self.df_cls['DYE Correction'] = self.df_cls['DYE'].copy()

        # Adding the Pin column and adding the pin names to it by
        # combining the Col and Row Columns
        self.df_cls['Pin'] = self.df_cls[['Row', 'Col']].apply(
            lambda x: '{}{}{}'.format(x[0], '_', x[1]), axis=1)

        # Adding a label for the index column
        self.df_cls.columns.names = ['Index']

        column_names = list(self.df_cls)

        self.set_buttons(column_names)

        self.sort_classification()

        df_pin_list = self.df_cls['Pin'].tolist()

        # Adding the path of the each pin image to the Dataframe

        # Create a column in the dataframe to store the image paths
        self.df_cls['Image Path'] = np.nan

        # Create a column in the dataframe to store the DYE image paths
        self.df_cls['DYE Image Path'] = np.nan

        # Matching the pins with the correct image and adding them to
        # the 'Image Path' column

        # for every file (image_path) in the pins folder:
        for f in os.listdir(os.path.join(root, 'pins')):

            if f.endswith("_1.tif"):
                image_path = os.path.join(os.getcwd(),
                                          os.path.join(root, 'pins'), f)
                pin_name = os.path.basename(image_path).replace("_1.tif", "")

                # finding the index of the row where the pin is located in
                # the dataframe
                if pin_name in df_pin_list:

                    index = self.df_cls[self.df_cls['Pin'] ==
                                        pin_name].index[0]

                    # adding the pin's image path to the same row in the dataframe
                    # where the pin is located
                    self.df_cls.loc[index, 'Image Path'] = image_path

        if self.board_type == 'SJR':
            if 'dye' in os.listdir(os.path.join(root, '')):
                for f in os.listdir(os.path.join(root, 'dye')):

                    dye_path = os.path.join(os.getcwd(),
                                            os.path.join(root, 'dye'), f)
                    pin_name = os.path.basename(dye_path).replace("_1.tif", "")

                    if pin_name in df_pin_list:

                        index = self.df_cls[self.df_cls['Pin'] ==
                                            pin_name].index[0]

                        self.df_cls.loc[index, 'DYE Image Path'] = dye_path

            else:
                print('Unable to locate the DYE folder. Thus no dye images')
                self.sjr_image_button.setDisabled(True)

        self.pin_dividers()

        self.load_pins(first_load=True)

    def load_board(self, root):
        """Load the appropriate board based on the menu clicked."""
        self.centralwidget.setDisabled(True)

        print('Path: ', root)
        board_name = os.path.basename(root)

        print('Board name: ', board_name)
        print('Board type: ', self.board_type)

        new_font = QtGui.QFont("Arial", 10, QtGui.QFont.Bold)
        self.board_display_label.setFont(new_font)
        self.board_info = "Board name: {0}\nBoard type: {1}".format(
            board_name, self.board_type)
        self.board_display_label.setText(self.board_info)

        with open(os.path.join(root, 'info.txt')) as json_file:
            data = json.load(json_file)
            print('guide path: ', data['guidePath'])

        csvpath = data['guidePath']
        b2p_ratio = int(data['b2p_ratio'])
        from guides.generalguide import GeneralGuide

        self.guide = GeneralGuide(csvpath, b2p_ratio)

        pin_coords = []

        for pin in self.df_cls['Pin'].tolist():
            if pin in self.guide.pins:
                pin_coords.append(tuple(self.guide.position(pin)))

        self.df_cls['Label Coords'] = pd.Series(pin_coords,
                                                index=self.df_cls.index)

        self.drawing_tool = DrawPredictions(guide=self.guide,
                                            mode=self.board_type)

        cls_corrected_col = "{} Correction".format(self.board_type)
        # print('cls_corrected_col', cls_corrected_col)

        predictions = self.df_cls.set_index('Pin')[cls_corrected_col].to_dict()

        if self.board_type == 'SJR':
            dye_predictions = self.df_cls.set_index(
                'Pin')['DYE Correction'].to_dict()
            self.board_label = self.drawing_tool(shape=1000,
                                                 pred=predictions,
                                                 dyepred=dye_predictions)
        elif self.board_type == 'SJQ':
            self.board_label = self.drawing_tool(shape=1000, pred=predictions)
        else:
            raise ValueError('Unknown board')

        self.board_img = mpimg.imread(os.path.join(root, 'stitched.png'))

        gs = gridspec.GridSpec(1, 2)
        gs.update(left=0.005, right=0.99, wspace=0.05)

        self.axes1 = self.figure.add_subplot(gs[0])
        self.axes1.get_xaxis().set_visible(False)
        self.axes1.get_yaxis().set_visible(False)

        self.axes2 = self.figure.add_subplot(gs[1])
        self.axes2.get_xaxis().set_visible(False)
        self.axes2.get_yaxis().set_visible(False)

        # print('init scatter')

        self.axes1_scatter = None
        self.axes2_scatter = None

        self.canvas.mpl_connect('resize_event', self.connect_resize)
        self.export_action.triggered.connect(
            functools.partial(self.save_board_label, root))
        self.export_action.setDisabled(False)

        self.run_macro_action.triggered.connect(
            functools.partial(self.run_excel_macro, root))

        self.first_load = False

        self.centralwidget.setDisabled(False)

    def connect_resize(self, event):
        print('resize event', event)

        self.centralwidget.setDisabled(True)

        self.perform_resize()

        # try:
        #     print('resize event try', event)
        #     self.resize_timer.cancel()
        # except:
        #     pass
        #
        # self.resize_timer = Timer(1, self.perform_resize)
        # self.resize_timer.start()

    def perform_resize(self):
        # print('perform_resize')

        self.axes1.imshow(self.board_img)
        self.axes2.imshow(self.board_label)

        try:
            self.axes1_scatter.remove()
        except:
            print('Unable to remove axes1_scatter')
            pass

        try:
            self.axes2_scatter.remove()
        except:
            print('Unable to remove axes2_scatter')
            pass

        self.canvas.draw()

        self.background_image = self.canvas.copy_from_bbox(self.axes1.bbox)
        self.background_board = self.canvas.copy_from_bbox(self.axes2.bbox)

        self.centralwidget.setDisabled(False)

    def connect_click(self):
        """Just testing this."""

        # connect the click or arrow keys press on the the scrollable
        # image list to the currentItemChanged
        self.mpl_figs.currentItemChanged.connect(self.current_item_changed)

        self.nwo_button.toggled.connect(self.nwo_button_clicked)

        self.nco_button.toggled.connect(self.nco_button_clicked)

        self.normal_button.toggled.connect(self.normal_button_clicked)

        self.hnp_button.toggled.connect(self.hnp_button_clicked)

        self.sb_button.toggled.connect(self.sb_button_clicked)

        self.type1_button.toggled.connect(self.type1_button_clicked)

        self.type2_button.toggled.connect(self.type2_button_clicked)

        self.type3_button.toggled.connect(self.type3_button_clicked)

        self.type4_button.toggled.connect(self.type4_button_clicked)

        self.reject_button.toggled.connect(self.reject_button_clicked)

        # self.no_crack_button.toggled.connect(self.no_crack_button_clicked)

        self.type_a_button.toggled.connect(self.type_a_button_clicked)

        self.type_b_button.toggled.connect(self.type_b_button_clicked)

        self.type_c_button.toggled.connect(self.type_c_button_clicked)

        self.type_d_button.toggled.connect(self.type_d_button_clicked)

        self.type_e_button.toggled.connect(self.type_e_button_clicked)

        self.refresh_sjq_button.clicked.connect(self.refresh_sjq_clicked)

        self.refresh_sjr_button.clicked.connect(self.refresh_sjr_clicked)

        self.sjr_image_button.clicked.connect(self.sjr_image_switch)

        self.canvas.callbacks.connect('button_press_event', self.on_click)

    def save_board_label(self, path):
        import cv2
        cv2.imwrite(os.path.join(path, 'board.png'),
                    255 * self.board_label[:, :, ::-1])

    def run_excel_macro(self, path):

        try:
            xlApp = win32com.client.Dispatch('Excel.Application')
            xlsPath = os.path.abspath(os.path.join(path, 'MacroTest.xlsm'))
            print(xlsPath)
            macroxl = xlApp.Workbooks.Open(xlsPath)
            xlApp.Run('MacroTest.xlsm!TEST_1')
            macroxl.Save()
            xlApp.Quit()
            print("Macro ran successfully!")

        except Exception as e:
            print("Error found while running the excel macro!")
            raise

    def load_image(self, root):
        """Construct the board image and load it on matplotlib container."""
        # Rest the matplotlib container and the buttons

        print('Starting to load the images...')

        if not self.first_load:
            self.remove_board()

        # Add a matplotlib container
        self.add_mpl()

        # load pin images in the zoom-in container and
        # the scrollable image list
        self.get_pins(root)

        self.load_pin_cls(0)

        self.load_board(root)

        self.connect_click()

    def closeEvent(self, event):
        try:
            self.resize_timer.cancel()
        except:
            pass

    def load_model(self, root, board_name, save_path):

        default_settings = {
            'sampleID': board_name,
            'rawImgPath': root,
            'cropSavePath': os.path.join(save_path, "{}", 'pins'),
            'savePath': os.path.join(save_path, "{}"),
            'additionalCrop': 'None',
            'b2p_ratio': '150',
            'device': 'gpu0'
        }

        dialog = Ui_SettingsWindow(default_settings)
        dialog.setAttribute(QtCore.Qt.WA_DeleteOnClose)

        exit_code = dialog.exec_()
        # print(exit_code)

        if exit_code:
            # print(dialog.values)
            settings_dict = dialog.values

            userName = settings_dict['userName']
            sampleID = settings_dict['sampleID']
            analysisType = settings_dict['analysisType']
            rawImgPath = settings_dict['rawImgPath']
            cropSavePath = settings_dict['cropSavePath']
            segDataPath = settings_dict['segDataPath']
            savePath = settings_dict['savePath']
            guidePath = settings_dict['guidePath']
            additionalCrop = settings_dict['additionalCrop']
            b2p_ratio = int(settings_dict['b2p_ratio'])
            device = settings_dict['device']

            class TaskThread(QtCore.QThread):
                def run(self):
                    print("task is running")
                    try:
                        from guides.generalguide import GeneralGuide
                        guide = GeneralGuide(guidePath, b2p_ratio)

                        from model import Model
                        valid_additionalCrop = None if additionalCrop == 'None' else (
                            int(additionalCrop), int(additionalCrop))

                        f_model = Model(sampleID, analysisType, guide,
                                        valid_additionalCrop, rawImgPath,
                                        segDataPath, cropSavePath, savePath,
                                        device)

                        from time import time
                        st = time()
                        print('Starting the analysis...')
                        f_model(saveResults=True)
                        sp = time() - st
                        print('Analysis took {} seconds'.format(sp))
                    except Exception as e:
                        print(e)

            task_thread = TaskThread()

            from progress_bar import ConstantProgressBar
            progress_dialog = ConstantProgressBar(task_thread)
            exit_code = progress_dialog.exec_()
            print(exit_code)

            json.dump(settings_dict,
                      open(os.path.join(savePath, 'info.txt'), 'w'),
                      indent=4)

            self.refresh_menus()

            self.load_image(savePath)
Beispiel #8
0
class Window(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent)   

        #測定データ保存先
        self.label_Save = QLabel('Save Directory:', self)
        self.edit_Save = QLineEdit(self)
        self.edit_Save.setFocusPolicy(Qt.ClickFocus)
        self.edit_Save.setText(SCRIPT_DIR)
        self.btn_Save = QPushButton('Open', self)
        self.btn_Save.setStyleSheet('background-color:white; color:black;')
        self.btn_Save.clicked.connect(self.showDialog)

        #測定モードの選択
        self.label_Mode = QLabel('Mode:', self)
        self.combo_Mode = QComboBox(self)
        self.combo_Mode.addItems(['RSTF', 'SSTF', 'LSTF', 'Mic_Ajust'])

        #ターゲット
        self.label_Output = QLabel('Target:', self)
        self.combo_Output = QComboBox(self)
        self.combo_Output.addItem("Headphones")

        #被験者記入欄
        self.label_Sub = QLabel('subject:', self)
        self.edit_Sub = QLineEdit(self)

        #同期化算数記入欄
        self.label_averaging_time = QLabel('Avaraging Times:', self)
        self.edit_averaging_time = QLineEdit(self)
        self.edit_averaging_time.setText("10")

        #測定開始ボタン
        self.btn_Start = QPushButton('START', self)
        self.btn_Start.setStyleSheet('background-color:cyan; color:black;')
        self.btn_Start.setFont(QFont("", 40))
        
        #イメージ
        self.image1 = QLabel(self)
        self.image1.setScaledContents(True)
        self.image1.setPixmap(QPixmap(SCRIPT_DIR+"/.texture/RSTF.png"))

        #プロット切り替えボタン
        self.btn_pltChange = QPushButton('↔', self)
        self.btn_pltChange.setStyleSheet('background-color:white; color:black;')
        self.btn_pltChange.setFont(QFont("", 28))
        self.btn_pltChange.setHidden(True)
        self.btn_pltChange.clicked.connect(self.plotChange) 

        #ログ
        self.label_log = QLabel('Log:', self)
        self.te_log = QTextEdit(self)
        self.te_log.setReadOnly(True) 
        self.te_log.setStyleSheet('background-color:black;')
        self.te_log.setTextColor(QColor(0, 255, 0))
        self.te_log.setFontPointSize(16)

        #プロットウィンドウ
        self.figure1 = plt.figure(figsize=(9, 5), dpi=60)
        self.axes1 = self.figure1.add_subplot(111)
        self.axes1.tick_params(labelsize=15)
        self.canvas1 = FigureCanvas(self.figure1)
        self.canvas1.setParent(self)
        
        #プロットウィンドウ2
        self.figure2 = plt.figure(figsize=(9, 5), dpi=60)
        self.axes2 = self.figure2.add_subplot(111)
        self.axes2.tick_params(labelsize=15)
        self.canvas2 = FigureCanvas(self.figure2)
        self.canvas2.setParent(self)

        #ナビゲーションツール
        self.toolbar1 = NavigationToolbar(self.canvas1, self)
        self.toolbar1.setGeometry(800, 80, 20, 30)
        self.toolbar2 = NavigationToolbar(self.canvas2, self)
        self.toolbar2.setGeometry(800, 390, 20, 30)

        self.btn_Start.clicked.connect(self.measure)
        #ウィジェットイベント
        self.combo_Mode.currentIndexChanged.connect(self.wiget_setting)

        self.wigets_layout()

    def showDialog(self):
        frame = QFileDialog.getExistingDirectory(self)
        if frame != "": self.edit_Save.setText(frame+'/') 

    #各ウィジェットの配置
    def wigets_layout(self):
        x1=20; y1=350;y2=380
        x2=150;y3=450;y4=480

        self.setGeometry(300,300,900,800) #ウィンドウサイズ
        self.image1.setGeometry(20, 60, 250, 250)
        self.label_Save.move(x1, 10) 
        self.edit_Save.setGeometry(130, 10, 650, 20)
        self.btn_Save.setGeometry(800, 10, 80, 20)
        self.label_Mode.move(x1, y1) 
        self.combo_Mode.setGeometry(x1-3, y2, 120, 30)
        self.label_Output.move(x2, y1)
        self.combo_Output.setGeometry(x2-3, y2, 140, 30)
        self.label_Sub.move(x1, y3)
        self.edit_Sub.setGeometry(x1, y4, 100, 20)
        self.label_averaging_time.move(x2, y3)
        self.edit_averaging_time.setGeometry(x2, y4, 100, 20)
        self.btn_Start.setGeometry(60, 540, 180, 80)
        self.label_log.move(50, 640)
        self.te_log.setGeometry(50, 670, 800, 120)
        self.canvas1.move(300, 50)
        self.canvas2.move(300, 360)
        self.btn_pltChange.setGeometry(860, 55, 30, 30)

    #測定モードごとのウィジェットの設定
    def wiget_setting(self):
        self.combo_Output.clear()
        self.edit_Sub.setReadOnly(True)

        #RSTF
        if self.combo_Mode.currentIndex() is 0:
            self.combo_Output.addItem("Headphones")
            self.edit_Sub.setReadOnly(False)
            self.edit_averaging_time.setText("10") 
            self.image1.setPixmap(QPixmap(SCRIPT_DIR+"/.texture/RSTF.png"))


        #SSTF
        elif self.combo_Mode.currentIndex() is 1:
            self.combo_Output.addItems(['angle: 0-85', 'angle: 90-175', 'angle: 180-265', 'angle: 270-355', 'ITD_Check'])   
            self.edit_Sub.setReadOnly(False) 
            self.edit_averaging_time.setText("10")
            self.image1.setPixmap(QPixmap(SCRIPT_DIR+"/.texture/SSTF.png"))

        #LSTF
        elif self.combo_Mode.currentIndex() is 2:
            self.combo_Output.addItems('Speaker No.' + str(n) for n in range(1, 19))
            self.edit_Sub.clear()
            self.edit_averaging_time.setText("10")
            self.image1.setPixmap(QPixmap(SCRIPT_DIR+"/.texture/LSTF.png"))

        #MicAjust
        elif self.combo_Mode.currentIndex() is 3:
            self.combo_Output.addItem("Speaker No.1")
            self.edit_Sub.clear()
            self.image1.setPixmap(QPixmap(SCRIPT_DIR+""))

    #測定
    def measure(self):
        self.averaging_times = self.edit_averaging_time.text()
        self.speaker_index = self.combo_Output.currentIndex()
        self.subject = self.edit_Sub.text()
        self.outdir = self.edit_Save.text() + "/" + self.subject
        self.Reverse = False #プロット切り替えスイッチ
        self.btn_pltChange.setHidden(False)

        #RSTF
        if self.combo_Mode.currentIndex() is 0:
            if self.averaging_times is "":
                QMessageBox.warning(self, "Message", u"SANnum is invalid or empty")
                return
            if self.subject is "":
                QMessageBox.warning(self, "Message", u"subject Name is empty")
                return
            measure.RSTF(self.subject, self.averaging_times, 1, 255, 3801, 4823, self.outdir)
            self.te_log.append('cinv_cRSTF_L.DDB and cinv_cRSTF_R.DDB are measured. ('
                                + datetime.now().strftime("%H:%M:%S")  + ')')
            self.plotChange()

        #SSTF
        elif self.combo_Mode.currentIndex() is 1:
            self.btn_pltChange.setHidden(True)
            if os.path.exists("/Volumes/share/angle") is False:
                QMessageBox.warning(self, "Message", u"Speaker selector is not connecting.")
                return
            if self.averaging_times is "":
                QMessageBox.warning(self, "Message", u"SANnum is invalid or empty")
                return
            if self.subject is "":
                QMessageBox.warning(self, "Message", u"subject Name is empty")
                return
            if self.speaker_index is 4: 
                with open('/Volumes/share/angle', 'w') as select: select.write("1")
                measure.SSTF(self.subject, self.averaging_times, 'check', 150, 405, self.outdir)
                self.plot(self.outdir + '/SSTF/cSSTF_check_L.DDB'
                        , self.outdir + '/SSTF/cSSTF_check_R.DDB')        #/SSTF/cSSTF_000_R.DDBから/SSTF/cSSTF_check_R.DDBに名称変更
                return

            for n in range(18):
                angle = self.speaker_index * 90 + n * 5
                with open('/Volumes/share/angle', 'w') as select: select.write(str(n+1))
                measure.SSTF(self.subject, self.averaging_times, angle, 150, 405, self.outdir)
                self.te_log.append('SSTF_' + str(angle) + '_L.DDB and SSTF_' 
                                + str(angle) + '_R.DDB are measured. ('+ datetime.now().strftime("%H:%M:%S")  + ')')
                self.plot(self.outdir + '/SSTF/cSSTF_' + str(angle) + '_L.DDB'
                        , self.outdir + '/SSTF/cSSTF_' + str(angle) + '_R.DDB')
                self.canvas1.flush_events()
                self.canvas2.flush_events()

        #LSTF
        elif self.combo_Mode.currentIndex() is 2:
            if os.path.exists("/Volumes/share/angle") is False:
                QMessageBox.warning(self, "Message", u"Speaker selector is not connecting.")
                return
            if self.averaging_times is "":
                QMessageBox.warning(self, "Message", u"SANnum is invalid or empty")
                return
            with open('/Volumes/share/angle', 'w') as select:
                        select.write(str(self.speaker_index+1))
            measure.LSTF(self.speaker_index+1, self.averaging_times, 150, 405, 3800, 4823, self.edit_Save.text())
            self.te_log.append('/LSTF_' + str(self.speaker_index+1) + '.DDB is measured. ('
                                + datetime.now().strftime("%H:%M:%S")  + ')')
            self.plotChange()

        #MicAjust
        elif self.combo_Mode.currentIndex() is 3:
            with open('/Volumes/share/angle', 'w') as select: select.write("1")
            measure.mic_ajust()
            self.te_log.append('rec_L.DDB and rec_R.DDB are measured. ('+ datetime.now().strftime("%H:%M:%S")  + ')')
            self.plotChange()
        
        # elif self.combo_Mode.currentIndex() is 4:
        #     cpyconv.closedloop()
        #     with open(SCRIPT_DIR+"/DOUKI_START", 'r') as douki_start: iodelay = douki_start.read()
        #     QMessageBox.about(self, "Message", "The I/O delay is "+iodelay+" sample")

    #データのプロット
    def plot(self, file_L, file_R): 
        data_bin = open(file_L, 'rb').read()
        data = np.fromstring(data_bin,dtype=np.float64)
        if file_R != None:
            data_bin2 = open(file_R, 'rb').read()
            data2 = np.fromstring(data_bin2,dtype=np.float64)

        #プロット1
        self.axes1.clear()
        self.axes1.set_title('Impulse Response', fontsize=15)
        self.axes1.set_xlabel("Sample", fontsize=15)
        self.axes1.set_ylabel("Level", fontsize=15)
        self.axes1.plot(data, '-', label=re.search("(.*)/(.*)", file_L).group(2))
        if file_R != None:
            self.axes1.plot(data2, '-', label=re.search("(.*)/(.*)", file_R).group(2))
        self.axes1.legend(bbox_to_anchor=(0., 1.02, 1., .102),borderaxespad=-0.2)
        self.canvas1.draw()

        #プロット2
        self.axes2.clear()
        self.axes2.set_title('Frequency Characteristic', fontsize=15)
        self.axes2.set_xlabel("Frequency [Hz]", fontsize=15)
        self.axes2.set_ylabel("Amplitude [dB]", fontsize=15)
        self.axes2.set_xlim(100, 24000)
        self.axes2.set_xscale('log')
        N = 255
        x = np.fft.fftfreq(N*2,d=1.0/48000)*2
        freq = np.fft.fft(data[0:N])
        data_amplitude = [np.sqrt(c.real ** 2  + c.imag ** 2 ) for c in freq]
        data_decibel = 10.0 * np.log10(data_amplitude)
        self.axes2.plot(x[0:N], data_decibel, '-', label=re.search("(.*)/(.*)", file_L).group(2))

        if file_R != None:
            freq2 = np.fft.fft(data2[0:N])
            data_amplitude2 = [np.sqrt(c.real ** 2  + c.imag ** 2 ) for c in freq2]
            data_decibel2 = 10.0 * np.log10(data_amplitude2)
            self.axes2.plot(x[0:N], data_decibel2, '-', label=re.search("(.*)/(.*)", file_R).group(2))
        
        self.axes2.legend(bbox_to_anchor=(0., 1.02, 1., .102),borderaxespad=-0.2)
        self.canvas2.draw()

    def plotChange(self):
        if self.Reverse is False: 
            #RSTF
            if self.combo_Mode.currentIndex() is 0:
                self.plot(self.outdir + '/RSTF/cRSTF_L.DDB'
                        , self.outdir + '/RSTF/cRSTF_R.DDB')
            #LSTF
            elif self.combo_Mode.currentIndex() is 2:
                self.plot(self.edit_Save.text() + '/LSTF/cLSTF_' + str(self.speaker_index+1) + '.DDB', None)
            #MicAjust
            elif self.combo_Mode.currentIndex() is 3:
                self.plot('/tmp/rec_L.DDB', '/tmp/rec_R.DDB')
            self.Reverse = True

        elif self.Reverse is True:
            #RSTF
            if self.combo_Mode.currentIndex() is 0:
                self.plot(self.outdir + '/RSTF/cinv_cRSTF_L.DDB'
                        , self.outdir + '/RSTF/cinv_cRSTF_R.DDB')
            #LSTF
            elif self.combo_Mode.currentIndex() is 2:
                self.plot(self.edit_Save.text() + '/LSTF/cinv_cLSTF_' + str(self.speaker_index+1) + '.DDB', None)
            #MicAjust
            elif self.combo_Mode.currentIndex() is 3:
                self.plot('/tmp/rec_R.DDB', '/tmp/rec_L.DDB')
            self.Reverse = False


    #ウィンドウ内の座標を表示
    def mousePressEvent(self, event):
        print('x='+ str(event.x()) + ', y=' + str(event.y()))
Beispiel #9
0
class Window(QtWidgets.QMainWindow, Ui_MainWindow):
    """
    重写修改该类即可实现自定义后端界面,相加什么按钮可以随便加,目前还只是个demo

    self.canvas.draw() 每执行该函数,图形重绘
    """
    def __init__(self, figure):
        super(Window, self).__init__()
        self.setupUi(self)  # 先执行父类方法,以产生成员变量
        self.figure = figure
        self.canvas = FigureCanvas(self.figure)  # 这里的canvas就是曲线图
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.toolbar.hide()  # 隐藏QT原来的工具栏
        # self.graphicsView.addWidget(self.canvas)  # 将canvas渲染到布局中
        self.scene = QGraphicsScene()
        self.scene.addWidget(self.canvas)
        self.graphicsView.setScene(self.scene)
        self.graphicsView.show()
        self.actionX_X.triggered.connect(self.axes_control_slot)
        # 初始化当前界面
        self.init_gui()
        # 槽函数连接
        # 当前子图对象切换
        self.current_path = os.path.dirname(__file__)
        self.saveAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/save.png')), 'save',
            self)
        self.saveAction.setShortcut('Ctrl+S')
        self.saveAction.triggered.connect(self.save_slot)
        self.toolBar.addAction(self.saveAction)

        self.settingAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/setting.png')),
            'setting', self)
        self.settingAction.triggered.connect(self.axes_control_slot)
        self.toolBar.addAction(self.settingAction)

        self.homeAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/home.png')), 'home',
            self)
        self.homeAction.setShortcut('Ctrl+H')
        self.homeAction.triggered.connect(self.home_slot)
        self.toolBar.addAction(self.homeAction)

        self.backAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/back.png')), 'back',
            self)
        self.backAction.triggered.connect(self.back_slot)
        self.toolBar.addAction(self.backAction)

        self.frontAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/front.png')), 'front',
            self)
        self.frontAction.triggered.connect(self.front_slot)
        self.toolBar.addAction(self.frontAction)

        self.zoomAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/zoom.png')), 'zoom',
            self)
        self.zoomAction.triggered.connect(self.zoom_slot)
        self.toolBar.addAction(self.zoomAction)

        self.panAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/pan.png')), '平移',
            self)
        self.panAction.triggered.connect(self.pan_slot)
        self.toolBar.addAction(self.panAction)

        self.rotateAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/rotate.png')),
            'rotate', self)
        self.rotateAction.triggered.connect(self.rotate_slot)
        self.toolBar.addAction(self.rotateAction)

        self.textAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/text.png')), 'text',
            self)
        self.textAction.triggered.connect(self.add_text_slot)
        self.toolBar.addAction(self.textAction)

        self.rectAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/rect.png')), 'rect',
            self)
        self.rectAction.triggered.connect(self.add_rect_slot)
        self.toolBar.addAction(self.rectAction)

        self.ovalAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/oval.png')), 'oval',
            self)
        self.ovalAction.triggered.connect(self.add_oval_slot)
        self.toolBar.addAction(self.ovalAction)

        self.arrowAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/arrow.png')), 'arrow',
            self)
        self.arrowAction.triggered.connect(self.add_arrow_slot)
        self.toolBar.addAction(self.arrowAction)

        self.pointAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/point.png')), 'point',
            self)
        self.pointAction.triggered.connect(self.add_point_slot)
        self.toolBar.addAction(self.pointAction)

        self.styleAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/style.png')), 'style',
            self)
        self.styleAction.triggered.connect(self.add_style_slot)
        self.toolBar.addAction(self.styleAction)

        # 将下拉菜单放在最右边
        self.toolBar.addSeparator()
        spacer = QWidget()
        spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.toolBar.addWidget(spacer)
        self.toolBar.addWidget(self.comboBox)

        # 以上为工具栏1

        self.gridAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/grid.png')),
            '显示/隐藏网格', self)
        self.gridAction.triggered.connect(self.show_grid_slot)
        self.toolBar_2.addAction(self.gridAction)

        self.legendAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/legend.png')),
            '显示/隐藏图例', self)
        self.legendAction.triggered.connect(self.show_legend_slot)
        self.toolBar_2.addAction(self.legendAction)

        self.colorbarAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/colorbar.png')),
            '显示/隐藏colorbar', self)
        self.colorbarAction.triggered.connect(self.show_colorbar_slot)
        self.toolBar_2.addAction(self.colorbarAction)

        self.layoutAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/layout.png')), '改变布局',
            self)
        # self.layoutAction.triggered.connect(self.show_layout_slot)
        self.toolBar_2.addAction(self.layoutAction)

        self.mainViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/mainView.png')),
            'mainView', self)
        self.mainViewAction.triggered.connect(self.mainView_slot)
        self.toolBar_2.addAction(self.mainViewAction)

        self.leftViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/leftView.png')),
            'leftView', self)
        self.leftViewAction.triggered.connect(self.leftView_slot)
        self.toolBar_2.addAction(self.leftViewAction)

        self.rightViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/rightView.png')),
            'rightView', self)
        self.rightViewAction.triggered.connect(self.rightView_slot)
        self.toolBar_2.addAction(self.rightViewAction)

        self.topViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/topView.png')),
            'topView', self)
        self.topViewAction.triggered.connect(self.topView_slot)
        self.toolBar_2.addAction(self.topViewAction)

        self.bottomViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/bottomView.png')),
            'bottomView', self)
        self.bottomViewAction.triggered.connect(self.bottomView_slot)
        self.toolBar_2.addAction(self.bottomViewAction)

        self.backViewAction = QAction(
            QIcon(os.path.join(self.current_path, 'icons/backView.png')),
            'backView', self)
        self.backViewAction.triggered.connect(self.backView_slot)
        self.toolBar_2.addAction(self.backViewAction)

        # 样式右键菜单功能集
        self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
        self.customContextMenuRequested.connect(self.rightMenuShow)
        self.rightMenuShow()  # 创建上下文菜单

        self.comboBox.currentIndexChanged.connect(self.combobox_slot)
        # 获取子图对象
        self.axes = self.canvas.figure.get_axes()
        if not self.axes:
            QtWidgets.QMessageBox.warning(self.canvas.parent(), "Error",
                                          "There are no axes to edit.")
            return
        elif len(self.axes) == 1:
            self.current_subplot, = self.axes
            titles = ['图1']
        else:
            titles = ['图' + str(i + 1) for i in range(len(self.axes))]
        # 将三维图的初始视角保存下来,便于旋转之后可以复原
        self.init_views = [(index, item.azim, item.elev)
                           for index, item in enumerate(self.axes)
                           if hasattr(item, 'azim')]
        self.comboBox.addItems(titles)

        # 鼠标拖拽,实现三维图形旋转功能
        self.canvas.mpl_connect('motion_notify_event', self.on_rotate)
        # 鼠标拖拽,实现画矩形功能
        self.canvas.mpl_connect('motion_notify_event', self.add_rect)
        self.mouse_pressed = False
        # 鼠标拖拽,实现画椭圆和按住shift画圆的功能实现
        self.canvas.mpl_connect('motion_notify_event', self.add_oval)
        # 画箭头
        self.canvas.mpl_connect('motion_notify_event', self.add_arrow)
        # 画点,并且按住点能够移动
        self.canvas.mpl_connect('button_press_event', self.add_point)
        self.canvas.mpl_connect('motion_notify_event', self.add_point)
        self.canvas.mpl_connect('button_release_event', self.add_point)
        self.canvas.mpl_connect('pick_event', self.add_point)
        self.press_time = 0
        # 记录鼠标运动的位置
        self.rotate_mouse_point = None
        self.canvas.mpl_connect('button_press_event', self.add_text)
        # 为曲线添加样式的功能实现
        self.canvas.mpl_connect('button_press_event', self.add_style)
        self.canvas.mpl_connect('pick_event', self.add_style)
        # 为图例绑定监听事件
        self.canvas.mpl_connect('button_press_event', self.change_legend)
        self.canvas.mpl_connect('pick_event', self.change_legend)
        # 所有的按钮标志
        self.make_flag_invalid()
        self.artist = None
        for ax in self.axes:
            for line in ax.lines:
                line.set_picker(True)
                line.set_pickradius(5)

    def make_flag_invalid(self):
        self.add_rect_flag = False
        self.add_oval_flag = False
        self.add_text_flag = False
        self.rotate_flag = False
        self.home_flag = False
        self.pan_flag = False
        self.zoom_flag = False
        self.add_arrow_flag = False
        self.add_point_flag = False
        self.add_style_flag = False
        self.show_grid_flag = False
        self.show_legend_flag = False
        # 禁用移动和缩放
        self.toolbar.mode = None

    def home_slot(self):
        self.make_flag_invalid()
        self.home_flag = not self.home_flag
        self.home()

    def home(self):
        """
        matplotlib lines里面放曲线,patches可以放图形,artists也可以放东西,设为空则可以删除对应的对象
        """
        if not self.home_flag:
            return
        self.toolbar.home()
        # 将三维图视角还原
        for item in self.init_views:
            self.axes[item[0]].view_init(azim=item[1], elev=item[2])
        # 将所有添加的形状去除
        for item in self.axes:
            item.patches = []
            item.artists = []  # 去除画在图中的点,是否需要去掉有待考究
        self.canvas.draw()

    def zoom_slot(self):
        self.make_flag_invalid()
        self.zoom_flag = not self.zoom_flag
        self.zoom()

    def zoom(self):
        if not self.zoom_flag:
            return
        self.toolbar.zoom()

    def pan_slot(self):
        self.make_flag_invalid()
        self.pan_flag = not self.pan_flag
        self.pan()

    def pan(self):
        if not self.pan_flag:
            return
        self.toolbar.pan()

    def save_slot(self):
        self.toolbar.save_figure()

    def front_slot(self):
        self.toolbar._nav_stack.forward()
        self.toolbar._update_view()

    def back_slot(self):
        self.toolbar._nav_stack.back()
        self.toolbar._update_view()

    def add_text_slot(self):
        self.make_flag_invalid()
        self.add_text_flag = not self.add_text_flag

    def add_text(self, event):
        if not self.add_text_flag:
            return
        if self.add_text_flag and event.xdata and event.ydata and not hasattr(
                event.inaxes, 'azim'):
            text, ok = QtWidgets.QInputDialog.getText(self.canvas.parent(),
                                                      '输入文字', '添加注释')
            if ok and text:
                event.inaxes.text(event.xdata, event.ydata, text)
                # plt.text(event.xdata, event.ydata, text)
                self.canvas.draw()

    def rotate_slot(self):
        self.make_flag_invalid()
        self.rotate_flag = not self.rotate_flag

    def on_rotate(self, event):
        """
        通过观察发现,旋转时产生的xdata,ydata是以图像中心为原点,正负0.1范围内的数值
        """
        if not self.rotate_flag:
            return
        # 如果鼠标移动过程有按下,视为拖拽,判断当前子图是否有azim属性来判断当前是否3D
        if event.button and hasattr(event.inaxes, 'azim'):
            for item in self.init_views:
                if self.axes[item[0]] == event.inaxes:
                    delta_azim = 180 * event.xdata * -1 / 0.1
                    delta_elev = 180 * event.ydata / 0.1
                    azim = delta_azim + item[1]
                    elev = delta_elev + item[2]
                    event.inaxes.view_init(azim=azim, elev=elev)
                    self.canvas.draw()

    def add_rect_slot(self):
        # 除了本标记,其余全置False
        self.make_flag_invalid()
        self.add_rect_flag = not self.add_rect_flag

    def add_rect(self, event):
        if not self.add_rect_flag:
            return
        if not event.button and event.inaxes:
            self.ax_init = event
            if self.mouse_pressed and event.inaxes.patches:
                event.inaxes.add_patch(event.inaxes.patches[0])
            self.mouse_pressed = False

        # 仅能在二维图形中画矩形,鼠标按下,且当前子图和初始点的子图相同
        try:
            if event.button and event.inaxes and event.inaxes == self.ax_init.inaxes and not hasattr(
                    event.inaxes, 'azim'):
                self.mouse_pressed = True
                if event.inaxes.patches:
                    event.inaxes.patches.pop()
                rect = plt.Rectangle((self.ax_init.xdata, self.ax_init.ydata),
                                     event.xdata - self.ax_init.xdata,
                                     event.ydata - self.ax_init.ydata,
                                     fill=False,
                                     edgecolor='red',
                                     linewidth=1)
                event.inaxes.add_patch(rect)
                self.canvas.draw()
                self.canvas.flush_events()
        except Exception:
            pass

    def add_oval_slot(self):
        self.make_flag_invalid()
        self.add_oval_flag = not self.add_oval_flag

    def add_oval(self, event):
        if not self.add_oval_flag:
            return
        if not event.button and event.inaxes:
            self.ax_init = event
            if self.mouse_pressed and event.inaxes.patches:
                event.inaxes.add_patch(event.inaxes.patches[0])
            self.mouse_pressed = False
        try:
            if event.button and event.inaxes and event.inaxes == self.ax_init.inaxes and not hasattr(
                    event.inaxes, 'azim'):
                self.mouse_pressed = True
                if event.inaxes.patches:
                    event.inaxes.patches.pop()
                oval = Ellipse(xy=(self.ax_init.xdata, self.ax_init.ydata),
                               width=abs(event.xdata - self.ax_init.xdata) * 2,
                               height=abs(event.ydata - self.ax_init.ydata) *
                               2,
                               angle=0,
                               fill=False,
                               edgecolor='red',
                               linewidth=1)
                event.inaxes.add_patch(oval)
                self.canvas.draw()
                self.canvas.flush_events()
        except Exception:
            pass

    def add_arrow_slot(self):
        self.make_flag_invalid()
        self.add_arrow_flag = not self.add_arrow_flag

    def add_arrow(self, event):
        if not self.add_arrow_flag:
            return
        if not event.button and event.inaxes:
            self.ax_init = event
            if self.mouse_pressed and event.inaxes.patches:
                event.inaxes.add_patch(event.inaxes.patches[0])
            self.mouse_pressed = False
        try:
            if event.button and event.inaxes and event.inaxes == self.ax_init.inaxes and not hasattr(
                    event.inaxes, 'azim'):
                self.mouse_pressed = True
                if event.inaxes.patches:
                    event.inaxes.patches.pop()
                arrow = event.inaxes.arrow(self.ax_init.xdata,
                                           self.ax_init.ydata,
                                           event.xdata - self.ax_init.xdata,
                                           event.ydata - self.ax_init.ydata,
                                           width=0.01,
                                           length_includes_head=True,
                                           head_width=0.05,
                                           head_length=0.1,
                                           fc='r',
                                           ec='r')
                event.inaxes.add_patch(arrow)
                # 请恕我无知,我也不懂这里为什么还要pop一次,我不想思考,但的确这样是正确的。
                if event.inaxes.patches:
                    event.inaxes.patches.pop()
                self.canvas.draw()
                self.canvas.flush_events()
        except Exception:
            pass

    def add_point_slot(self):
        self.make_flag_invalid()
        self.add_point_flag = not self.add_point_flag

    # def add_point(self, event):
    #     if not self.add_point_flag:
    #         return
    #     if event.inaxes and event.button and not hasattr(event.inaxes, 'azim'):
    #         x_range = np.array(event.inaxes.get_xlim())
    #         y_range = np.array(event.inaxes.get_ylim())
    #         self.offset = np.sqrt(np.sum((x_range - y_range) ** 2)) / 20 # 将坐标轴范围的1/50视为误差
    #         self.nearest_point = None
    #         d_min = 10 * self.offset
    #         for point in event.inaxes.artists:
    #             xt, yt = point.get_data()
    #             d = ((xt - event.xdata) ** 2 + (yt - event.ydata) ** 2) ** 0.5
    #             if d <= self.offset and d < d_min:  # 如果在误差范围内,移动该点
    #                 d_min = d
    #                 self.nearest_point = point
    #         if self.nearest_point:
    #             new_point = Line2D([event.xdata], [event.ydata], ls="",
    #                                marker='o', markerfacecolor='r',
    #                                animated=False)
    #             event.inaxes.add_artist(new_point)
    #             event.inaxes.artists.remove(self.nearest_point)
    #             self.canvas.restore_region(self.bg)
    #             event.inaxes.draw_artist(new_point)
    #             self.canvas.blit(event.inaxes.bbox)
    #             self.bg = self.canvas.copy_from_bbox(event.inaxes.bbox)
    def add_point(self, event):
        if not self.add_point_flag:
            return
        if event.name == 'pick_event':
            self.artist = event.artist
            return
        if event.name == 'button_press_event' and not self.artist and hasattr(
                event, 'inaxes') and not hasattr(event.inaxes, 'azim'):
            point = Line2D([event.xdata], [event.ydata],
                           ls="",
                           marker='o',
                           markerfacecolor='r',
                           animated=False,
                           pickradius=5,
                           picker=True)
            event.inaxes.add_artist(point)
            self.canvas.draw()
            return
        if event.name == 'motion_notify_event' and self.artist and hasattr(
                event, 'inaxes') and event.button and not hasattr(
                    event.inaxes, 'azim'):
            xy = self.artist.get_data()
            if len(xy[0]) == 1:  # 判断该对象是否是一个点。
                self.artist.set_data(([event.xdata], [event.ydata]))
                self.canvas.draw()
        if event.name == 'button_release_event':
            self.artist = None

    def add_style_slot(self):
        self.make_flag_invalid()
        self.add_style_flag = not self.add_style_flag

    def add_style(self, event):
        if not self.add_style_flag:
            return
        if event.name == 'pick_event':
            self.artist = event.artist
        if self.artist and event.name == 'button_press_event':
            for line in event.inaxes.lines:
                if self.artist != line:
                    line.set_alpha(0.5)
                else:
                    line.set_alpha(1)
            self.canvas.draw_idle()
            if event.button == 3:
                self.contextMenu.popup(QCursor.pos())  # 2菜单显示的位置
                self.contextMenu.show()
                return
            elif event.button == 1:
                self.artist = None
                return
        if not self.artist and event.name == 'button_press_event' and event.button == 1:
            for line in event.inaxes.lines:
                line.set_alpha(1)
            self.canvas.draw_idle()

    def show_legend_slot(self):
        legend_titles = []
        for index in range(len(self.current_subplot.lines)):
            legend_titles.append('curve ' + str(index + 1))  # 从1开始算
        if self.current_subplot.lines:  # 如果存在曲线才允许画图例
            leg = self.current_subplot.legend(self.current_subplot.lines,
                                              legend_titles)
            leg.set_draggable(True)  # 设置legend可拖拽
            for legline in leg.get_lines():
                legline.set_pickradius(10)
                legline.set_picker(True)  # 给每个legend设置可点击
            self.canvas.draw()

    def change_legend(self, event):
        if event.name == 'pick_event':
            self.artist = event.artist
        if self.artist and event.name == 'button_press_event' and event.button == 3:
            self.contextMenu.popup(QCursor.pos())  # 2菜单显示的位置
            self.contextMenu.show()

    def show_colorbar_slot(self):
        # print(self.current_subplot.curves)
        pass
        # self.canvas.figure.colorbar(self.canvas,self.current_subplot)

    def rightMenuShow(self):
        self.contextMenu = QMenu()
        self.actionStyle = self.contextMenu.addAction('修改曲线样式')
        self.actionLegend = self.contextMenu.addAction('修改图例样式')
        self.actionCurve = self.contextMenu.addAction('修改曲线类型')
        self.actionStyle.triggered.connect(self.styleHandler)
        self.actionLegend.triggered.connect(self.legendHandler)
        self.actionLegend.triggered.connect(self.curveHandler)

    def styleHandler(self):
        print(self.artist)

    def legendHandler(self):
        print(self.artist)

    def curveHandler(self):
        print(self.artist)

    def mainView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=0, elev=0)
            self.canvas.draw()

    def leftView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=90, elev=0)
            self.canvas.draw()

    def rightView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=-90, elev=0)
            self.canvas.draw()

    def topView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=0, elev=90)
            self.canvas.draw()

    def bottomView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=0, elev=-90)
            self.canvas.draw()

    def backView_slot(self):
        if hasattr(self.current_subplot, 'azim'):
            self.current_subplot.view_init(azim=180, elev=0)
            self.canvas.draw()

    def show_grid_slot(self):
        self.show_grid_flag = not self.show_grid_flag
        self.current_subplot.grid(self.show_grid_flag)
        self.canvas.draw_idle()

    def init_gui(self):
        self.toolbar._update_view()

    def combobox_slot(self):
        self.current_subplot = self.axes[
            self.comboBox.currentIndex()]  # 将当前选择付给子图对象

    def axes_control_slot(self):
        if not self.current_subplot:
            QtWidgets.QMessageBox.warning(self.canvas.parent(), "错误",
                                          "没有可选的子图!")
            return
        Ui_Form_Manager(self.current_subplot, self.canvas)

    def show(self):
        super().show()
Beispiel #10
0
class subWindow(QMainWindow):
    
    #set up a signal so that the window closes when the main window closes
    #closeWindow = QtCore.pyqtSignal()
    
    #replot = pyqtsignal()
    
    def __init__(self,parent=None):
        super(QMainWindow, self).__init__(parent)
        self.parent=parent
        
        self.effExpTime = .010  #units are s
        
        self.create_main_frame()
        self.activePixel = parent.activePixel
        self.parent.updateActivePix.connect(self.setActivePixel)
        self.a = parent.a
        self.spinbox_startTime = parent.spinbox_startTime
        self.spinbox_integrationTime = parent.spinbox_integrationTime
        self.spinbox_startLambda = parent.spinbox_startLambda
        self.spinbox_stopLambda = parent.spinbox_stopLambda
        self.image = parent.image
        self.beamFlagMask = parent.beamFlagMask
        self.apertureRadius = 2.27/2   #Taken from Seth's paper (not yet published in Jan 2018)
        self.apertureOn = False
        self.lineColor = 'blue'
        

    def getAxes(self):
        return self.ax
        
    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        self.main_frame = QWidget()
        
        # Figure
        self.dpi = 100
        self.fig = Figure((3.0, 2.0), dpi=self.dpi, tight_layout=True)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax = self.fig.add_subplot(111)
        
        #create a navigation toolbar for the plot window
        self.toolbar = NavigationToolbar(self.canvas, self)
        
        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)
        
        #check if we need effective exposure time controls in the window, and add them if we do. 
        try:
            self.spinbox_effExpTime
        except:
            pass
        else:
            label_expTime = QLabel('effective exposure time [ms]')
            button_plot = QPushButton("Plot")
            
            hbox_expTimeControl = QHBoxLayout()
            hbox_expTimeControl.addWidget(label_expTime)
            hbox_expTimeControl.addWidget(self.spinbox_effExpTime)
            hbox_expTimeControl.addWidget(button_plot)
            vbox_plot.addLayout(hbox_expTimeControl)
            
            self.spinbox_effExpTime.setMinimum(1)
            self.spinbox_effExpTime.setMaximum(200)
            self.spinbox_effExpTime.setValue(1000*self.effExpTime)
            button_plot.clicked.connect(self.plotData)
            
        vbox_plot.addWidget(self.toolbar)

        
        #combine everything into another vbox
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)
        
        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)
                
        
    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()     
        
        
    def setActivePixel(self):
        self.activePixel = self.parent.activePixel
#        if self.parent.image[self.activePixel[0],self.activePixel[1]] ==0: #only plot data from good pixels
#            self.lineColor = 'red'
#        else:
#            self.lineColor = 'blue'
        try:
            self.plotData() #put this in a try statement in case it doesn't work. This way it won't kill the whole gui. 
        except:
            pass


    def plotData(self):
        return #just create a dummy function that we'll redefine in the child classes
                # this way the signal to update the plots is handled entirely 
                # by this subWindow base class
                
    def getPhotonList(self):
        #use this function to make the call to the correct obsfile method
        if self.apertureOn == True:
            photonList,aperture = self.a.getCircularAperturePhotonList(self.activePixel[0], self.activePixel[1], radius = self.apertureRadius, firstSec = self.spinbox_startTime.value(), integrationTime=self.spinbox_integrationTime.value(), wvlStart = self.spinbox_startLambda.value(), wvlStop=self.spinbox_stopLambda.value(), flagToUse=0)
        
        else:
            photonList = self.a.getPixelPhotonList(self.activePixel[0], self.activePixel[1], firstSec = self.spinbox_startTime.value(), integrationTime=self.spinbox_integrationTime.value(), wvlStart=self.spinbox_startLambda.value(),wvlStop=self.spinbox_stopLambda.value())
        
        return photonList
            
            
            
        
    def getLightCurve(self):
        #take a time stream and bin it up into a lightcurve
        #in other words, take a list of photon time stamps and figure out the 
        #intensity during each exposureTime, which is ~.01 sec
        
        self.histBinEdges = np.arange(self.spinbox_startTime.value(),self.spinbox_startTime.value()+self.spinbox_integrationTime.value(),self.effExpTime)
        self.hist,_ = np.histogram(self.photonList['Time']/10**6,bins=self.histBinEdges) #if histBinEdges has N elements, hist has N-1
        lightCurveIntensityCounts = 1.*self.hist  #units are photon counts
        lightCurveIntensity = 1.*self.hist/self.effExpTime  #units are counts/sec
        lightCurveTimes = self.histBinEdges[:-1] + 1.0*self.effExpTime/2
        
        return lightCurveIntensityCounts, lightCurveIntensity, lightCurveTimes
class mainWindow(QMainWindow):
    
    
    updateActivePix = pyqtSignal()

    def __init__(self,parent=None):
        QMainWindow.__init__(self,parent=parent)
        self.initializeEmptyArrays()
        self.setWindowTitle('quickLook_img.py')
        self.resize(600,850)  #(600,850 works for clint's laptop screen. Units are pixels I think.)
        self.create_main_frame()
        self.create_status_bar()
        self.createMenu()
        #self.load_beam_map()
        
        
    def initializeEmptyArrays(self,nCol = 80,nRow = 125):
        self.nCol = nCol
        self.nRow = nRow

        self.rawCountsImage = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.hotPixMask = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.hotPixCut = 2400
        self.image = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.beamFlagMask = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))

        
        

    def load_IMG_filenames(self,filename):
        print('\nloading img filenames')
               
        self.imgPath = os.path.dirname(filename)        
        fileListRaw = []
        timeStampList = np.array([])
        ii = 0
        for file in os.listdir(self.imgPath):
            if file.endswith(".img"):
                fileListRaw = fileListRaw + [os.path.join(self.imgPath, file)]
                timeStampList = np.append(timeStampList,np.fromstring(file[:-4],dtype=int, sep=' ')[0])
            else:
                continue
            ii+=1


        #the files may not be in chronological order, so let's enforce it
        fileListRaw = np.asarray(fileListRaw)
        fileListRaw = fileListRaw[np.argsort(timeStampList)]
        timeStampList = np.sort(np.asarray(timeStampList))

        self.fileListRaw = fileListRaw
        self.timeStampList = timeStampList

        print('\nfound {:d} .img files\n'.format(len(self.timeStampList)))
        print('first timestamp: ',self.timeStampList[0])
        print('last timestamp:  ',self.timeStampList[-1],'\n')

        

    def load_log_filenames(self):
        #check if directory exists
        if not os.path.exists(self.logPath):
            text = 'log file path not found.\n Check log file path.'
            self.label_log.setText(text)

            self.logTimestampList = np.asarray([])
            self.logFilenameList = np.asarray([])

            return



        #load the log filenames
        print('\nloading log filenames\n')
        logFilenameList = []
        logTimestampList = []
        
        for logFilename in os.listdir(self.logPath):
            
            if logFilename.endswith("telescope.log"):
                continue
            elif logFilename.endswith(".log"):
                logFilenameList.append(logFilename)
                logTimestampList.append(np.fromstring(logFilename[:10],dtype=int, sep=' ')[0])
                

        #the files may not be in chronological order, so let's enforce it
        logFilenameList = np.asarray(logFilenameList)
        logFilenameList = logFilenameList[np.argsort(logTimestampList)]
        logTimestampList = np.sort(np.asarray(logTimestampList))
        
        self.logTimestampList = np.asarray(logTimestampList)
        self.logFilenameList = logFilenameList


    def load_beam_map(self):
        filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', '/mnt/data0/Darkness/20180522/Beammap/',filter = '*.txt')
        resID, flag, xPos, yPos = np.loadtxt(filename, unpack=True,dtype = int)

        #resID, flag, xPos, yPos = np.loadtxt('/mnt/data0/Darkness/20180522/Beammap/finalMap_20180524.txt', unpack=True,dtype = int)

        temp = np.nonzero(flag) #get the indices of the nonzero elements. 

        self.beamFlagMask[yPos[temp]][xPos[temp]]=1 #beamFlagMask is 1 when the pixel is not beam mapped
        #self.beamFlagMask = beamFlagMask



    def initialize_spinbox_values(self,filename):
        #set up the spinbox limits and start value, which will be the file you selected
        self.spinbox_imgTimestamp.setMinimum(self.timeStampList[0])
        self.spinbox_imgTimestamp.setMaximum(self.timeStampList[-1])
        self.spinbox_imgTimestamp.setValue(np.fromstring(os.path.basename(filename)[:-4],dtype=int, sep=' ')[0])
        
        self.spinbox_darkStart.setMinimum(self.timeStampList[0])
        self.spinbox_darkStart.setMaximum(self.timeStampList[-10])
        self.spinbox_darkStart.setValue(np.fromstring(os.path.basename(filename)[:-4],dtype=int, sep=' ')[0])
        

        
        
    def plotImage(self,filename = None):        
        
        if filename == None:
            filename = self.fileListRaw[np.where(self.timeStampList==self.spinbox_imgTimestamp.value())[0][0]]      

        self.ax1.clear()         
        
        self.rawImage = np.transpose(np.reshape(np.fromfile(open(filename, mode='rb'),dtype=np.uint16), (self.nCol,self.nRow)))        
        
        if self.checkbox_darkSubtract.isChecked():
            self.cleanedImage = self.rawImage - self.darkFrame
            self.cleanedImage[np.where(self.cleanedImage<0)] = 0
            
        else:
            self.cleanedImage = self.rawImage
        
        #colorbar auto
        if self.checkbox_colorbar_auto.isChecked():
            self.cbarLimits = np.array([0,np.amax(self.image)])
            self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
            self.fig.cbar.draw_all()
        else:
            self.cbarLimits = np.array([0,self.spinbox_colorBarMax.value()])
            self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
            self.fig.cbar.draw_all()
                  
        self.cleanedImage[np.where(self.cleanedImage>self.hotPixCut)] = 0
        self.cleanedImage = self.cleanedImage*np.logical_not(self.beamFlagMask)
        self.image = self.cleanedImage
        self.ax1.imshow(self.image,vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
        self.ax1.axis('off')
        
        
        self.draw()
        
        

    def getDarkFrame(self):
        #get an average dark from darkStart to darkStart + darkIntTime
        darkIntTime = self.spinbox_darkIntTime.value()
        darkFrame = np.zeros(darkIntTime*self.nRow*self.nCol).reshape((darkIntTime,self.nRow,self.nCol))
        
        for ii in range(darkIntTime):
            try:
                darkFrameFilename = self.fileListRaw[np.where(self.timeStampList==(self.spinbox_darkStart.value()+ii))[0][0]]
            except:
                pass
            else:
                darkFrame[ii] = np.transpose(np.reshape(np.fromfile(open(darkFrameFilename, mode='rb'),dtype=np.uint16), (self.nCol,self.nRow)))

        self.darkFrame = np.median(darkFrame,axis=0)

        
        
        
    def plotBlank(self):
        self.ax1.imshow(np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol)))
        
        
        
    def updateLogLabel(self,IMG_fileExists = True):

        timestamp = self.spinbox_imgTimestamp.value()

        #check if self.logTimestampList has more than zero entries. If not, return.
        if len(self.logTimestampList)==0:
            text = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S\n\n') + 'no log file found.\n Check log file path.'
            self.label_log.setText(text)
            return

        
        #check if the img exists, if not then return
        if IMG_fileExists==False:
            text = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S\n\n') + 'no .img file found'
            self.label_log.setText(text)
            return

        #check if a nearby log file exists, then pick the closest one
        diffs = timestamp - self.logTimestampList
        if np.sum(np.abs(diffs)<3600)==0: #nearby means within 1 hour. 
            text = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S\n\n') + 'nearest log is ' + str(np.amin(diffs)) + '\nseconds away from img'
            self.label_log.setText(text)
            return

        diffs[np.where(diffs<0)] = np.amax(diffs)

        logLabelTimestamp = self.logTimestampList[np.argmin(diffs)]

        labelFilename = self.logFilenameList[np.where(self.logTimestampList==logLabelTimestamp)[0][0]]
        
        
        #print('labelFilename is ', os.path.join(os.environ['MKID_RAW_PATH'],labelFilename))
        #fin=open(os.path.join(os.environ['MKID_RAW_PATH'],labelFilename),'r')
        fin=open(os.path.join(self.logPath,labelFilename),'r')
        text = 'img timestamp:\n' + datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') + '\n\nLogfile time:\n' + datetime.datetime.fromtimestamp(logLabelTimestamp).strftime('%Y-%m-%d %H:%M:%S\n') + '\n' + labelFilename[:-4] + '\n' + fin.read()
        self.label_log.setText(text)
        fin.close()



    
    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        #Define the plot window. 
        self.main_frame = QWidget()
        self.dpi = 100
        self.fig = Figure(figsize = (5.0, 10.0), dpi=self.dpi, tight_layout=True) #define the figure, set the max size (inches) and resolution. Overall window size is set with QMainWindow parameter. 
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax1 = self.fig.add_subplot(111)
        self.ax1.axis('off')
        self.foo = self.ax1.imshow(self.image,interpolation='none')
        self.fig.cbar = self.fig.colorbar(self.foo)
        
        
        #spinboxes for the img timestamp
        self.spinbox_imgTimestamp = QSpinBox()
        self.spinbox_imgTimestamp.valueChanged.connect(self.spinBoxValueChange)
        
        #spinboxes for specifying dark frames
        self.spinbox_darkStart = QSpinBox()
        self.spinbox_darkStart.valueChanged.connect(self.getDarkFrame)
        self.spinbox_darkIntTime = QSpinBox()
        #set up the limits and initial value of the darkIntTime
        self.spinbox_darkIntTime.setMinimum(1)
        self.spinbox_darkIntTime.setMaximum(1000)
        self.spinbox_darkIntTime.setValue(10)
        self.spinbox_darkIntTime.valueChanged.connect(self.getDarkFrame)
        

        
        
        #labels for the start/stop time spinboxes
        label_imgTimestamp = QLabel('IMG timestamp')
        label_darkStart = QLabel('dark Start')
        label_darkIntTime = QLabel('dark int time [s]')
        
        
        #make a checkbox for the colorbar autoscale
        self.checkbox_colorbar_auto = QCheckBox()
        self.checkbox_colorbar_auto.setChecked(False)
        self.checkbox_colorbar_auto.stateChanged.connect(self.spinBoxValueChange)
        
        label_checkbox_colorbar_auto = QLabel('Auto colorbar')
        
        self.spinbox_colorBarMax = QSpinBox()
        self.spinbox_colorBarMax.setRange(1,2500)
        self.spinbox_colorBarMax.setValue(2000)
        self.spinbox_colorBarMax.valueChanged.connect(self.spinBoxValueChange)
        
        
        #make a checkbox for the dark subtract
        self.checkbox_darkSubtract = QCheckBox()
        self.checkbox_darkSubtract.setChecked(False)
        self.checkbox_darkSubtract.stateChanged.connect(self.spinBoxValueChange)
        
        #make a label for the dark subtract checkbox
        label_darkSubtract = QLabel('dark subtract')
        
        
        #make a label for the logs
        self.label_log = QLabel('')
        
        
        #make a label to display the IMG path and the MKID_RAW_PATH. Also set up log path variable
        try:
            os.environ['MKID_IMG_DIR']
        except:
            labelText = 'MKID_IMG_DIR:      could not find MKID_IMG_DIR'
            self.imgPath = '/'
        else:
            labelText = 'MKID_IMG_DIR:      ' + os.environ['MKID_IMG_DIR']
            self.imgPath = os.environ['MKID_IMG_DIR']
        
        self.label_IMG_path = QLabel(labelText)
        self.label_IMG_path.setToolTip('Look for img files in this directory. To change, go to File>Open img file')  

        try:
            os.environ['MKID_RAW_PATH']
        except:
            labelText = 'MKID_RAW_PATH:  could not find MKID_RAW_PATH'
            self.logPath = '/'
        else:
            labelText = 'MKID_RAW_PATH:  ' + os.environ['MKID_RAW_PATH']
            self.logPath = os.environ['MKID_RAW_PATH']

        self.label_log_path = QLabel(labelText)
        self.label_log_path.setToolTip('Look for log files in this directory. To change, go to File>Change log path.') 



    
        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)
        
        #create a v box for the timestamp spinbox
        vbox_imgTimestamp = QVBoxLayout()
        vbox_imgTimestamp.addWidget(label_imgTimestamp)
        vbox_imgTimestamp.addWidget(self.spinbox_imgTimestamp)

        #make an hbox for the dark start
        hbox_darkStart = QHBoxLayout()
        hbox_darkStart.addWidget(label_darkStart)
        hbox_darkStart.addWidget(self.spinbox_darkStart)
        
        #make an hbox for the dark integration time
        hbox_darkIntTime = QHBoxLayout()
        hbox_darkIntTime.addWidget(label_darkIntTime)
        hbox_darkIntTime.addWidget(self.spinbox_darkIntTime)
        
        #make an hbox for the dark subtract checkbox
        hbox_darkSubtract = QHBoxLayout()
        hbox_darkSubtract.addWidget(label_darkSubtract)
        hbox_darkSubtract.addWidget(self.checkbox_darkSubtract)
        
        #make a vbox for the autoscale colorbar
        hbox_autoscale = QHBoxLayout()
        hbox_autoscale.addWidget(label_checkbox_colorbar_auto)
        hbox_autoscale.addWidget(self.checkbox_colorbar_auto)
        hbox_autoscale.addWidget(self.spinbox_colorBarMax)
        
        #make a vbox for dark times
        vbox_darkTimes = QVBoxLayout()
        vbox_darkTimes.addLayout(hbox_darkStart)
        vbox_darkTimes.addLayout(hbox_darkIntTime)
        vbox_darkTimes.addLayout(hbox_darkSubtract)
        vbox_darkTimes.addLayout(hbox_autoscale)
        
        hbox_controls = QHBoxLayout()
        hbox_controls.addLayout(vbox_imgTimestamp)
        hbox_controls.addLayout(vbox_darkTimes)
        hbox_controls.addWidget(self.label_log)


        
        #Now create another vbox, and add the plot vbox and the button's hbox to the new vbox.
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)
#        vbox_combined.addLayout(hbox_imgTimestamp)
        vbox_combined.addLayout(hbox_controls)
        vbox_combined.addWidget(self.label_IMG_path)
        vbox_combined.addWidget(self.label_log_path)
        
        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)
        
        #set up the pyqt5 events
        cid = self.fig.canvas.mpl_connect('motion_notify_event', self.hoverCanvas)
        cid3 = self.fig.canvas.mpl_connect('scroll_event', self.scroll_ColorBar)
        

        
        
    def spinBoxValueChange(self):     
        try:
            filename = self.fileListRaw[np.where(self.timeStampList==self.spinbox_imgTimestamp.value())[0][0]]
        except:
            self.plotBlank()
            self.updateLogLabel(IMG_fileExists = False)
        else:
            self.plotImage(filename)
            self.updateLogLabel()
            
        


        
        
    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()
        
        
    def hoverCanvas(self,event):
        if event.inaxes is self.ax1:
            col = int(round(event.xdata))
            row = int(round(event.ydata))
            if row < self.nRow and col < self.nCol:
                self.status_text.setText('({:d},{:d}) {}'.format(col,row,self.image[row,col]))
                
                
    def scroll_ColorBar(self,event):
        if event.inaxes is self.fig.cbar.ax:
            stepSize = 0.1  #fractional change in the colorbar scale
            if event.button == 'up':
                self.cbarLimits[1] *= (1 + stepSize)   #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
            elif event.button == 'down':
                self.cbarLimits[1] *= (1 - stepSize)   #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
                
            else:
                pass
                
        self.draw()
        
        


                
                
        
    def create_status_bar(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/popup.py
        self.status_text = QLabel("")
        self.statusBar().addWidget(self.status_text, 1)
        
        
    def createMenu(self):   
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/quicklook.py
        self.menubar = self.menuBar()
        self.fileMenu = self.menubar.addMenu("&File")

        openFileButton = QAction('Open img File', self)
        openFileButton.setShortcut('Ctrl+O')
        openFileButton.setStatusTip('Open an img File')
        openFileButton.triggered.connect(lambda x: self.getFileNameFromUser(fileType = 'img'))
        self.fileMenu.addAction(openFileButton)


        changeLogDirectory_Button = QAction('Change log directory', self)
        changeLogDirectory_Button.setShortcut('Ctrl+l')
        changeLogDirectory_Button.setStatusTip('Opens a dialog box so user can select log file manually.')
        changeLogDirectory_Button.triggered.connect(lambda x: self.getFileNameFromUser(fileType = 'log'))
        self.fileMenu.addAction(changeLogDirectory_Button)

        self.fileMenu.addSeparator()
        
        
        exitButton = QAction('Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.setStatusTip('Exit application')
        exitButton.triggered.connect(self.close)
        self.fileMenu.addAction(exitButton)
        
      
        self.menubar.setNativeMenuBar(False) #This is for MAC OS


        
        
    def getFileNameFromUser(self,fileType):
        # look at this website for useful examples
        # https://pythonspot.com/pyqt5-file-dialog/
        if fileType == 'img':
            try:
                filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', self.imgPath,filter = '*.img')
            except:
                filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', '/',filter = '*.img')
                    
            
            if filename=='':
                print('\nno file selected\n')
                return

            self.imgPath = os.path.dirname(filename)
            self.label_IMG_path.setText('img path:  ' + self.imgPath)

            self.filename = filename
            self.load_IMG_filenames(self.filename)
            self.load_log_filenames()
            self.initialize_spinbox_values(self.filename)

        elif fileType == 'log':
            try:
                filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', self.logPath,filter = '*.log')
            except:
                filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', '/',filter = '*.log')
                    
            
            if filename=='':
                print('\nno file selected\n')
                return
            
            self.logPath = os.path.dirname(filename)
            self.label_log_path.setText('log path:  ' + self.logPath)
            self.load_log_filenames()
            self.updateLogLabel()

        else:
            return
Beispiel #12
0
class MainWindow(QtWidgets.QMainWindow):
    """ Podedovano glavno okno
    """
    def __init__(self):
        """ Konstruktor MainWindow objekta
        """
        QtWidgets.QMainWindow.__init__(self)
        self.setWindowTitle('ADXL - UI')
        self.setWindowIcon(QtGui.QIcon("Logo.png"))
        self.setGeometry(50, 50, 600, 400)
        self.showMaximized()
        self.init_central_widget()
        self.init_actions()
        self.init_menus()
        self.statusBar()

    def init_central_widget(self):
        """ Vsebina centralnega okna
        """
        self.central_widget = QtWidgets.QWidget()
        self.buttons_widget = QtWidgets.QWidget()
        v_layout = QtWidgets.QVBoxLayout()
        h_layout = QtWidgets.QHBoxLayout()
        self.function_text = QtWidgets.QTextEdit()
        self.function_text.setFontPointSize(30)
        self.function_text.setText('Vnesite čas zajema [ms]')
        #self.submit_btn = QtWidgets.QPushButton('Prikaži')
        #self.submit_btn.pressed.connect(self.data_aq())
        self.animate_btn = QtWidgets.QPushButton('Zaženi meritev')
        self.animate_btn.pressed.connect(self.animate_figure)
        self.animate_btn.setCheckable(True)
        self.get_figure()

        self.central_widget.setLayout(v_layout)
        v_layout.addWidget(self.function_text)
        v_layout.addWidget(self.buttons_widget)
        v_layout.addWidget(self.canvas)
        v_layout.addWidget(self.canvas_toolbar)

        self.buttons_widget.setLayout(h_layout)
        h_layout.addStretch()
        h_layout.addWidget(self.animate_btn)
        #h_layout.addWidget(self.submit_btn)
        h_layout.addStretch()

        self.setCentralWidget(self.central_widget)

    def get_figure(self):
        self.fig = Figure(figsize=(600, 600),
                          dpi=72,
                          facecolor=(1, 1, 1),
                          edgecolor=(0, 0, 0))
        self.ax = self.fig.add_subplot(111)

        self.os, = self.ax.plot(np.linspace(0, 200, 10),
                                np.random.randint(100, 140,
                                                  10))  # Začetno stanje
        #self.ax.set_ylim([100, 200])
        self.ax.set_xlabel('$t$ $[s]$', fontsize=24)
        self.ax.set_ylabel('$a$ $[m/s^2]$', fontsize=24)
        self.ax.tick_params(axis='both', which='major', labelsize=16)

        self.canvas = FigureCanvasQTAgg(self.fig)
        self.canvas_toolbar = NavigationToolbar(self.canvas, self)

    def init_menus(self):
        """ Pripravi menuje
        """
        self.file_menu = self.menuBar().addMenu('&Datoteka')

        self.file_menu.addAction(self.close_app_action)

    def close_app(self):
        choice = QtWidgets.QMessageBox.question(
            self, "Zapiranje", "Želite zapustiti aplikacijo?",
            QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No)

        if choice == QtWidgets.QMessageBox.Yes:
            print("Zapuščam aplikacijo.")
            sys.exit()
        else:
            pass

    def init_actions(self):
        """ Pripravi actions za menuje
        """
        self.close_app_action = QtWidgets.QAction(
            '&Izhod',
            self,
            shortcut=QtGui.QKeySequence.Cancel,
            statusTip="Izhod iz aplikacije",
            triggered=self.close_app)

    def animate_figure(self):

        try:
            tk = int(self.function_text.toPlainText()) * 1000
        except AttributeError:
            QtWidgets.QMessageBox.about(self, 'Napaka', 'Vnesite čas zajema!')

        try:
            conn, add = Povezava()
        except:
            QtWidgets.QMessageBox.about(self, 'Napaka',
                                        'Povezava ni vzpostavljena!')

        dt = 10000  # Načeloma bi tukaj prišla komun. Pyt - Ard
        stp = 49
        i = 1

        t = np.array([], dtype=int)  # Prazen numerični seznam časov
        a = np.array([], dtype=int)  # Prazen numerični seznam pospeškov

        while i <= tk / (dt * stp):  # Prejemanje paketov
            paket = ""  # Prazen paket

            buf = conn.recv(4096)

            if len(buf) > 0:  # Se izpolni, ce paket ni prazen

                paket = buf.decode()

            tocke = paket.split("\n")  # Ločevanje točk
            for tocka in tocke:  # Ločevanje pospeška in časa
                if len(tocka) > 1:
                    y, x = tocka.split('\t')
                    t = np.append(t,
                                  float(x) /
                                  1000000)  # Pripenjanje časov v skupno tabelo
                    a = np.append(
                        a, int(y))  # Pripenjanje pospeškov v skupno tabelo

            self.os.set_data(t[(i - 1) * stp:(i - 1) * stp + stp] - t[0],
                             a[(i - 1) * stp:(i - 1) * stp +
                               stp])  # Osveževanje grafa
            self.ax.set_xlim(t[(i - 1) * stp] - t[0],
                             t[(i - 1) * stp + stp] - t[0])  # Meje x osi
            self.canvas.draw()
            self.canvas.flush_events()

            i += 1  # Inkrement števca paketov
Beispiel #13
0
class Window(QDialog):
    def __init__(self, parent=None):
        super(Window, self).__init__(parent)

        self.figure = plt.figure()
        self.canvas = FigureCanvas(self.figure)
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.button = QPushButton('Start Scan')
        self.button.clicked.connect(self.plot)

        layout = QVBoxLayout()
        layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)
        layout.addWidget(self.button)
        self.setLayout(layout)

    def plot(self):
        n = 1
        posY = 0
        posX = 0
        x1 = np.arange(0, 32, 1)
        y1 = np.arange(0, 50, 1)
        xs1, ys1 = np.meshgrid(x1, y1)
        zz1 = np.array([
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #1
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #2
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #3
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #4
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #5
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #6
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #7
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #8
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #9
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #10
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #11
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #12
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #13
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #14
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #15
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #16
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #17
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #18
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #19
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #20
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #21
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #22
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #23
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #24
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #25
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #26
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #27
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #28
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #29
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #30
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #31
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #32
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #33
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #34
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #35
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #36
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #37
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #37
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #39 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #40 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #41 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #42 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #43 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #44 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #45 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #46 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #47 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #48 
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ],  #49        
            [
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            ]
        ])  #50

        for y in zz1:
            # Y Ganjil -> dari kiri ke kanan
            if (n == 1):
                posX = 0
                for x in y:
                    # lakukan scan ultrasonic
                    zz1[posY][posX] = get_range()
                    print(zz1[posY][posX])
                    posX += 1
                    self.figure.clear()
                    ax = Axes3D(self.figure)
                    ax.plot_surface(xs1,
                                    ys1,
                                    zz1,
                                    rstride=1,
                                    cstride=1,
                                    cmap=cm.coolwarm)
                    self.canvas.draw()
                    self.canvas.flush_events()
                    stepA(True, 1000, 100)
                n = 0
                stepB(True, 1000, 100)
            # Y Genap -> dari kanan ke kiri
            else:
                posX -= 1
                for x in reversed(y):
                    # lakukan scan ultrasonic
                    zz1[posY][posX] = get_range()
                    print(zz1[posY][posX])
                    posX -= 1
                    self.figure.clear()
                    ax = Axes3D(self.figure)
                    ax.plot_surface(xs1,
                                    ys1,
                                    zz1,
                                    rstride=1,
                                    cstride=1,
                                    cmap=cm.coolwarm)
                    self.canvas.draw()
                    self.canvas.flush_events()
                    stepA(False, 1000, 100)
                n = 1
                stepB(True, 1000, 100)
            posY += 1
        print(zz1)
        reset()
Beispiel #14
0
class mainWindow(QMainWindow):

    updateActivePix = pyqtSignal()

    def __init__(self, parent=None):
        QMainWindow.__init__(self, parent=parent)
        self.initializeEmptyArrays()
        self.setWindowTitle('quickLook.py')
        self.resize(
            600, 850
        )  #(600,850 works for clint's laptop screen. Units are pixels I think.)
        self.create_main_frame()
        self.create_status_bar()
        self.createMenu()
        self.plotNoise()

    def initializeEmptyArrays(self, nCol=10, nRow=10):
        self.nCol = nCol
        self.nRow = nRow
        self.IcMap = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.IsMap = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.IcIsMap = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.rawCountsImage = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.hotPixMask = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.hotPixCut = 2300
        self.image = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.activePixel = [0, 0]
        self.sWindowList = []

    def loadDataFromH5(self, *args):
        #a = darkObsFile.ObsFile('/Users/clint/Documents/mazinlab/ScienceData/PAL2017b/20171004/1507175503.h5')
        if os.path.isfile(self.filename):
            try:
                self.a = ObsFile(self.filename)
            except:
                print('darkObsFile failed to load file. Check filename.\n',
                      self.filename)
            else:
                print('data loaded from .h5 file')
                self.h5_filename_label.setText(self.filename)
                self.initializeEmptyArrays(len(self.a.beamImage),
                                           len(self.a.beamImage[0]))
                self.beamFlagImage = np.transpose(self.a.beamFlagImage.read())
                self.beamFlagMask = self.beamFlagImage == 0  #make a mask. 0 for good beam map
                self.makeHotPixMask()
                self.radio_button_beamFlagImage.setChecked(True)
                self.callPlotMethod()
                #set the max integration time to the h5 exp time in the header
                self.expTime = self.a.getFromHeader('expTime')
                self.wvlBinStart = self.a.getFromHeader('wvlBinStart')
                self.wvlBinEnd = self.a.getFromHeader('wvlBinEnd')

                #set the max and min values for the lambda spinboxes
                #                self.spinbox_startLambda.setMinimum(self.wvlBinStart)
                self.spinbox_stopLambda.setMinimum(self.wvlBinStart)
                self.spinbox_startLambda.setMaximum(self.wvlBinEnd)
                self.spinbox_stopLambda.setMaximum(self.wvlBinEnd)
                self.spinbox_startLambda.setValue(self.wvlBinStart)
                self.spinbox_stopLambda.setValue(self.wvlBinEnd)

                #set the max value of the integration time spinbox
                self.spinbox_startTime.setMinimum(0)
                self.spinbox_startTime.setMaximum(self.expTime)
                self.spinbox_integrationTime.setMinimum(0)
                self.spinbox_integrationTime.setMaximum(self.expTime)
                self.spinbox_integrationTime.setValue(self.expTime)

    def plotBeamImage(self):
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:

            #clear the axes
            self.ax1.clear()

            self.image = self.beamFlagImage

            self.cbarLimits = np.array(
                [np.amin(self.image), np.amax(self.image)])

            self.ax1.imshow(self.image, interpolation='none')
            self.fig.cbar.set_clim(np.amin(self.image), np.amax(self.image))
            self.fig.cbar.draw_all()

            self.ax1.set_title('beam flag image')

            self.ax1.axis('off')

            self.cursor = Cursor(self.ax1,
                                 useblit=True,
                                 color='red',
                                 linewidth=2)

            self.draw()

    def plotImage(self, *args):
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:

            #clear the axes
            self.ax1.clear()

            temp = self.a.getPixelCountImage(
                firstSec=self.spinbox_startTime.value(),
                integrationTime=self.spinbox_integrationTime.value(),
                applyWeight=False,
                flagToUse=0,
                wvlStart=self.spinbox_startLambda.value(),
                wvlStop=self.spinbox_stopLambda.value())
            self.rawCountsImage = np.transpose(temp['image'])

            self.image = self.rawCountsImage
            self.image[np.where(np.logical_not(np.isfinite(self.image)))] = 0
            #            self.image = self.rawCountsImage*self.beamFlagMask
            #self.image = self.rawCountsImage*self.beamFlagMask*self.hotPixMask
            self.image = 1.0 * self.image / self.spinbox_integrationTime.value(
            )

            self.cbarLimits = np.array(
                [np.amin(self.image), np.amax(self.image)])

            self.ax1.imshow(self.image,
                            interpolation='none',
                            vmin=self.cbarLimits[0],
                            vmax=self.cbarLimits[1])

            self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
            self.fig.cbar.draw_all()

            self.ax1.set_title('Raw counts')

            self.ax1.axis('off')

            self.cursor = Cursor(self.ax1,
                                 useblit=True,
                                 color='red',
                                 linewidth=2)

            #self.ax1.plot(np.arange(10),np.arange(10)**2)

            self.draw()

    def plotIcIs(self):
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:
            self.ax1.clear()  #clear the axes

            for col in range(self.nCol):
                print(col, '/80')
                for row in range(self.nRow):
                    photonList = self.a.getPixelPhotonList(
                        col,
                        row,
                        firstSec=self.spinbox_startTime.value(),
                        integrationTime=self.spinbox_integrationTime.value(),
                        wvlStart=self.spinbox_startLambda.value(),
                        wvlStop=self.spinbox_stopLambda.value())

                    effExpTime = .00001  #10 ms/1000

                    lightCurveIntensityCounts, lightCurveIntensity, lightCurveTimes = binnedRE.getLightCurve(
                        photonList['Time'], self.spinbox_startTime.value(),
                        self.spinbox_startTime.value() +
                        self.spinbox_integrationTime.value(), effExpTime)

                    intensityHist, bins = binnedRE.histogramLC(
                        lightCurveIntensityCounts)
                    # [self.intensityHist] = counts

                    Nbins = max(30, len(bins))

                    if np.sum(lightCurveIntensityCounts) > 0:

                        Ic_final, Is_final, covMatrix = binnedRE.fitBlurredMR(
                            bins, intensityHist, effExpTime)

                        self.IcMap[row][col] = Ic_final
                        self.IsMap[row][col] = Is_final

            self.image = self.IsMap
            self.image[np.where(np.logical_not(np.isfinite(self.image)))] = 0
            #            self.image = self.rawCountsImage*self.beamFlagMask
            #self.image = self.rawCountsImage*self.beamFlagMask*self.hotPixMask
            #            self.image = 1.0*self.image/self.spinbox_integrationTime.value()

            self.cbarLimits = np.array(
                [np.amin(self.image), np.amax(self.image)])

            self.ax1.imshow(self.image,
                            interpolation='none',
                            vmin=self.cbarLimits[0],
                            vmax=self.cbarLimits[1])

            self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
            self.fig.cbar.draw_all()

            self.ax1.set_title('Raw counts')

            self.ax1.axis('off')

            self.cursor = Cursor(self.ax1,
                                 useblit=True,
                                 color='red',
                                 linewidth=2)

            self.draw()

    def plotNoise(self, *args):
        #clear the axes
        self.ax1.clear()

        #debugging- generate some noise to plot
        self.image = np.random.randn(self.nRow, self.nCol)

        self.foo = self.ax1.imshow(self.image, interpolation='none')
        self.cbarLimits = np.array([np.amin(self.image), np.amax(self.image)])
        self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
        self.fig.cbar.draw_all()

        self.ax1.set_title('some generated noise...')

        self.ax1.axis('off')

        self.cursor = Cursor(self.ax1, useblit=True, color='red', linewidth=2)

        self.draw()

    def callPlotMethod(self):
        if self.radio_button_img.isChecked() == True:
            self.plotNoise()
        elif self.radio_button_ic_is.isChecked() == True:
            self.plotIcIs()
        elif self.radio_button_beamFlagImage.isChecked() == True:
            self.plotBeamImage()
        elif self.radio_button_rawCounts.isChecked() == True:
            self.plotImage()
        else:
            self.plotNoise()

    def makeHotPixMask(self):
        #if self.hotPixMask[row][col] = 0, it's a hot pixel. If 1, it's good.
        temp = self.a.getPixelCountImage(firstSec=0,
                                         integrationTime=1,
                                         applyWeight=False,
                                         flagToUse=0)
        rawCountsImage = np.transpose(temp['image'])
        for col in range(self.nCol):
            for row in range(self.nRow):
                if rawCountsImage[row][col] < self.hotPixCut:
                    self.hotPixMask[row][col] = 1

    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        #Define the plot window.
        self.main_frame = QWidget()
        self.dpi = 100
        self.fig = Figure(
            (5.0, 10.0), dpi=self.dpi,
            tight_layout=True)  #define the figure, set the size and resolution
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax1 = self.fig.add_subplot(111)
        self.foo = self.ax1.imshow(self.image, interpolation='none')
        self.fig.cbar = self.fig.colorbar(self.foo)

        button_plot = QPushButton("Plot image")
        button_plot.setEnabled(True)
        button_plot.setToolTip('Click to update image.')
        button_plot.clicked.connect(self.callPlotMethod)

        button_quickLoad = QPushButton("Quick Load H5")
        button_quickLoad.setEnabled(True)
        button_quickLoad.setToolTip('Will change functionality later.')
        button_quickLoad.clicked.connect(self.quickLoadH5)

        #spinboxes for the start & stop times
        self.spinbox_startTime = QSpinBox()
        self.spinbox_integrationTime = QSpinBox()

        #labels for the start/stop time spinboxes
        label_startTime = QLabel('start time')
        label_integrationTime = QLabel('integration time')

        #spinboxes for the start & stop wavelengths
        self.spinbox_startLambda = QSpinBox()
        self.spinbox_stopLambda = QSpinBox()

        #labels for the start/stop time spinboxes
        label_startLambda = QLabel('start wavelength [nm]')
        label_stopLambda = QLabel('stop wavelength [nm]')

        #label for the filenames
        self.h5_filename_label = QLabel('no file loaded')

        #label for the active pixel
        self.activePixel_label = QLabel('Active Pixel ({},{}) {}'.format(
            self.activePixel[0], self.activePixel[1],
            self.image[self.activePixel[1], self.activePixel[0]]))

        #make the radio buttons
        #self.radio_button_noise = QRadioButton("Noise")
        self.radio_button_img = QRadioButton(".IMG")
        self.radio_button_ic_is = QRadioButton("Ic/Is")
        #self.radio_button_bin = QRadioButton(".bin")
        #self.radio_button_decorrelationTime = QRadioButton("Decorrelation Time")
        self.radio_button_beamFlagImage = QRadioButton("Beam Flag Image")
        self.radio_button_rawCounts = QRadioButton("Raw Counts")
        self.radio_button_img.setChecked(True)

        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)

        #create a v box for the timespan spinboxes
        vbox_timespan = QVBoxLayout()
        vbox_timespan.addWidget(label_startTime)
        vbox_timespan.addWidget(self.spinbox_startTime)
        vbox_timespan.addWidget(label_integrationTime)
        vbox_timespan.addWidget(self.spinbox_integrationTime)

        #create a v box for the wavelength spinboxes
        vbox_lambda = QVBoxLayout()
        vbox_lambda.addWidget(label_startLambda)
        vbox_lambda.addWidget(self.spinbox_startLambda)
        vbox_lambda.addWidget(label_stopLambda)
        vbox_lambda.addWidget(self.spinbox_stopLambda)

        #create an h box for the buttons
        hbox_buttons = QHBoxLayout()
        hbox_buttons.addWidget(button_plot)
        hbox_buttons.addWidget(
            button_quickLoad
        )  #################################################

        #create an h box for the time and lambda v boxes
        hbox_time_lambda = QHBoxLayout()
        hbox_time_lambda.addLayout(vbox_timespan)
        hbox_time_lambda.addLayout(vbox_lambda)

        #create a v box combining spinboxes and buttons
        vbox_time_lambda_buttons = QVBoxLayout()
        vbox_time_lambda_buttons.addLayout(hbox_time_lambda)
        vbox_time_lambda_buttons.addLayout(hbox_buttons)

        #create a v box for the radio buttons
        vbox_radio_buttons = QVBoxLayout()
        #vbox_radio_buttons.addWidget(self.radio_button_noise)
        vbox_radio_buttons.addWidget(self.radio_button_img)
        vbox_radio_buttons.addWidget(self.radio_button_ic_is)
        #vbox_radio_buttons.addWidget(self.radio_button_bin)
        #vbox_radio_buttons.addWidget(self.radio_button_decorrelationTime)
        vbox_radio_buttons.addWidget(self.radio_button_beamFlagImage)
        vbox_radio_buttons.addWidget(self.radio_button_rawCounts)

        #create a h box combining the spinboxes, buttons, and radio buttons
        hbox_controls = QHBoxLayout()
        hbox_controls.addLayout(vbox_time_lambda_buttons)
        hbox_controls.addLayout(vbox_radio_buttons)

        #create a v box for showing the files that are loaded in memory
        vbox_filenames = QVBoxLayout()
        vbox_filenames.addWidget(self.h5_filename_label)
        vbox_filenames.addWidget(self.activePixel_label)

        #Now create another vbox, and add the plot vbox and the button's hbox to the new vbox.
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)
        vbox_combined.addLayout(hbox_controls)
        vbox_combined.addLayout(vbox_filenames)

        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)

        #set up the pyqt5 events
        cid = self.fig.canvas.mpl_connect('motion_notify_event',
                                          self.hoverCanvas)
        cid2 = self.fig.canvas.mpl_connect('button_press_event',
                                           self.mousePressed)
        cid3 = self.fig.canvas.mpl_connect('scroll_event',
                                           self.scroll_ColorBar)

    def quickLoadH5(self):
        self.filename = '/Users/clint/Documents/mazinlab/ScienceData/PAL2017b/20171004/1507175503_old.h5'
        self.loadDataFromH5()

    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()

    def hoverCanvas(self, event):
        if event.inaxes is self.ax1:
            col = int(round(event.xdata))
            row = int(round(event.ydata))
            if row < self.nRow and col < self.nCol:
                self.status_text.setText('({:d},{:d}) {}'.format(
                    col, row, self.image[row, col]))

    def scroll_ColorBar(self, event):
        if event.inaxes is self.fig.cbar.ax:
            stepSize = 0.1  #fractional change in the colorbar scale
            if event.button == 'up':
                self.cbarLimits[1] *= (1 + stepSize)  #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,
                                interpolation='none',
                                vmin=self.cbarLimits[0],
                                vmax=self.cbarLimits[1])
            elif event.button == 'down':
                self.cbarLimits[1] *= (1 - stepSize)  #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,
                                interpolation='none',
                                vmin=self.cbarLimits[0],
                                vmax=self.cbarLimits[1])

            else:
                pass

        self.draw()

    def mousePressed(self, event):
        #        print('\nclick event registered!\n')
        if event.inaxes is self.ax1:  #check if the mouse-click was within the axes.
            #print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %('double' if event.dblclick else 'single', event.button,event.x, event.y, event.xdata, event.ydata))

            if event.button == 1:
                #print('\nit was the left button that was pressed!\n')
                col = int(round(event.xdata))
                row = int(round(event.ydata))
                self.activePixel = [col, row]
                self.activePixel_label.setText(
                    'Active Pixel ({},{}) {}'.format(
                        self.activePixel[0], self.activePixel[1],
                        self.image[self.activePixel[1], self.activePixel[0]]))

                self.updateActivePix.emit(
                )  #emit a signal for other plots to update

            elif event.button == 3:
                print('\nit was the right button that was pressed!\n')

        elif event.inaxes is self.fig.cbar.ax:  #reset the scale bar
            if event.button == 1:
                self.cbarLimits = np.array(
                    [np.amin(self.image),
                     np.amax(self.image)])
                self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,
                                interpolation='none',
                                vmin=self.cbarLimits[0],
                                vmax=self.cbarLimits[1])
                self.draw()
        else:
            pass

    def create_status_bar(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/popup.py
        self.status_text = QLabel("")
        self.statusBar().addWidget(self.status_text, 1)

    def createMenu(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/quicklook.py
        self.menubar = self.menuBar()
        self.fileMenu = self.menubar.addMenu("&File")

        openFileButton = QAction(QIcon('exit24.png'), 'Open H5 File', self)
        openFileButton.setShortcut('Ctrl+O')
        openFileButton.setStatusTip('Open an H5 File')
        openFileButton.triggered.connect(self.getFileNameFromUser)
        self.fileMenu.addAction(openFileButton)

        exitButton = QAction(QIcon('exit24.png'), 'Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.setStatusTip('Exit application')
        exitButton.triggered.connect(self.close)
        self.fileMenu.addAction(exitButton)

        #make a menu for plotting
        self.plotMenu = self.menubar.addMenu("&Plot")
        plotLightCurveButton = QAction('Light Curve', self)
        plotLightCurveButton.triggered.connect(self.makeTimestreamPlot)
        plotIntensityHistogramButton = QAction('Intensity Histogram', self)
        plotIntensityHistogramButton.triggered.connect(
            self.makeIntensityHistogramPlot)
        plotSpectrumButton = QAction('Spectrum', self)
        plotSpectrumButton.triggered.connect(self.makeSpectrumPlot)
        self.plotMenu.addAction(plotLightCurveButton)
        self.plotMenu.addAction(plotIntensityHistogramButton)
        self.plotMenu.addAction(plotSpectrumButton)

        self.menubar.setNativeMenuBar(False)  #This is for MAC OS

    def getFileNameFromUser(self):
        # look at this website for useful examples
        # https://pythonspot.com/pyqt5-file-dialog/
        try:
            def_loc = os.environ['MKID_DATA_DIR']
        except KeyError:
            def_loc = '.'
        filename, _ = QFileDialog.getOpenFileName(self,
                                                  'Select One File',
                                                  def_loc,
                                                  filter='*.h5')

        self.filename = filename
        self.loadDataFromH5(self.filename)

    def makeTimestreamPlot(self):
        sWindow = timeStream(self)
        sWindow.show()
        self.sWindowList.append(sWindow)

    def makeIntensityHistogramPlot(self):
        sWindow = intensityHistogram(self)
        sWindow.show()
        self.sWindowList.append(sWindow)

    def makeSpectrumPlot(self):
        sWindow = spectrum(self)
        sWindow.show()
        self.sWindowList.append(sWindow)
Beispiel #15
0
class subWindow(QMainWindow):

    #set up a signal so that the window closes when the main window closes
    #closeWindow = QtCore.pyqtSignal()

    #replot = pyqtsignal()

    def __init__(self, parent=None):
        super(QMainWindow, self).__init__(parent)
        self.parent = parent

        self.effExpTime = .010  #units are s

        self.create_main_frame()
        self.activePixel = parent.activePixel
        self.parent.updateActivePix.connect(self.setActivePixel)
        self.a = parent.a
        self.spinbox_startTime = parent.spinbox_startTime
        self.spinbox_integrationTime = parent.spinbox_integrationTime
        self.spinbox_startLambda = parent.spinbox_startLambda
        self.spinbox_stopLambda = parent.spinbox_stopLambda
        self.image = parent.image
        self.beamFlagMask = parent.beamFlagMask
        self.apertureRadius = 2.27 / 2  #Taken from Seth's paper (not yet published in Jan 2018)
        self.apertureOn = False
        self.lineColor = 'blue'

    def getAxes(self):
        return self.ax

    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        self.main_frame = QWidget()

        # Figure
        self.dpi = 100
        self.fig = Figure((3.0, 2.0), dpi=self.dpi, tight_layout=True)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax = self.fig.add_subplot(111)

        #create a navigation toolbar for the plot window
        self.toolbar = NavigationToolbar(self.canvas, self)

        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)

        #check if we need effective exposure time controls in the window, and add them if we do.
        try:
            self.spinbox_effExpTime
        except:
            pass
        else:
            label_expTime = QLabel('effective exposure time [ms]')
            button_plot = QPushButton("Plot")

            hbox_expTimeControl = QHBoxLayout()
            hbox_expTimeControl.addWidget(label_expTime)
            hbox_expTimeControl.addWidget(self.spinbox_effExpTime)
            hbox_expTimeControl.addWidget(button_plot)
            vbox_plot.addLayout(hbox_expTimeControl)

            self.spinbox_effExpTime.setMinimum(1)
            self.spinbox_effExpTime.setMaximum(200)
            self.spinbox_effExpTime.setValue(1000 * self.effExpTime)
            button_plot.clicked.connect(self.plotData)

        vbox_plot.addWidget(self.toolbar)

        #combine everything into another vbox
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)

        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)

    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()

    def setActivePixel(self):
        self.activePixel = self.parent.activePixel
        #        if self.parent.image[self.activePixel[0],self.activePixel[1]] ==0: #only plot data from good pixels
        #            self.lineColor = 'red'
        #        else:
        #            self.lineColor = 'blue'
        try:
            self.plotData(
            )  #put this in a try statement in case it doesn't work. This way it won't kill the whole gui.
        except:
            pass

    def plotData(self):
        return  #just create a dummy function that we'll redefine in the child classes
        # this way the signal to update the plots is handled entirely
        # by this subWindow base class

    def getPhotonList(self):
        #use this function to make the call to the correct obsfile method
        if self.apertureOn == True:
            photonList, aperture = self.a.getCircularAperturePhotonList(
                self.activePixel[0],
                self.activePixel[1],
                radius=self.apertureRadius,
                firstSec=self.spinbox_startTime.value(),
                integrationTime=self.spinbox_integrationTime.value(),
                wvlStart=self.spinbox_startLambda.value(),
                wvlStop=self.spinbox_stopLambda.value(),
                flagToUse=0)

        else:
            photonList = self.a.getPixelPhotonList(
                self.activePixel[0],
                self.activePixel[1],
                firstSec=self.spinbox_startTime.value(),
                integrationTime=self.spinbox_integrationTime.value(),
                wvlStart=self.spinbox_startLambda.value(),
                wvlStop=self.spinbox_stopLambda.value())

        return photonList
Beispiel #16
0
class Window(QMainWindow):
    """Class for the whole window.
    """
    def __init__(self, parent=None):
        """Load and initialise the lattices.
        """
        super(Window, self).__init__(parent)
        # Lattice loading
        ring = atip.utils.load_at_lattice('DIAD')
        sp_len = ring.circumference/6.0
        ring.s_range = [0, sp_len]
        self.lattice = ring[ring.i_range[0]:ring.i_range[-1]]# + [ring[1491]]
        """
        self.lattice = at.load_tracy('../atip/atip/rings/for_Tobyn.lat')
        zl = []
        for idx, elem in enumerate(self.lattice):
            elem.Index = idx + 1
            if elem.Length == 0.0:
                zl.append(idx)
        zl.reverse()
        for idx in zl:
            self.lattice.__delitem__(idx)
        print(len(self.lattice))
        """
        self._atsim = atip.simulator.ATSimulator(self.lattice, emit_calc=False)
        self.s_selection = None

        # Super-period support
        self.total_len = self.lattice.get_s_pos(len(self.lattice))[0]
        self.symmetry = 6
        #self.symmetry = vars(self.lattice).get('periodicity', 1)

        # Create UI
        self.initUI()

    def initUI(self):
        """Low level UI building of the core sections and their components.
        """
        # Set initial window size in pixels
        self.setGeometry(0, 0, 1500, 800)
        # Initialise layouts
        layout = QHBoxLayout()
        layout.setSpacing(20)
        self.left_side = QVBoxLayout()
        self.left_side.setAlignment(Qt.AlignLeft)

        # Create graph
        graph = QHBoxLayout()
        self.figure = Figure()
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.canvas.mpl_connect('button_press_event', self.graph_onclick)
        self.figure.set_tight_layout({"pad": 0.5, "w_pad": 0, "h_pad": 0})
        self.plot()
        # Make graph fixed size to prevent autoscaling
        self.canvas.setMinimumWidth(1000)
        self.canvas.setMaximumWidth(1000)
        self.canvas.setMinimumHeight(480)
        self.canvas.setMaximumHeight(480)
        self.graph_width = 1000
        self.graph_height = 480
        graph.addWidget(self.canvas)
        graph.setStretchFactor(self.canvas, 0)
        # Add graph to left side layout
        self.left_side.addLayout(graph)

        # Create lattice representation bar
        self.full_disp = QVBoxLayout()
        self.full_disp.setSpacing(0)
        self.lat_disp = QHBoxLayout()
        self.lat_disp.setSpacing(0)
        self.lat_disp.setContentsMargins(QMargins(0, 0, 0, 0))
        # Add a stretch at both ends to keep the lattice representation centred
        self.lat_disp.addStretch()
        # Add startline
        self.lat_disp.addWidget(element_repr(-1, Qt.black, 1, drag=False))
        # Add elements
        self.lat_repr = self.create_lat_repr()
        for el_repr in self.lat_repr:
            self.lat_disp.addWidget(el_repr)
        # Add endline
        self.lat_disp.addWidget(element_repr(-1, Qt.black, 1, drag=False))
        # Add offset
        self.lat_disp.addWidget(element_repr(-1, Qt.white, 3, drag=False))
        # Add a stretch at both ends to keep the lattice representation centred
        self.lat_disp.addStretch()
        # Add non-zero length representation to lattice representation layout
        self.full_disp.addLayout(self.lat_disp)

        # Add horizontal dividing line
        self.black_bar = QHBoxLayout()
        self.black_bar.setSpacing(0)
        self.black_bar.addStretch()  # Keep it centred
        self.mid_line = element_repr(-1, Qt.black, 1000, height=1, drag=False)
        self.black_bar.addWidget(self.mid_line)
        # Add offset
        self.black_bar.addWidget(element_repr(-1, Qt.white, 3, height=1,
                                              drag=False))
        self.black_bar.addStretch()  # Keep it centred
        self.full_disp.addLayout(self.black_bar)

        # Create zero length element representation bar
        self.zl_disp = QHBoxLayout()
        self.zl_disp.setSpacing(0)
        # Add a stretch at both ends to keep the lattice representation centred
        self.zl_disp.addStretch()
        # Add startline
        self.zl_disp.addWidget(element_repr(-1, Qt.black, 1, drag=False))
        # Add elements
        self.zl_repr = self.calc_zero_len_repr(1000)
        for el_repr in self.zl_repr:
            self.zl_disp.addWidget(el_repr)
        # Add endline
        self.zl_disp.addWidget(element_repr(-1, Qt.black, 1, drag=False))
        # Add offset
        self.zl_disp.addWidget(element_repr(-1, Qt.white, 3, drag=False))
        # Add a stretch at both ends to keep the lattice representation centred
        self.zl_disp.addStretch()
        # Add zero length representation to lattice representation layout
        self.full_disp.addLayout(self.zl_disp)
        # Add full lattice representation to left side layout
        self.left_side.addLayout(self.full_disp)

        # Create element editing boxes to drop to
        bottom = QHBoxLayout()
        self.edit_boxes = []
        # Future possibility to auto determine number of boxes by window size
        for i in range(4):
            box = edit_box(self, self._atsim)
            self.edit_boxes.append(box)
            bottom.addWidget(box)
        # Add edit boxes to left side layout
        self.left_side.addLayout(bottom)

        # All left side components now set, add them to main layout
        layout.addLayout(self.left_side)

        # Create lattice and element data sidebar
        sidebar_border = QWidget()
        # Dividing line
        sidebar_border.setStyleSheet(".QWidget {border-left: 1px solid black}")
        sidebar = QGridLayout(sidebar_border)
        sidebar.setSpacing(10)
        # Determine correct global title
        if self.symmetry == 1:
            title = QLabel("Global Lattice Parameters:")
        else:
            title = QLabel("Global Super Period Parameters:")
        # Ensure sidebar width remains fixed
        title.setMaximumWidth(220)
        title.setMinimumWidth(220)
        title.setStyleSheet("font-weight:bold; text-decoration:underline;")
        sidebar.addWidget(title, 0, 0)
        # Ensure sidebar width remains fixed
        spacer = QLabel("")
        spacer.setMaximumWidth(220)
        spacer.setMinimumWidth(220)
        sidebar.addWidget(spacer, 0, 1)
        self.lattice_data_widgets = {}
        row_count = 1  # start after global title row
        # Create global fields
        for field, value in self.get_lattice_data().items():
            sidebar.addWidget(QLabel("{0}: ".format(field)), row_count, 0)
            lab = QLabel(self.stringify(value))
            sidebar.addWidget(lab, row_count, 1)
            self.lattice_data_widgets[field] = lab
            row_count += 1
        # Add element title
        title = QLabel("Selected Element Parameters:")
        title.setStyleSheet("font-weight:bold; text-decoration:underline;")
        sidebar.addWidget(title, row_count, 0)
        self.element_data_widgets = {}
        row_count += 1  # continue after element title row
        # Create local fields
        for field, value in self.get_element_data(0).items():
            sidebar.addWidget(QLabel("{0}: ".format(field)), row_count, 0)
            lab = QLabel("N/A")  # default until s selection is made
            sidebar.addWidget(lab, row_count, 1)
            self.element_data_widgets[field] = lab
            row_count += 1
        # Add units tool tips where applicable
        self.lattice_data_widgets["Total Length"].setToolTip("m")
        self.lattice_data_widgets["Horizontal Emittance"].setToolTip("pm")
        self.lattice_data_widgets["Linear Dispersion Action"].setToolTip("m")
        self.lattice_data_widgets["Energy Loss per Turn"].setToolTip("eV")
        self.lattice_data_widgets["Damping Times"].setToolTip("msec")
        self.lattice_data_widgets["Total Bend Angle"].setToolTip("deg")
        self.lattice_data_widgets["Total Absolute Bend Angle"].setToolTip("deg")
        self.element_data_widgets["Selected S Position"].setToolTip("m")
        self.element_data_widgets["Element Start S Position"].setToolTip("m")
        self.element_data_widgets["Element Length"].setToolTip("m")
        self.element_data_widgets["Horizontal Linear Dispersion"].setToolTip("m")
        self.element_data_widgets["Beta Function"].setToolTip("m")
        # Add sidebar to main window layout
        layout.addWidget(sidebar_border)

        # Set and display layout
        wid = QWidget(self)
        wid.setLayout(layout)
        self.setCentralWidget(wid)
        self.setStyleSheet("background-color:white;")
        self.show()

    def create_lat_repr(self):
        """Create a list of element representations, in the order that they
        appear in the lattice, colour coded according to their type.
        See also: calc_zero_len_repr
        """
        lat_repr = []
        self.zero_length = []
        self.base_widths = []
        for elem in self.lattice:#[:self.lattice.i_range[-1]]:
            width = math.ceil(elem.Length)
            if width == 0:
                if not (isinstance(elem, at.elements.Drift) or
                        isinstance(elem, at.elements.Marker) or
                        isinstance(elem, at.elements.Aperture)):
                    # don't care about zero length drifts, markers or apertures
                    self.zero_length.append(elem)
            else:
                self.base_widths.append(elem.Length)
                if isinstance(elem, at.elements.Drift):
                    elem_repr = element_repr(elem.Index, Qt.white, width)
                elif isinstance(elem, at.elements.Dipole):
                    elem_repr = element_repr(elem.Index, Qt.green, width)
                elif isinstance(elem, at.elements.Quadrupole):
                    elem_repr = element_repr(elem.Index, Qt.red, width)
                elif isinstance(elem, at.elements.Sextupole):
                    elem_repr = element_repr(elem.Index, Qt.yellow, width)
                elif isinstance(elem, at.elements.Corrector):
                    elem_repr = element_repr(elem.Index, Qt.blue, width)
                else:
                    elem_repr = element_repr(elem.Index, Qt.gray, width)
                lat_repr.append(elem_repr)
        return lat_repr

    def calc_new_width(self, new_width):
        """Calculate the new widths of the element representations so that
        they may be dynamically scaled to fit into the new window size, whilst
        remaining roughly proportional to their lengths.
        """
        scale_factor = new_width / sum(self.base_widths)
        scaled_widths = [width * scale_factor for width in self.base_widths]
        rounding = []
        for index in range(len(scaled_widths)):
            if scaled_widths[index] == 0:
                pass
            elif scaled_widths[index] < 1:
                scaled_widths[index] = 1
            else:
                value = scaled_widths[index]
                scaled_widths[index] = round(value)
                if round(value) >= 2:
                    rounding.append((value, index))
        rounding.sort()  # sort smallest to biggest
        diff = round(sum(scaled_widths) - new_width)
        if abs(diff) > len(rounding):
            raise ValueError("too many elements with 0<length<1")
        if diff > 0:  # overshoot
            for i in range(diff):
                _, index = rounding.pop()
                scaled_widths[index] = numpy.maximum(scaled_widths[index]-1, 1)
        elif diff < 0:  # undershoot
            for i in range(abs(diff)):
                _, index = rounding.pop(0)
                scaled_widths[index] = scaled_widths[index]+1
        return scaled_widths

    def calc_zero_len_repr(self, width):
        """Create element representations for elements in the lattice with 0
        length, to be displayed below the non-zero length element
        representations.
        See also: create_lat_repr
        """
        scale_factor = width / self.total_len
        all_s = self._atsim.get_s()
        positions = [0.0]
        for elem in self.zero_length:
            positions.append(all_s[elem.Index-1] * scale_factor)
        zero_len_repr = []
        for i in range(1, len(positions), 1):
            gap_length = int(round(positions[i] - positions[i-1]))
            # N.B. zero length gap spacers are not drag-and-drop-able as they
            # are not drifts, however this could potentially be added in future
            # to allow zero length elements to be moved.
            zero_len_repr.append(element_repr(-1, Qt.white, gap_length,
                                              drag=False))
            elem = self.zero_length[i-1]
            if isinstance(elem, at.elements.Monitor):
                elem_repr = element_repr(elem.Index, Qt.magenta, 1)
            elif isinstance(elem, at.elements.RFCavity):
                elem_repr = element_repr(elem.Index, Qt.cyan, 1)
            elif isinstance(elem, at.elements.Corrector):
                elem_repr = element_repr(elem.Index, Qt.blue, 1)
            else:
                elem_repr = element_repr(elem.Index, Qt.gray, 1)
            zero_len_repr.append(elem_repr)
        diff = int(sum([el_repr.width for el_repr in zero_len_repr]) - width)
        if diff < 0:  # undershoot
            # unless the last zero length element is very close to the end of
            # the displayed section this should always occur.
            zero_len_repr.append(element_repr(-1, Qt.white, abs(diff),
                                              drag=False))
        elif diff > 0:  # overshoot
            # this should rarely occur
            # add zero len elem_repr at the end to maintain consistent length
            zero_len_repr.append(element_repr(-1, Qt.white, 0, drag=False))
            while diff > 1:
                for i in range(len(zero_len_repr)):
                    el_repr = zero_len_repr[i]
                    if el_repr.width > 1:
                        el_repr.changeSize(el_repr.width - 1)
                        diff -= 1
                    if diff < 1:
                        break
        else:
            # add zero len elem_repr at the end to maintain consistent length
            zero_len_repr.append(element_repr(-1, Qt.white, 0, drag=False))
        return zero_len_repr

    def get_lattice_data(self):
        """Calculate the global linear optics data for the lattice, and return
        it in a dictionary by its field names.
        """
        self._atsim.wait_for_calculations()
        data_dict = OrderedDict()
        data_dict["Number of Elements"] = len(self.lattice)
        data_dict["Total Length"] = self.total_len
        data_dict["Total Bend Angle"] = self._atsim.get_total_bend_angle()
        data_dict["Total Absolute Bend Angle"] = self._atsim.get_total_absolute_bend_angle()
        data_dict["Cell Tune"] = [self._atsim.get_tune('x'),
                                  self._atsim.get_tune('y')]
        data_dict["Linear Chromaticity"] = [self._atsim.get_chromaticity('x'),
                                            self._atsim.get_chromaticity('y')]
        data_dict["Horizontal Emittance"] = self._atsim.get_horizontal_emittance() * 1e12
        data_dict["Linear Dispersion Action"] = self._atsim.get_linear_dispersion_action()
        data_dict["Momentum Spread"] = self._atsim.get_energy_spread()
        data_dict["Linear Momentum Compaction"] = self._atsim.get_momentum_compaction()
        data_dict["Energy Loss per Turn"] = self._atsim.get_energy_loss()
        data_dict["Damping Times"] = self._atsim.get_damping_times() * 1e3
        data_dict["Damping Partition Numbers"] = self._atsim.get_damping_partition_numbers()
        return data_dict

    def get_element_data(self, selected_s_pos):
        """Calculate the local (for the element at the selected s position)
        linear optics data for the lattice, and return it in a dictionary by
        its field names.
        """
        self._atsim.wait_for_calculations()
        data_dict = OrderedDict()
        all_s = self._atsim.get_s()
        index = int(numpy.where([s <= selected_s_pos for s in all_s])[0][-1])
        data_dict["Selected S Position"] = selected_s_pos
        data_dict["Element Index"] = index + 1
        data_dict["Element Start S Position"] = all_s[index]
        data_dict["Element Length"] = self._atsim.get_at_element(index+1).Length
        data_dict["Horizontal Linear Dispersion"] = self._atsim.get_dispersion()[index, 0]
        data_dict["Beta Function"] = self._atsim.get_beta()[index]
        data_dict["Derivative of Beta Function"] = self._atsim.get_alpha()[index]
        data_dict["Normalized Phase Advance"] = self._atsim.get_mu()[index]/(2*numpy.pi)
        return data_dict

    def stringify(self, value):
        """Convert numerical data into a string that can be displayed.
        """
        v = []
        if numpy.issubdtype(type(value), numpy.number):
            value = [value]
        for val in value:
            if isinstance(val, int):
                v.append("{0:d}".format(val))
            else:
                if val == 0:
                    v.append("0.0")
                elif abs(val) < 0.1:
                    v.append("{0:.5e}".format(val))
                else:
                    v.append("{0:.5f}".format(val))
        if len(v) == 1:
            return v[0]
        else:
            return "[" + ', '.join(v) + "]"

    def update_lattice_data(self):
        """Iterate over the global linear optics data and update the values of
        each field. Usually called after a change has been made to the lattice.
        """
        for field, value in self.get_lattice_data().items():
            self.lattice_data_widgets[field].setText(self.stringify(value))

    def update_element_data(self, s_pos):
        """Iterate over the local linear optics data and update the values of
        each field. Usually called when a new s position selection is made.
        """
        for field, value in self.get_element_data(s_pos).items():
            self.element_data_widgets[field].setText(self.stringify(value))

    def plot(self):
        """Plot the graph inside the figure.
        """
        self.figure.clear()
        self.axl = self.figure.add_subplot(111, xmargin=0, ymargin=0.025)
        self.axl.set_xlabel('s position [m]')
        self.axr = self.axl.twinx()
        self.axr.margins(0, 0.025)
        self.lattice.radiation_off()  # ensure radiation state for linopt call
        at.plot.plot_beta(self.lattice, axes=(self.axl, self.axr))
        self.canvas.draw()

    def graph_onclick(self, event):
        """Left click to make an s position selection and display a black
        dashed line at that position on the graph.
        Right click to clear a selection.
        """
        if event.xdata is not None:
            if self.s_selection is not None:
                self.s_selection.remove()  # remove old s selection line
            if event.button == 1:
                self.s_selection = self.axl.axvline(event.xdata, color="black",
                                                    linestyle='--', zorder=3)
                self.update_element_data(event.xdata)
            else:  # if not right click clear selection data
                self.s_selection = None
                for lab in self.element_data_widgets.values():
                    lab.setText("N/A")
            self.canvas.draw()

    def resize_graph(self, width, height, redraw=False):
        """Resize the graph to a new width and(or) height; can also be used to
        force a redraw of the graph, without resizing, by way of the redraw
        argument.
        """
        if not redraw:  # doesn't redraw if not necessary and not forced
            redraw = bool((int(width) != int(self.graph_width)) or
                          (int(height) != int(self.graph_height)))
        if redraw:
            self.canvas.flush_events()
            self.canvas.setMaximumWidth(int(width))
            self.canvas.setMaximumHeight(int(height))
            self.canvas.resize(int(width), int(height))
            self.graph_width = int(width)
            self.graph_height = int(height)

    def refresh_all(self):
        """Refresh the graph, global linear optics data, and local linear
        optics data.
        """
        self.plot()
        self.resizeEvent(None)
        self.update_lattice_data()
        s_pos = self.element_data_widgets["Selected S Position"].text()
        if s_pos != "N/A":
            self.update_element_data(float(s_pos))
            self.s_selection.remove()
            self.s_selection = self.axl.axvline(float(s_pos), color="black",
                                                linestyle='--', zorder=3)
            self.canvas.draw()
        for box in self.edit_boxes:
            box.refresh()

    def resizeEvent(self, event):
        """Called when the window is resized; resizes the graph and lattice
        representation accordingly for the new window size.
        N.B.
            1) The hard-coded pixel offsets are almost entirely arbitrary and
               "just work^TM" for me, but may need to be changed for alignment
               to work properly on a different machine.
            2) All resizing related code is held together by willpower and
               voodoo magic and will break if it senses fear.
        """
        # Determine graph width from window size
        width = int(max([self.frameGeometry().width() - 500, 1000]))
        height = int(max([self.frameGeometry().height() - 350, 480]))
        # Resize graph
        self.resize_graph(width, height)
        # Get non-zero length element representation widths from graph width
        widths = self.calc_new_width(width - 127)
        for el_repr, w in zip(self.lat_repr, widths):
            if w != el_repr.width:
                el_repr.changeSize(w)
        # Two px more to account for end bars
        self.mid_line.changeSize(width - 125)
        # Get lattice representation width from graph width
        zlr = self.calc_zero_len_repr(width - 127)
        zl_widths = [el_repr.width for el_repr in zlr]
        for el_repr, w in zip(self.zl_repr, zl_widths):
            el_repr.changeSize(w)
        # If not a refresh call then resize the window
        if event is not None:
            super().resizeEvent(event)
Beispiel #17
0
class MatplotlibWidget(QWidget):
    '''
    Overlays matplotlib figure onto different parts of the gui
    '''
    def __init__(self,
                 parent=None,
                 axtype='',
                 title='',
                 xlabel='',
                 ylabel='',
                 xlim=None,
                 ylim=None,
                 xscale='linear',
                 yscale='linear',
                 showtoolbar=True,
                 dpi=100,
                 **kwargs):
        super(MatplotlibWidget, self).__init__(parent)
        self.axtype = axtype
        #self.axtype = axtype
        self.leg = ''

        # Make the Figure
        self.fig = Figure(tight_layout={'pad': 0.05}, dpi=dpi)

        # Make the Canvas
        self.canvas = FigureCanvas(self.fig)
        # Have it resize to the Gui
        self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.canvas.setFocusPolicy(Qt.ClickFocus)
        self.canvas.setFocus()

        self.vbox = QVBoxLayout()
        self.vbox.setContentsMargins(0, 0, 0, 0)  # set layout margins
        self.vbox.addWidget(self.canvas)
        self.setLayout(self.vbox)

        # TODO maybe- add tool bar

        # Set the inital axes
        self.ax = self.fig.add_subplot(111, facecolor='none')

        # Draw the canvas
        self.canvas_draw()

    def make_mpl_plot(self,
                      title='',
                      xlabel='',
                      ylabel='',
                      xlim=None,
                      ylim=None,
                      xscale='log',
                      yscale='log',
                      x_data=None,
                      y_data=None,
                      marker='.',
                      *args,
                      **kwargs):

        self.ax.plot(x_data, y_data, marker=marker)

        self.ax.set_xscale(xscale)
        self.ax.set_yscale(yscale)

        self.ax.set_xlabel(xlabel)
        self.ax.set_ylabel(ylabel)

        self.ax.set_title(title)

        # Redraw the new scale
        self.canvas_draw()

    def clr_lines(self):
        # TODO look into clearing the whole plot before
        '''
        Clear all lines in all mpls
        '''
        line_list = self.ax.get_lines()
        # Clear each line in the figure
        for line in line_list:
            self.ax.lines.remove(line)

        #self.reset_ax_lim(ax)
        self.canvas_draw()

    def canvas_draw(self):
        '''
        redraw canvas after data changed
        '''
        self.canvas.draw()
        self.canvas.flush_events()  # flush the GUI events

    def reset_ax_lim(self, ax):
        '''
        reset the lim of ax
        this change the display and where home button goes back to
        '''
        ax.relim(visible_only=True)
        ax.autoscale_view(True, True, True)
Beispiel #18
0
class mainWindow(QMainWindow):
    
    
    updateActivePix = pyqtSignal()

    def __init__(self,parent=None):
        QMainWindow.__init__(self,parent=parent)
        self.initializeEmptyArrays()
        self.setWindowTitle('quickLook.py')
        self.create_main_frame()
        self.create_status_bar()
        self.createMenu()
        self.plotNoise()
        
        
    def initializeEmptyArrays(self,nCol = 10,nRow = 10):
        self.nCol = nCol
        self.nRow = nRow
        self.IcMap = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.IsMap = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.IcIsMap = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.rawCountsImage = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.hotPixMask = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.hotPixCut = 2300
        self.image = np.zeros(self.nRow*self.nCol).reshape((self.nRow,self.nCol))
        self.activePixel = [0,0]
        self.sWindowList = []
        
        
    def loadDataFromH5(self,*args):
        #a = darkObsFile.ObsFile('/Users/clint/Documents/mazinlab/ScienceData/PAL2017b/20171004/1507175503.h5')
        if os.path.isfile(self.filename):
            try:
                self.a = ObsFile(self.filename)
            except:
                print('darkObsFile failed to load file. Check filename.\n',self.filename)
            else:
                print('data loaded from .h5 file')
                self.h5_filename_label.setText(self.filename)
                self.initializeEmptyArrays(len(self.a.beamImage),len(self.a.beamImage[0]))
                self.beamFlagImage = np.transpose(self.a.beamFlagImage.read())
                self.beamFlagMask = self.beamFlagImage==0  #make a mask. 0 for good beam map
                self.makeHotPixMask()
                self.radio_button_beamFlagImage.setChecked(True)
                self.callPlotMethod()
                #set the max integration time to the h5 exp time in the header
                self.expTime = self.a.getFromHeader('expTime')
                self.wvlBinStart = self.a.getFromHeader('wvlBinStart')
                self.wvlBinEnd = self.a.getFromHeader('wvlBinEnd')
                
                #set the max and min values for the lambda spinboxes
#                self.spinbox_startLambda.setMinimum(self.wvlBinStart)
                self.spinbox_stopLambda.setMinimum(self.wvlBinStart)
                self.spinbox_startLambda.setMaximum(self.wvlBinEnd)
                self.spinbox_stopLambda.setMaximum(self.wvlBinEnd)
                self.spinbox_startLambda.setValue(self.wvlBinStart)
                self.spinbox_stopLambda.setValue(self.wvlBinEnd)
                
                #set the max value of the integration time spinbox
                self.spinbox_startTime.setMinimum(0)
                self.spinbox_startTime.setMaximum(self.expTime)
                self.spinbox_integrationTime.setMinimum(0)
                self.spinbox_integrationTime.setMaximum(self.expTime)
                self.spinbox_integrationTime.setValue(self.expTime)

        
        
        
        
    def plotBeamImage(self):
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:
        
            
            #clear the axes
            self.ax1.clear()  
            
            self.image = self.beamFlagImage
            
            self.cbarLimits = np.array([np.amin(self.image),np.amax(self.image)])
            
            self.ax1.imshow(self.image,interpolation='none')
            self.fig.cbar.set_clim(np.amin(self.image),np.amax(self.image))
            self.fig.cbar.draw_all()

            self.ax1.set_title('beam flag image')
            
            self.ax1.axis('off')
            
            self.cursor = Cursor(self.ax1, useblit=True, color='red', linewidth=2)
            
            self.draw()
        
        
        
    def plotImage(self,*args):        
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:
        
        
            #clear the axes
            self.ax1.clear()  
            
            temp = self.a.getPixelCountImage(firstSec = self.spinbox_startTime.value(), integrationTime=self.spinbox_integrationTime.value(),applyWeight=False,flagToUse = 0,wvlStart=self.spinbox_startLambda.value(), wvlStop=self.spinbox_stopLambda.value())
            self.rawCountsImage = np.transpose(temp['image'])
            
            self.image = self.rawCountsImage
            self.image[np.where(np.logical_not(np.isfinite(self.image)))]=0
#            self.image = self.rawCountsImage*self.beamFlagMask
            #self.image = self.rawCountsImage*self.beamFlagMask*self.hotPixMask
            self.image = 1.0*self.image/self.spinbox_integrationTime.value()
            
            self.cbarLimits = np.array([np.amin(self.image),np.amax(self.image)])
            
            self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
            
            self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
            self.fig.cbar.draw_all()

            self.ax1.set_title('Raw counts')
            
            self.ax1.axis('off')
            
            self.cursor = Cursor(self.ax1, useblit=True, color='red', linewidth=2)
            
            #self.ax1.plot(np.arange(10),np.arange(10)**2)
            
            
            self.draw()
            


        
        
        
        
        
    def plotIcIs(self):
        #check if obsfile object exists
        try:
            self.a
        except:
            print('\nNo obsfile object defined. Select H5 file to load.\n')
            return
        else:
            self.ax.clear() #clear the axes
            
#            for col in range(self.nCol):
#                for row in range(self.nRow):
                
            






        
    def plotNoise(self,*args):
        #clear the axes
        self.ax1.clear()  
        
        #debugging- generate some noise to plot
        self.image = np.random.randn(self.nRow,self.nCol)
                    
        self.foo = self.ax1.imshow(self.image,interpolation='none')
        self.cbarLimits = np.array([np.amin(self.image),np.amax(self.image)])
        self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
        self.fig.cbar.draw_all()
        
        self.ax1.set_title('some generated noise...')
        
        self.ax1.axis('off')
        
        self.cursor = Cursor(self.ax1, useblit=True, color='red', linewidth=2)

        self.draw()
        
        
        
    def callPlotMethod(self):
        if self.radio_button_img.isChecked() == True:
            self.plotNoise()
        elif self.radio_button_ic_is.isChecked() == True:
            self.plotNoise()
        elif self.radio_button_beamFlagImage.isChecked() == True:
            self.plotBeamImage()
        elif self.radio_button_rawCounts.isChecked() == True:
            self.plotImage()
        else:
            self.plotNoise()
            
            
            
            
    def makeHotPixMask(self):
        #if self.hotPixMask[row][col] = 0, it's a hot pixel. If 1, it's good. 
        temp = self.a.getPixelCountImage(firstSec = 0, integrationTime=1,applyWeight=False,flagToUse = 0)
        rawCountsImage = np.transpose(temp['image'])
        for col in range(self.nCol):
            for row in range(self.nRow):
                if rawCountsImage[row][col] < self.hotPixCut:
                    self.hotPixMask[row][col] = 1

        

    
    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        #Define the plot window. 
        self.main_frame = QWidget()
        self.dpi = 100
        self.fig = Figure((1.0, 20.0), dpi=self.dpi, tight_layout=True) #define the figure, set the size and resolution
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax1 = self.fig.add_subplot(111)
        self.foo = self.ax1.imshow(self.image,interpolation='none')
        self.fig.cbar = self.fig.colorbar(self.foo)
        
        
        button_plot = QPushButton("Plot image")
        button_plot.setEnabled(True)
        button_plot.setToolTip('Click to update image.')
        button_plot.clicked.connect(self.callPlotMethod)
        
        
        button_quickLoad = QPushButton("Quick Load H5")
        button_quickLoad.setEnabled(True)
        button_quickLoad.setToolTip('Will change functionality later.')
        button_quickLoad.clicked.connect(self.quickLoadH5)
        
        #spinboxes for the start & stop times
        self.spinbox_startTime = QSpinBox()
        self.spinbox_integrationTime = QSpinBox()
        
        #labels for the start/stop time spinboxes
        label_startTime = QLabel('start time')
        label_integrationTime = QLabel('integration time')
        
        #spinboxes for the start & stop wavelengths
        self.spinbox_startLambda = QSpinBox()
        self.spinbox_stopLambda = QSpinBox()
        
        #labels for the start/stop time spinboxes
        label_startLambda = QLabel('start wavelength [nm]')
        label_stopLambda = QLabel('stop wavelength [nm]')      
        
        #label for the filenames
        self.h5_filename_label = QLabel('no file loaded')
        
        #label for the active pixel
        self.activePixel_label = QLabel('Active Pixel ({},{}) {}'.format(self.activePixel[0],self.activePixel[1],self.image[self.activePixel[1],self.activePixel[0]]))
        
        #make the radio buttons
        #self.radio_button_noise = QRadioButton("Noise")
        self.radio_button_img = QRadioButton(".IMG")
        self.radio_button_ic_is = QRadioButton("Ic/Is")
        #self.radio_button_bin = QRadioButton(".bin")
        #self.radio_button_decorrelationTime = QRadioButton("Decorrelation Time")
        self.radio_button_beamFlagImage = QRadioButton("Beam Flag Image")
        self.radio_button_rawCounts = QRadioButton("Raw Counts")
        self.radio_button_img.setChecked(True)
        
        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)
        
        #create a v box for the timespan spinboxes
        vbox_timespan = QVBoxLayout()
        vbox_timespan.addWidget(label_startTime)
        vbox_timespan.addWidget(self.spinbox_startTime)
        vbox_timespan.addWidget(label_integrationTime)
        vbox_timespan.addWidget(self.spinbox_integrationTime)
        
        #create a v box for the wavelength spinboxes
        vbox_lambda = QVBoxLayout()
        vbox_lambda.addWidget(label_startLambda)
        vbox_lambda.addWidget(self.spinbox_startLambda)
        vbox_lambda.addWidget(label_stopLambda)
        vbox_lambda.addWidget(self.spinbox_stopLambda)
        
        #create an h box for the buttons
        hbox_buttons = QHBoxLayout()
        hbox_buttons.addWidget(button_plot)
        hbox_buttons.addWidget(button_quickLoad) #################################################
        
        #create an h box for the time and lambda v boxes
        hbox_time_lambda = QHBoxLayout()
        hbox_time_lambda.addLayout(vbox_timespan)
        hbox_time_lambda.addLayout(vbox_lambda)
        
        #create a v box combining spinboxes and buttons
        vbox_time_lambda_buttons = QVBoxLayout()
        vbox_time_lambda_buttons.addLayout(hbox_time_lambda)
        vbox_time_lambda_buttons.addLayout(hbox_buttons)
        
        #create a v box for the radio buttons
        vbox_radio_buttons = QVBoxLayout()
        #vbox_radio_buttons.addWidget(self.radio_button_noise)
        vbox_radio_buttons.addWidget(self.radio_button_img)
        vbox_radio_buttons.addWidget(self.radio_button_ic_is)
        #vbox_radio_buttons.addWidget(self.radio_button_bin)
        #vbox_radio_buttons.addWidget(self.radio_button_decorrelationTime)
        vbox_radio_buttons.addWidget(self.radio_button_beamFlagImage)
        vbox_radio_buttons.addWidget(self.radio_button_rawCounts)
        
        #create a h box combining the spinboxes, buttons, and radio buttons
        hbox_controls = QHBoxLayout()
        hbox_controls.addLayout(vbox_time_lambda_buttons)
        hbox_controls.addLayout(vbox_radio_buttons)
        
        #create a v box for showing the files that are loaded in memory
        vbox_filenames = QVBoxLayout()
        vbox_filenames.addWidget(self.h5_filename_label)
        vbox_filenames.addWidget(self.activePixel_label)

        
        #Now create another vbox, and add the plot vbox and the button's hbox to the new vbox.
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)
        vbox_combined.addLayout(hbox_controls)
        vbox_combined.addLayout(vbox_filenames)
        
        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)
        
        #set up the pyqt5 events
        cid = self.fig.canvas.mpl_connect('motion_notify_event', self.hoverCanvas)
        cid2 = self.fig.canvas.mpl_connect('button_press_event', self.mousePressed)
        cid3 = self.fig.canvas.mpl_connect('scroll_event', self.scroll_ColorBar)
        

        
        
    def quickLoadH5(self):
        self.filename = '/Users/clint/Documents/mazinlab/ScienceData/PAL2017b/20171004/1507175503.h5'
        self.loadDataFromH5()  
        


        
        
    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()
        
        
    def hoverCanvas(self,event):
        if event.inaxes is self.ax1:
            col = int(round(event.xdata))
            row = int(round(event.ydata))
            if row < self.nRow and col < self.nCol:
                self.status_text.setText('({:d},{:d}) {}'.format(col,row,self.image[row,col]))
                
                
    def scroll_ColorBar(self,event):
        if event.inaxes is self.fig.cbar.ax:
            stepSize = 0.1  #fractional change in the colorbar scale
            if event.button == 'up':
                self.cbarLimits[1] *= (1 + stepSize)   #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
            elif event.button == 'down':
                self.cbarLimits[1] *= (1 - stepSize)   #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
                
            else:
                pass
                
        self.draw()
        
        
                
                
    def mousePressed(self,event):
#        print('\nclick event registered!\n')
        if event.inaxes is self.ax1:  #check if the mouse-click was within the axes. 
            #print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %('double' if event.dblclick else 'single', event.button,event.x, event.y, event.xdata, event.ydata))
            
            if event.button == 1:
                #print('\nit was the left button that was pressed!\n')
                col = int(round(event.xdata))
                row = int(round(event.ydata))
                self.activePixel = [col,row]
                self.activePixel_label.setText('Active Pixel ({},{}) {}'.format(self.activePixel[0],self.activePixel[1],self.image[self.activePixel[1],self.activePixel[0]]))
                
                self.updateActivePix.emit()  #emit a signal for other plots to update
                
            elif event.button == 3:
                print('\nit was the right button that was pressed!\n')
                
                
        elif event.inaxes is self.fig.cbar.ax:   #reset the scale bar       
            if event.button == 1:
                self.cbarLimits = np.array([np.amin(self.image),np.amax(self.image)])
                self.fig.cbar.set_clim(self.cbarLimits[0],self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,interpolation='none',vmin = self.cbarLimits[0],vmax = self.cbarLimits[1])
                self.draw()
        else:
            pass
        

                
                
        
    def create_status_bar(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/popup.py
        self.status_text = QLabel("")
        self.statusBar().addWidget(self.status_text, 1)
        
        
    def createMenu(self):   
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/quicklook.py
        self.menubar = self.menuBar()
        self.fileMenu = self.menubar.addMenu("&File")
        
        openFileButton = QAction(QIcon('exit24.png'), 'Open H5 File', self)
        openFileButton.setShortcut('Ctrl+O')
        openFileButton.setStatusTip('Open an H5 File')
        openFileButton.triggered.connect(self.getFileNameFromUser)
        self.fileMenu.addAction(openFileButton)
        
        
        exitButton = QAction(QIcon('exit24.png'), 'Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.setStatusTip('Exit application')
        exitButton.triggered.connect(self.close)
        self.fileMenu.addAction(exitButton)
        
        
        #make a menu for plotting
        self.plotMenu = self.menubar.addMenu("&Plot")
        plotLightCurveButton = QAction('Light Curve', self)
        plotLightCurveButton.triggered.connect(self.makeTimestreamPlot)
        plotIntensityHistogramButton = QAction('Intensity Histogram',self)
        plotIntensityHistogramButton.triggered.connect(self.makeIntensityHistogramPlot)
        plotSpectrumButton = QAction('Spectrum',self)
        plotSpectrumButton.triggered.connect(self.makeSpectrumPlot)
        self.plotMenu.addAction(plotLightCurveButton)
        self.plotMenu.addAction(plotIntensityHistogramButton)
        self.plotMenu.addAction(plotSpectrumButton)

        
        self.menubar.setNativeMenuBar(False) #This is for MAC OS


        
        
    def getFileNameFromUser(self):
        # look at this website for useful examples
        # https://pythonspot.com/pyqt5-file-dialog/
        try:def_loc = os.environ['MKID_DATA_DIR']
        except KeyError:def_loc='.'
        filename, _ = QFileDialog.getOpenFileName(self, 'Select One File', def_loc,filter = '*.h5')

        self.filename = filename
        self.loadDataFromH5(self.filename)
        
        
    def makeTimestreamPlot(self):
        sWindow = timeStream(self)
        sWindow.show()
        self.sWindowList.append(sWindow)
        
        
    def makeIntensityHistogramPlot(self):
        sWindow = intensityHistogram(self)
        sWindow.show()
        self.sWindowList.append(sWindow)
        
        
    def makeSpectrumPlot(self):
        sWindow = spectrum(self)
        sWindow.show()
        self.sWindowList.append(sWindow)
Beispiel #19
0
class MainWindow(QMainWindow, Ui_MainWindow):
    """
    Class documentation goes here.
    """
    datalist = []
    framelist = []
    coords = []
    coords_tmp = []

    def __init__(self, parent=None):
        """
        Constructor
        @param parent reference to the parent widget
        @type QWidget
        """
        super(MainWindow, self).__init__(parent)
        self.setupUi(self)

        self.treeWidget.setColumnWidth(0, 240)
        # self.treeWidget.resizeColumnToContents(0)
        # self.treeWidget.resizeColumnToContents(0)
        # self.treeWidget.resizeColumnToContents(1)
        # self.treeWidget.header().setSectionResizeMode(0, QHeaderView.Stretch)
        # self.treeWidget.header().setSectionResizeMode(1, QHeaderView.Stretch)

        self.figure = Figure(figsize=(384, 384), dpi=100)
        self.figure.subplots_adjust(top=1, bottom=0, left=0, \
                                    right=1, hspace=0, wspace=0)

        self.canvas = FigureCanvas(self.figure)
        self.horizontalLayout_2.addWidget(self.canvas)
        self.figure1 = Figure(figsize=(384, 384), dpi=100)
        self.figure1.subplots_adjust(top=1, bottom=0, left=0, \
                                    right=1, hspace=0, wspace=0)

        self.canvas1 = FigureCanvas(self.figure1)
        self.horizontalLayout_3.addWidget(self.canvas1)

        self.horizontalScrollBar.valueChanged.connect(self.updatedisplay)
        self.horizontalScrollBar_2.valueChanged.connect(self.updatedisplay)

        self.checkBox.stateChanged.connect(self.sectordiv)

    def sectordiv(self):
        cpx = self.centerpts[0]
        cpy = 4096 - self.centerpts[1]
        line, = self.axes.plot([cpx], [cpy])
        linebuilder = LineBuilder(line)

    def updatedisplay(self, value):
        tmin = self.horizontalScrollBar.value()
        tmax = self.horizontalScrollBar_2.value()
        # self.axes.clear()
        self.axes.imshow(self.data, cmap='Greys_r', vmin=tmin, vmax=tmax)
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.set_yticklabels([])
        self.axes.set_xticklabels([])
        self.canvas.draw()
        self.canvas.flush_events()

    def create_datasets(self, path, dataname):
        data = QTreeWidgetItem(self.treeWidget)
        # data.setExpanded(True)
        datadict = {'data': data, 'dataname': dataname, 'framecount': 0}

        num, datas = self.load_datainfo(path, dataname)
        print("datas[0]: ", datas[0])
        for i in range(num):
            frame = QTreeWidgetItem()
            # frame.setExpanded(True)
            framename = datas[i]
            framedict = {'frame': frame, 'framename': framename}
            self.framelist.append(framedict)
            frame.setText(0, framename)
            data.addChild(frame)

        datadict['framecount'] = num
        data.setText(0, dataname)
        self.datalist.append(datadict)
        return os.path.join(path, datas[0])

    def load_datainfo(self, path, dataname):
        data = [ x for x in os.listdir(path) if os.path.splitext(dataname)[0][:-4] in x \
                                and os.path.splitext(x)[-1] == ".tif"]
        return len(data), data

    def onclick(self, event):
        global ix, iy
        ix, iy = event.xdata, event.ydata
        if self.checkBox.isChecked():
            return
        else:
            self.coords.append([ix, iy])
            # print('x = %d, y = %d'%(ix, iy))
            c = Circle((ix, iy), 25, fill=False, color='g')
            self.axes.add_patch(c)

            self.canvas.draw()
            self.canvas.flush_events()
            return [ix, iy]

    @pyqtSlot()
    def on_pushButton_clicked(self):
        self.axes = self.figure.add_subplot()
        # self.axes.clear()

        filename, _ = QFileDialog.getOpenFileName(self, "Select file", "./")

        self.dataset = filename
        path, dataname = os.path.split(self.dataset)
        frame1 = self.create_datasets(path, dataname)

        tif = libtiff.TIFF.open(frame1)
        self.data = tif.read_image()
        self.data = np.abs(self.data)

        maximum = np.max(self.data)
        minimum = np.min(self.data)
        mean = np.mean(self.data)
        sigma = np.std(self.data)
        print("maximum: ", maximum)
        print("minimum: ", minimum)
        print("mean: ", mean)
        print("sigma: ", sigma)
        self.curmin = max(minimum, mean - 3.0 * sigma)
        self.curmax = min(maximum, mean + 3.0 * sigma)

        self.horizontalScrollBar.setMinimum(minimum)
        self.horizontalScrollBar_2.setMaximum(maximum)
        self.horizontalScrollBar.setMinimum(minimum)
        self.horizontalScrollBar_2.setMaximum(maximum)
        self.horizontalScrollBar.setValue(self.curmin)
        self.horizontalScrollBar_2.setValue(self.curmax)

        self.axes.imshow(self.data,
                         cmap='Greys_r',
                         vmin=self.curmin,
                         vmax=self.curmax)
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.set_yticklabels([])
        self.axes.set_xticklabels([])

        self.figure.canvas.mpl_connect('button_press_event', self.onclick)
        self.canvas.draw()

    @pyqtSlot()
    def on_pushButton_2_clicked(self):
        self.centerpts = com.center(np.flipud(self.data), self.coords)
        cpsx = self.centerpts[0]
        cpsy = 4096 - self.centerpts[1]
        c = Circle((cpsx, cpsy), 36, fill=False, color='b')
        self.axes.add_patch(c)
        self.canvas.draw()
        print("centerx: ", self.centerpts[0], " centery: ", self.centerpts[1])

    @pyqtSlot()
    def on_pushButton_3_clicked(self):
        radius = 24
        if len(self.coords_tmp) == 0:
            self.coords_tmp = np.loadtxt("/cs2/shuangbo/tmp/coords.txt")

        for xi, yi in self.coords_tmp:
            yi = 4096 - yi
            c = Circle((xi, yi), radius, fill=False, color='r')
            self.axes.add_patch(c)
        self.canvas.draw()

    @pyqtSlot()
    def on_pushButton_4_clicked(self):

        for co in self.coords:
            self.coords_tmp.append(com.com(np.flipud(self.data), co))
        coords_tmp = np.array(self.coords_tmp)
        np.savetxt("/cs2/shuangbo/tmp/coords.txt", self.coords_tmp)

    @pyqtSlot()
    def on_pushButton_5_clicked(self):
        self.thread = Thread.Thread(self.dataset, self.centerpts)

        self.pushButton_5.setEnabled(False)
        self.axes1 = self.figure1.add_subplot()
        self.axes1.clear()
        self.thread.signal.connect(self.Update)
        self.thread.finish.connect(self.buttonEnable)
        self.thread.start()

    def buttonEnable(self):
        self.pushButton_5.setEnabled(True)

    def Update(self, data, intens, iml, pl):
        self.axes.clear()
        self.axes1.clear()
        self.axes.imshow(data,
                         cmap='Greys_r',
                         vmin=self.curmin,
                         vmax=self.curmax)
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.set_yticklabels([])
        self.axes.set_xticklabels([])

        self.axes1.set_xlim(0, iml)
        for i in range(pl):
            self.axes1.plot(intens[:, i])
        self.canvas.draw()
        self.canvas1.draw()
Beispiel #20
0
class QVisaDynamicPlot(QWidget):
    def __init__(self, _app):

        QWidget.__init__(self)

        # Axes are be stored in a standard dictionary
        # 	[111]  = subplot(111)
        #	[111t] = subplot(111).twinx()
        self._axes = {}

        # Axes handles will be stored in QVisaDataObject
        #	[111]
        # 		[<key0>] = handles(list)
        # 		[<key1>] = handles(list)
        self._handles = QVisaDataObject()

        # QVisaColorMap class and generator function
        self._cmap = QVisaColorMap()
        self._cgen = self._cmap.gen_next_color()

        # Dictionary to hold plot adjust values
        self._adjust = {'l': 0.15, 'r': 0.90, 't': 0.90, 'b': 0.10}

        # Generate main layout
        self._gen_main_layout()

        # Cache a reference to the calling application
        self._app = _app
        self.sync = False

    def _gen_main_layout(self):

        self._layout = QVBoxLayout()

        # Generate widgets
        self._gen_mpl_widgets()

        # HBoxLayout for toolbar and clear button
        self._layout_toolbar = QHBoxLayout()
        self._layout_toolbar.addWidget(self.mpl_toolbar)
        self._layout_toolbar.addWidget(self.mpl_handles_label)
        self._layout_toolbar.addWidget(self.mpl_handles)
        self._layout_toolbar.addWidget(self.mpl_refresh)

        # HBoxLayout for plot object
        self._layout_plot = QHBoxLayout()
        self._layout_plot.addWidget(self.mpl_canvas)

        # Add layouts
        self._layout.addLayout(self._layout_toolbar)
        self._layout.addLayout(self._layout_plot)

        # Set widget layout
        self.setLayout(self._layout)

    # Generate matplotlib widgets
    def _gen_mpl_widgets(self):

        # Generate matplotlib figure and canvas
        self.mpl_figure = plt.figure(figsize=(8, 5))
        self.mpl_canvas = FigureCanvas(self.mpl_figure)
        self.mpl_toolbar = NavigationToolbar(self.mpl_canvas, self)

        # Handle selector
        self.mpl_handles_label = QLabel("<b>Show:</b>")
        self.mpl_handles = QComboBox()
        self.mpl_handles.addItem("all-traces")
        self.mpl_handles.setFixedHeight(30)
        self.mpl_handles.currentTextChanged.connect(
            self.update_visible_handles)

        # Refresh button
        self.mpl_refresh = QPushButton("Clear Data")
        self.mpl_refresh.clicked.connect(self.refresh_canvas)
        self.mpl_refresh.setFixedHeight(32)
        self.mpl_refresh_callback = None

    # Method to enable and disable mpl_refresh button
    def mpl_refresh_setEnabled(self, _bool):
        self.mpl_refresh.setEnabled(_bool)

    # Add mechanism to pass app method to run on mpl_refresh.clicked
    def set_mpl_refresh_callback(self, __func__):
        self.mpl_refresh_callback = str(__func__)

    # Run app method attached to mpl_refresh.clicked
    def _run_mpl_refresh_callback(self):
        if self.mpl_refresh_callback is not None:
            __func__ = getattr(self._app, self.mpl_refresh_callback)
            __func__()

    # Sync application data. When True, refresh lines will attempt to
    # del self._app._data.data[_handle_key] when clearing data in axes
    # self._handles[_axes_key][_handle_key]. This will synchonize plots
    # with application data.
    def sync_application_data(self, _bool):
        self.sync = _bool

    # Wrapper method to set(change) colormap
    def gen_cmap_colors(self, _cmap="default"):
        self._cmap.gen_cmap_colors(_cmap)

    # Wrapper method to gnertae next color
    def gen_next_color(self):
        return next(self._cgen)

    # Add axes object to widget
    def add_subplot(self, _axes_key=111, twinx=False):

        self._handles.add_key(str(_axes_key))
        self._axes[str(_axes_key)] = self.mpl_figure.add_subplot(_axes_key)

        if twinx:
            self._handles.add_key(str(_axes_key) + 't')
            self._axes[str(_axes_key) +
                       't'] = self._axes[str(_axes_key)].twinx()

    # Add axes xlabels
    def set_axes_xlabel(self, _axes_key, _xlabel):
        self._axes[_axes_key].set_xlabel(str(_xlabel))

    # Add axes ylabels
    def set_axes_ylabel(self, _axes_key, _ylabel):
        self._axes[_axes_key].set_ylabel(str(_ylabel))

    # Convenience method to set axes labels
    def set_axes_labels(self, _axes_key, _xlabel, _ylabel):
        self.set_axes_xlabel(str(_axes_key), _xlabel)
        self.set_axes_ylabel(str(_axes_key), _ylabel)

    # Set axes adjust
    def set_axes_adjust(self, _left, _right, _top, _bottom):
        self._adjust = {'l': _left, 'r': _right, 't': _top, 'b': _bottom}

    # Add origin lines
    def add_origin_lines(self, _axes_key, key="both"):

        # x-line only
        if key == "x":
            self._axes[_axes_key].axhline(y=0,
                                          color='k',
                                          linewidth=0.5,
                                          linestyle=":")

        # y-line only
        if key == "y":
            self._axes[_axes_key].axvline(x=0,
                                          color='k',
                                          linewidth=0.5,
                                          linestyle=":")

        # both lines
        if key == "both":
            self._axes[_axes_key].axhline(y=0,
                                          color='k',
                                          linewidth=0.5,
                                          linestyle=":")
            self._axes[_axes_key].axvline(x=0,
                                          color='k',
                                          linewidth=0.5,
                                          linestyle=":")

    # Add handle to axes
    def add_axes_handle(self, _axes_key, _handle_key, _color=None):

        # Get handle keys from comboBox
        _handle_keys = [
            self.mpl_handles.itemText(i)
            for i in range(self.mpl_handles.count())
        ]

        # Check if handle key is in list
        if _handle_key not in _handle_keys:
            self.mpl_handles.addItem(_handle_key)

        # Add option to set color directly
        if _color is not None:
            h, = self._axes[str(_axes_key)].plot([], [], color=_color)

        # Otherwise generate color on defined map
        else:
            h, = self._axes[str(_axes_key)].plot([], [],
                                                 color=self.gen_next_color())

        # Add handle to handle keys
        self._handles.add_subkey(_axes_key, _handle_key)
        self._handles.append_subkey_data(_axes_key, _handle_key, h)

    # Method to get axes handles
    def get_axes_handles(self):
        return self._handles

    # Update axes handle (set)
    def set_handle_data(self,
                        _axes_key,
                        _handle_key,
                        x_data,
                        y_data,
                        _handle_index=0):

        # Get the list of handles
        _h = self._handles.get_subkey_data(_axes_key, _handle_key)

        # Set data values on _handle_index
        _h[_handle_index].set_xdata(x_data)
        _h[_handle_index].set_ydata(y_data)

    # Update axes handle (append)
    def append_handle_data(self,
                           _axes_key,
                           _handle_key,
                           x_value,
                           y_value,
                           _handle_index=0):

        # Get the list of handles
        _h = self._handles.get_subkey_data(_axes_key, _handle_key)

        # Append new values to handle data
        _x = np.append(_h[_handle_index].get_xdata(), x_value)
        _y = np.append(_h[_handle_index].get_ydata(), y_value)

        # Set xdata and ydata to handle
        _h[_handle_index].set_xdata(_x)
        _h[_handle_index].set_ydata(_y)

    # Method to redraw canvas lines
    def update_visible_handles(self):

        # Get handle
        _show_handle = self.mpl_handles.currentText()

        # Set all traces visible
        if _show_handle == "all-traces":

            # For each axis (e.g. 111)
            for _axes_key in self._handles.keys():

                # Check if there are handles on the key
                if self._handles.subitems(_axes_key) is not None:

                    # Loop through handle_key and handle_list
                    for _handle_key, _handle_list in self._handles.subitems(
                            _axes_key):

                        [_h.set_visible(True) for _h in _handle_list]

        else:

            # For each axis (e.g. 111)
            for _axes_key in self._axes.keys():

                # Check if there are handles on the key
                if self._handles.subitems(_axes_key) is not None:

                    # Loop through handle_key and handle_list
                    for _handle_key, _handle_list in self._handles.subitems(
                            _axes_key):

                        if _show_handle == _handle_key:

                            [_h.set_visible(True) for _h in _handle_list]

                        else:

                            [_h.set_visible(False) for _h in _handle_list]

        self.update_canvas()

    # Method to update canvas dynamically
    def update_canvas(self):

        # Adjust subplots
        plt.subplots_adjust(left=self._adjust['l'],
                            right=self._adjust['r'],
                            top=self._adjust['t'],
                            bottom=self._adjust['b'])

        # Loop through all figure axes and relimit
        for _key, _axes in self._axes.items():

            _axes.relim()
            _axes.set_xlim(left=None, right=None, emit=True, auto=True)
            _axes.set_ylim(bottom=None, top=None, emit=True, auto=True)
            _axes.autoscale_view(scalex=True, scaley=True)

            # Only needed if plotting on linear scale
            if _axes.get_yscale() == "linear":
                _axes.ticklabel_format(style='sci',
                                       scilimits=(0, 0),
                                       axis='y',
                                       useOffset=False)

        # Draw and flush_events
        self.mpl_canvas.draw()
        self.mpl_canvas.flush_events()

    # Refresh canvas. Note callback will expose args as False
    def refresh_canvas(self, supress_warning=False):

        # Only ask to redraw if there is data present
        if (self._handles.keys_empty() == False) and (supress_warning
                                                      == False):

            msg = QMessageBox()
            msg.setIcon(QMessageBox.Information)
            msg.setText("Clear measurement data (%s)?" %
                        self.mpl_handles.currentText())
            msg.setWindowTitle("QDynamicPlot")
            msg.setWindowIcon(self._app._get_icon())
            msg.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
            self.msg_clear = msg.exec_()

            # Note that mpl_refresh_callback is run after refresh lines.
            if self.msg_clear == QMessageBox.Yes:

                self.refresh_lines()
                self._run_mpl_refresh_callback()
                return True

            else:
                return False

        else:
            self.refresh_lines()
            self._run_mpl_refresh_callback()
            return True

    # Method to delete lines from handle
    def refresh_lines(self):

        # Create empty handle cache
        _del_cache = []

        # For each axes (e.g. 111)
        for _axes_key in self._axes.keys():

            # Check if there are handles on the key
            if self._handles.subitems(_axes_key) is not None:

                # Loop through handle_key, handle_list objects on axes
                for _handle_key, _handle_list in self._handles.subitems(
                        _axes_key):

                    # Check if first handle in the list is visible
                    if _handle_list[0].get_visible(
                    ) == True and _handle_key not in _del_cache:

                        # Cache the handle key for deletion if it has not beed cached yet
                        _del_cache.append(_handle_key)

        # Loop through cached keys
        for _handle_key in _del_cache:

            # Check for key on each axis (e.g. 111, 111t)
            for _axes_key in self._axes.keys():

                # Remove handles (mpl.Artist obejcts) by calling destructor
                for _handle in self._handles.get_subkey_data(
                        _axes_key, _handle_key):

                    _handle.remove()

                # Delete the _handle_key from _handles object
                self._handles.del_subkey(_axes_key, _handle_key)

            # Remove _handle_key from dropdown
            self.mpl_handles.removeItem(self.mpl_handles.findText(_handle_key))

            # Remove _handle_key from application data if syncing
            if self.sync == True:
                _data = self._app._get_data_object()
                _data.del_key(_handle_key)

        # If deleting all traces, reset the colormap
        if self.mpl_handles.currentText() == "all-traces":
            self._cmap.gen_reset()

        # Otherwise set text to "all-traces"
        else:
            self.mpl_handles.setCurrentIndex(0)

        # Redraw canvas
        self.update_canvas()

    # Method to reset axes
    def reset_canvas(self):

        # Clear the axes
        for _key, _axes in self._axes.items():

            # Pull labels
            _xlabel = _axes.get_xlabel()
            _ylabel = _axes.get_ylabel()

            # clear axes and reset labels
            _axes.clear()
            _axes.set_xlabel(_xlabel)
            _axes.set_ylabel(_ylabel)

        # Clear registered handles
        # Calling add_key() will re-initialize data dictionary to {} for axes
        [
            self._handles.add_key(_axes_key)
            for _axes_key in self._handles.keys()
        ]

        # Clear the combobox
        self.mpl_handles.clear()
        self.mpl_handles.addItem("all-traces")

        # Reset the colormap
        self._cmap.gen_reset()

        # Update canvas
        self.update_canvas()
Beispiel #21
0
class MyMainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(MyMainWindow, self).__init__(parent)
        QtWidgets.qApp.installEventFilter(self)
        self.setupUi(self)

        self.fig = Figure(dpi=100)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.fig.subplots_adjust(left=0,
                                 right=1,
                                 top=1,
                                 bottom=0,
                                 wspace=0,
                                 hspace=0)
        self.ax.set_xlabel('x')
        self.ax.set_ylabel('y')
        self.ax.set_zlabel('z')
        self.plot3d.addWidget(self.canvas)

        self.fig2 = Figure(dpi=100)
        self.canvas2 = FigureCanvas(self.fig2)
        self.canvas2.setParent(self)
        self.ax2 = self.fig2.add_subplot(111, projection='3d')
        self.fig2.subplots_adjust(left=0,
                                  right=1,
                                  top=1,
                                  bottom=0,
                                  wspace=0,
                                  hspace=0)
        self.ax2.set_xlabel('x')
        self.ax2.set_ylabel('y')
        self.ax2.set_zlabel('z')
        self.toolMesh, = self.ax2.plot([], [], [], 'k-', zorder=150)
        self.plot3dProgramCode.addWidget(self.canvas2)
        self.canvas2.draw()

        self.myCncObject = myCNC()
        self.myStepperControl = stepperControl(self.myCncObject, self)

        self.stlImportButton.clicked[bool].connect(self.openStlFile)
        self.svgImportButton.clicked[bool].connect(self.openSvgFile)
        self.scaleSvg.valueChanged.connect(self.setSvgScaling)
        self.scaleStl.valueChanged.connect(self.setStlScaling)

        self.invertContour.setHidden(True)
        self.processingStrategy.currentIndexChanged.connect(self.setLabels)

        self.xOffsetLabel.setHidden(True)
        self.offsetX.setHidden(True)
        self.yOffsetLabel.setHidden(True)
        self.offsetY.setHidden(True)
        self.svgProcessingStrategy.currentIndexChanged.connect(
            self.setLabelsSvg)

        self.svgExportButton.clicked[bool].connect(self.saveSvgToNpzFile)
        self.stlExportButton.clicked[bool].connect(self.saveStlToNpzFile)
        self.generateSvgToolpath.clicked[bool].connect(
            self.myCncObject.generateSvgToolpath)
        self.generateStlToolpath.clicked[bool].connect(
            self.myCncObject.generateStlToolpath)

        self.loadConvertedFileButton.clicked[bool].connect(self.npzImport)

        self.pauseButton.clicked[bool].connect(
            self.myStepperControl.pauseProgram)
        self.stopButton.clicked[bool].connect(
            self.myStepperControl.stopProgram)
        self.runButton.clicked[bool].connect(self.myStepperControl.runProgram)
        self.pauseButton.setDisabled(True)
        self.stopButton.setDisabled(True)
        self.runButton.setDisabled(True)
        self.joggPosXSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('x', self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggPosXFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('x', self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.joggNegXSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('x', -self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggNegXFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('x', -self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.joggPosYSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('y', self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggPosYFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('y', self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.joggNegYSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('y', -self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggNegYFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('y', -self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.joggPosZSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('z', self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggPosZFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('z', self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.joggNegZSlow.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('z', -self.myCncObject.slowStep,
                                               win.speedManualDrive.value()))
        self.joggNegZFast.clicked[bool].connect(
            lambda: self.myStepperControl.jogg('z', -self.myCncObject.fastStep,
                                               win.speedManualDrive.value()))
        self.speed.valueChanged.connect(self.myStepperControl.changeSpeed)

        self.setHome.clicked[bool].connect(self.myCncObject.setHomePosition)
        self.setZero.clicked[bool].connect(self.myCncObject.setZero)
        self.goHome.clicked[bool].connect(self.myStepperControl.goHome)
        self.performReferenceScan.clicked[bool].connect(
            self.myStepperControl.performReferenceScan)

        self.goToPosition.clicked[bool].connect(
            lambda: self.myStepperControl.goToPosition(self.xSetPoint.value(),
                                                       self.ySetPoint.value(),
                                                       self.zSetPoint.value(),
                                                       manual=True))
        self.stopManualDrive.clicked[bool].connect(
            self.myStepperControl.stopManualDrive)
        self.stopManualDrive.setDisabled(True)

        self.toggleMillingMotor.valueChanged.connect(
            self.myStepperControl.toggleMillingMotor)

        self.canvasElements = {}

        self.show()

    def setLabels(self):
        source = self.sender()
        if source.currentText() == 'line':
            win.xPosLabel_9.setText('Slice Dist.')
            win.invertContour.setHidden(True)
        elif source.currentText() == 'contour':
            win.xPosLabel_9.setText('Z depth')
            win.invertContour.setHidden(False)
        elif source.currentText() == 'constant z':
            win.xPosLabel_9.setText('Z depth')
            win.invertContour.setHidden(False)

    def setLabelsSvg(self):
        source = self.sender()
        if source.currentText() == 'contour':
            win.xOffsetLabel.setHidden(True)
            win.offsetX.setHidden(True)
            win.yOffsetLabel.setHidden(True)
            win.offsetY.setHidden(True)
        elif source.currentText() == 'constant z':
            win.xOffsetLabel.setHidden(False)
            win.offsetX.setHidden(False)
            win.yOffsetLabel.setHidden(False)
            win.offsetY.setHidden(False)

    def openStlFile(self):
        source = self.sender()
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        fileName, _ = QtWidgets.QFileDialog.getOpenFileName(
            self,
            "QFileDialog.getOpenFileName()",
            "",
            "STL files (*.stl)",
            options=options)
        if fileName:
            self.stlFilePathLabel.setText(fileName)
            source.setChecked(False)
            self.myCncObject.stlImport(fileName)

    def setStlScaling(self):
        source = self.sender()
        self.myCncObject.stlScaling = source.value()
        self.make3dPlot(self.canvas, self.ax)

    def openSvgFile(self):
        source = self.sender()
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        fileName, _ = QtWidgets.QFileDialog.getOpenFileName(
            self,
            "QFileDialog.getOpenFileName()",
            "",
            "SVG files (*.svg)",
            options=options)
        if fileName:
            self.svgFilePathLabel.setText(fileName)
            source.setChecked(False)
            self.myCncObject.svgImport(fileName)

    def setSvgScaling(self):
        source = self.sender()
        self.myCncObject.svgScaling = source.value()
        self.make2dPlot(self.canvas,
                        self.ax,
                        scaling=self.myCncObject.svgScaling /
                        constants.conversionFactorPPItoMM)

    def saveSvgToNpzFile(self):
        source = self.sender()
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        fileName, _ = QtWidgets.QFileDialog.getSaveFileName(
            self,
            "QFileDialog.getSaveFileName()",
            "",
            "*.npz files (*.npz)",
            options=options)
        if fileName:
            source.setChecked(False)
            np.savez(fileName,
                     type="2d",
                     shape=self.rescaledContour,
                     toolpath=self.myCncObject.toolpath,
                     toolDiameter=self.toolDiameter.value())

    def saveStlToNpzFile(self):
        source = self.sender()
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        fileName, _ = QtWidgets.QFileDialog.getSaveFileName(
            self,
            "QFileDialog.getSaveFileName()",
            "",
            "*.npz files (*.npz)",
            options=options)
        if fileName:
            source.setChecked(False)
            np.savez(fileName,
                     type="3d",
                     shape=self.myCncObject.stlScaling *
                     self.myCncObject.facets,
                     toolpath=self.myCncObject.toolpath,
                     toolDiameter=self.toolDiameterStl.value())

    def make3dPlot(self, canvas, graph):
        graph.clear()
        self.myCncObject.translateToOrigin()

        mesh = mplot3d.art3d.Poly3DCollection(
            self.myCncObject.stlScaling * self.myCncObject.facets,
            facecolors=self.getFlatColorArray(),
            zsort='max')
        graph.add_collection3d(mesh)
        scale = self.myCncObject.stlScaling * self.myCncObject.facets.flatten(
            "A")
        graph.auto_scale_xyz(scale, scale, scale)
        graph.set_xlabel('x')
        graph.set_ylabel('y')
        graph.set_zlabel('z')
        self.programCode, = graph.plot([], [], [], 'r-', zorder=100)
        canvas.draw()

    def getFlatColorArray(self):
        colors = []
        dark = [16, 33, 105]
        light = [50, 91, 255]
        minZvalue = self.myCncObject.stlScaling * min(
            self.myCncObject.facets[..., 2][..., 2])
        maxZvalue = self.myCncObject.stlScaling * max(
            self.myCncObject.facets[..., 2][..., 2])
        distance = maxZvalue - minZvalue
        for facet in self.myCncObject.stlScaling * self.myCncObject.facets:
            key = (np.mean(facet[..., 2]) - minZvalue) / distance
            colors.append([(dark[0] + (light[0] - dark[0]) * key) / 255,
                           (dark[1] + (light[1] - dark[1]) * key) / 255,
                           (dark[2] + (light[2] - dark[2]) * key) / 255, 1])
        return colors

    def getColorArray(self):
        colors = []
        dark = [16, 33, 105]
        light = [50, 91, 255]
        minZvalue = min(self.myCncObject.facets[..., 2][..., 2])
        maxZvalue = max(self.myCncObject.facets[..., 2][..., 2])
        distance = maxZvalue - minZvalue
        for facet in self.myCncObject.facets:
            for point in facet:
                facetCol = []
                key = (point[2] - minZvalue) / distance
                facetCol.append([(dark[0] + (light[0] - dark[0]) * key) / 255,
                                 (dark[1] + (light[1] - dark[1]) * key) / 255,
                                 (dark[2] + (light[2] - dark[2]) * key) / 255,
                                 1])
                colors.append(facetCol)
        return colors

    def make2dPlot(self, canvas, graph, scaling=1):
        graph.clear()
        self.rescaledContour = []
        for contour in self.myCncObject.importedContourForPlotting:
            self.rescaledContour.append(scaling * np.array(contour))
            graph.plot(scaling * np.array(contour)[:, 0],
                       scaling * np.array(contour)[:, 1],
                       scaling * np.array(contour)[:, 2], 'b-')
        graph.set_xlabel('x')
        graph.set_ylabel('y')
        graph.set_zlabel('z')
        self.programCode, = graph.plot([], [], [], 'r-', zorder=100)

        canvas.draw()

    def plotToolpath(self, canvas, graph):
        self.programCode.set_xdata(self.myCncObject.toolpath[..., 0])
        self.programCode.set_ydata(self.myCncObject.toolpath[..., 1])
        self.programCode.set_3d_properties(self.myCncObject.toolpath[..., 2])
        lowerLimit = min(
            min(self.myCncObject.toolpath[..., 0]) * 1.1,
            min(self.myCncObject.toolpath[..., 1]) * 1.1)
        upperLimit = max(
            max(self.myCncObject.toolpath[..., 0]) * 1.1,
            max(self.myCncObject.toolpath[..., 1]) * 1.1)
        graph.set_xlim(lowerLimit, upperLimit)
        graph.set_ylim(lowerLimit, upperLimit)
        graph.set_zlim(
            min(self.myCncObject.toolpath[..., 2]) * 1.1,
            max(self.myCncObject.toolpath[..., 2]) * 1.1)
        canvas.draw()

    def clearCanvasElements(self, canvas):
        keysToDelete = []
        for elem in self.canvasElements:
            if 'grid' in elem:
                continue
            if self.canvasElements[elem]['canvas'] != canvas:
                continue
            if self.canvasElements[elem]['plot'] and type(
                    self.canvasElements[elem]['plot']) is list:
                for contour in self.canvasElements[elem]['plot']:
                    self.canvasElements[elem]['canvas'].removeItem(contour)
            else:
                self.canvasElements[elem]['canvas'].removeItem(
                    self.canvasElements[elem]['plot'])
            keysToDelete.append(elem)

        for elem in keysToDelete:
            del self.canvasElements[elem]

    def makeToolMesh(self):
        self.toolMesh.set_xdata([
            self.myCncObject.currentPosition['x']['mm'],
            self.myCncObject.currentPosition['x']['mm']
        ])
        self.toolMesh.set_ydata([
            self.myCncObject.currentPosition['y']['mm'],
            self.myCncObject.currentPosition['y']['mm']
        ])
        self.toolMesh.set_3d_properties([
            self.myCncObject.currentPosition['z']['mm'],
            self.myCncObject.currentPosition['z']['mm'] + 10
        ])
        self.canvas2.draw()
        self.canvas2.flush_events()

    def npzImport(self):
        self.myCncObject.programIsPaused = False
        self.pauseButton.setDisabled(True)
        self.stopButton.setDisabled(True)
        self.runButton.setEnabled(True)
        source = self.sender()
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        fileName, _ = QtWidgets.QFileDialog.getOpenFileName(
            self,
            "QFileDialog.getOpenFileName()",
            "",
            "*.npz files (*.npz)",
            options=options)

        self.myCncObject.svgScaling = 1
        self.myCncObject.stlScaling = 1

        if fileName:
            self.npzImportName.setText(fileName)
            source.setChecked(False)
            npzfile = np.load(fileName, allow_pickle=True)
            if npzfile['type'] == "2d":
                self.myCncObject.importedContourForPlotting = npzfile['shape']
                self.make2dPlot(self.canvas2, self.ax2)

            if npzfile['type'] == "3d":
                self.myCncObject.facets = npzfile['shape']
                self.make3dPlot(self.canvas2, self.ax2)
            self.myCncObject.toolpath = npzfile['toolpath']
            self.myCncObject.toolDiameter = npzfile['toolDiameter']
            self.myCncObject.programLength = len(self.myCncObject.toolpath)
            self.plotToolpath(self.canvas2, self.ax2)
            self.alreadProcessedStepsPlot, = self.ax2.plot([], [], [],
                                                           'k-',
                                                           zorder=150)

            self.toolMesh, = self.ax2.plot([], [], [], 'k-', zorder=150)
class ChartLab(QWidget):
    def __init__(self, datahub_entry: DataHubEntry,
                 factor_center: FactorCenter):
        super(ChartLab, self).__init__()

        # ---------------- ext var ----------------

        self.__data_hub = datahub_entry
        self.__factor_center = factor_center
        self.__data_center = self.__data_hub.get_data_center(
        ) if self.__data_hub is not None else None
        self.__data_utility = self.__data_hub.get_data_utility(
        ) if self.__data_hub is not None else None

        self.__inited = False
        self.__plot_table = {}
        self.__paint_data = None

        # ------------- plot resource -------------

        self.__figure = plt.figure()
        self.__canvas = FigureCanvas(self.__figure)

        # -------------- ui resource --------------

        self.__data_frame_widget = None

        self.__combo_factor = QComboBox()
        self.__label_comments = QLabel('')

        # Parallel comparison
        self.__radio_parallel_comparison = QRadioButton('横向比较')
        self.__combo_year = QComboBox()
        self.__combo_quarter = QComboBox()
        self.__combo_industry = QComboBox()

        # Longitudinal comparison
        self.__radio_longitudinal_comparison = QRadioButton('纵向比较')
        self.__combo_stock = SecuritiesSelector(self.__data_utility)

        # Limitation
        self.__line_lower = QLineEdit('')
        self.__line_upper = QLineEdit('')

        self.__button_draw = QPushButton('绘图')
        self.__button_show = QPushButton('数据')

        self.init_ui()

    # ---------------------------------------------------- UI Init -----------------------------------------------------

    def init_ui(self):
        self.__layout_control()
        self.__config_control()

    def __layout_control(self):
        main_layout = QVBoxLayout()
        self.setLayout(main_layout)
        self.setMinimumSize(1280, 800)

        bottom_layout = QHBoxLayout()
        main_layout.addWidget(self.__canvas, 99)
        main_layout.addLayout(bottom_layout, 1)

        group_box, group_layout = create_v_group_box('因子')
        bottom_layout.addWidget(group_box, 2)

        group_layout.addWidget(self.__combo_factor)
        group_layout.addWidget(self.__label_comments)

        group_box, group_layout = create_v_group_box('比较方式')
        bottom_layout.addWidget(group_box, 2)

        line = QHBoxLayout()
        line.addWidget(self.__radio_parallel_comparison, 1)
        line.addWidget(self.__combo_industry, 5)
        line.addWidget(self.__combo_year, 5)
        line.addWidget(self.__combo_quarter, 5)
        group_layout.addLayout(line)

        line = QHBoxLayout()
        line.addWidget(self.__radio_longitudinal_comparison, 1)
        line.addWidget(self.__combo_stock, 10)
        group_layout.addLayout(line)

        group_box, group_layout = create_v_group_box('范围限制')
        bottom_layout.addWidget(group_box, 1)

        line = QHBoxLayout()
        line.addWidget(QLabel('下限'))
        line.addWidget(self.__line_lower)
        group_layout.addLayout(line)

        line = QHBoxLayout()
        line.addWidget(QLabel('上限'))
        line.addWidget(self.__line_upper)
        group_layout.addLayout(line)

        col = QVBoxLayout()
        col.addWidget(self.__button_draw)
        col.addWidget(self.__button_show)
        bottom_layout.addLayout(col, 1)

    def __config_control(self):
        for year in range(now().year, 1989, -1):
            self.__combo_year.addItem(str(year), str(year))
        self.__combo_year.setCurrentIndex(1)

        self.__combo_quarter.addItem('一季报', '03-31')
        self.__combo_quarter.addItem('中报', '06-30')
        self.__combo_quarter.addItem('三季报', '09-30')
        self.__combo_quarter.addItem('年报', '12-31')
        self.__combo_quarter.setCurrentIndex(3)

        self.__combo_industry.addItem('全部', '全部')
        identities = self.__data_utility.get_all_industries()
        for identity in identities:
            self.__combo_industry.addItem(identity, identity)

        if self.__factor_center is not None:
            factors = self.__factor_center.get_all_factors()
            for fct in factors:
                self.__combo_factor.addItem(fct, fct)
        self.on_factor_updated(0)

        self.__combo_stock.setEnabled(False)
        self.__radio_parallel_comparison.setChecked(True)

        self.__radio_parallel_comparison.setToolTip(TIP_PARALLEL_COMPARISON)
        self.__radio_longitudinal_comparison.setToolTip(
            TIP_LONGITUDINAL_COMPARISON)
        self.__line_lower.setToolTip(TIP_LIMIT_UPPER_LOWER)
        self.__line_upper.setToolTip(TIP_LIMIT_UPPER_LOWER)
        self.__button_show.setToolTip(TIP_BUTTON_SHOW)

        self.__button_draw.clicked.connect(self.on_button_draw)
        self.__button_show.clicked.connect(self.on_button_show)

        self.__combo_factor.currentIndexChanged.connect(self.on_factor_updated)
        self.__radio_parallel_comparison.clicked.connect(
            self.on_radio_comparison)
        self.__radio_longitudinal_comparison.clicked.connect(
            self.on_radio_comparison)

        mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei']
        mpl.rcParams['axes.unicode_minus'] = False

    def on_factor_updated(self, value):
        self.__line_lower.setText('')
        self.__line_upper.setText('')
        factor = self.__combo_factor.itemData(value)
        comments = self.__factor_center.get_factor_comments(factor)
        self.__label_comments.setText(comments)

    def on_button_draw(self):
        factor = self.__combo_factor.currentData()
        lower = str2float_safe(self.__line_lower.text(), None)
        upper = str2float_safe(self.__line_upper.text(), None)

        if self.__radio_parallel_comparison.isChecked():
            year = self.__combo_year.currentData()
            month_day = self.__combo_quarter.currentData()
            period = year + '-' + month_day
            industry = self.__combo_industry.currentData()
            self.plot_factor_parallel_comparison(factor, industry,
                                                 text_auto_time(period), lower,
                                                 upper)
        else:
            securities = self.__combo_stock.get_input_securities()
            self.plot_factor_longitudinal_comparison(factor, securities)

    def on_button_show(self):
        if self.__data_frame_widget is not None and \
                self.__data_frame_widget.isVisible():
            self.__data_frame_widget.close()
        if self.__paint_data is not None:
            self.__data_frame_widget = DataFrameWidget(self.__paint_data)
            self.__data_frame_widget.show()

    def on_radio_comparison(self):
        if self.__radio_parallel_comparison.isChecked():
            self.__combo_year.setEnabled(True)
            self.__combo_quarter.setEnabled(True)
            self.__line_lower.setEnabled(True)
            self.__line_upper.setEnabled(True)
            self.__combo_stock.setEnabled(False)
        else:
            self.__combo_year.setEnabled(False)
            self.__combo_quarter.setEnabled(False)
            self.__line_lower.setEnabled(False)
            self.__line_upper.setEnabled(False)
            self.__combo_stock.setEnabled(True)

    # ---------------------------------------------------------------------------------------

    def plot_factor_parallel_comparison(self, factor: str, industry: str,
                                        period: datetime.datetime,
                                        lower: float, upper: float):
        identities = ''
        if industry != '全部':
            identities = self.__data_utility.get_industry_stocks(industry)
        df = self.__data_center.query_from_factor('Factor.Finance',
                                                  identities, (period, period),
                                                  fields=[factor],
                                                  readable=True)

        s1 = df[factor]
        if lower is not None and upper is not None:
            s1 = s1.apply(lambda x: (x if x < upper else upper)
                          if x > lower else lower)
        elif lower is not None:
            s1 = s1.apply(lambda x: x if x > lower else lower)
        elif upper is not None:
            s1 = s1.apply(lambda x: x if x < upper else upper)

        plt.clf()
        plt.subplot(1, 1, 1)
        s1.hist(bins=100)
        plt.title(factor)

        self.__canvas.draw()
        self.__canvas.flush_events()

        self.__paint_data = df
        self.__paint_data.sort_values(factor, inplace=True)

    def plot_factor_longitudinal_comparison(self, factor: str,
                                            securities: str):
        df = self.__data_center.query_from_factor('Factor.Finance',
                                                  securities,
                                                  None,
                                                  fields=[factor],
                                                  readable=True)
        # Only for annual report
        df = df[df['period'].dt.month == 12]
        df['报告期'] = df['period']
        df.set_index('报告期', inplace=True)

        s1 = df[factor]

        plt.clf()
        plt.subplot(1, 1, 1)
        s1.plot.line()
        plt.title(factor)

        self.__canvas.draw()
        self.__canvas.flush_events()

        self.__paint_data = df
        self.__paint_data.sort_values('period', ascending=False, inplace=True)

    # ---------------------------------------------------------------------------------------

    def plot(self):
        self.plot_histogram_statistics()

    def plot_histogram_statistics(self):
        # --------------------------- The Data and Period We Want to Check ---------------------------

        stock = ''
        period = (text_auto_time('2018-12-01'), text_auto_time('2018-12-31'))

        # --------------------------------------- Query Pattern --------------------------------------

        # fields_balance_sheet = ['货币资金', '资产总计', '负债合计',
        #                         '短期借款', '一年内到期的非流动负债', '其他流动负债',
        #                         '长期借款', '应付债券', '其他非流动负债', '流动负债合计',
        #                         '应收票据', '应收账款', '其他应收款', '预付款项',
        #                         '交易性金融资产', '可供出售金融资产',
        #                         '在建工程', '商誉', '固定资产']
        # fields_income_statement = ['营业收入', '营业总收入', '减:营业成本', '息税前利润']
        #
        # df, result = batch_query_readable_annual_report_pattern(
        #     self.__data_hub, stock, period, fields_balance_sheet, fields_income_statement)
        # if result is not None:
        #     return result

        # df_balance_sheet, result = query_readable_annual_report_pattern(
        #     self.__data_hub, 'Finance.BalanceSheet', stock, period, fields_balance_sheet)
        # if result is not None:
        #     print('Data Error')
        #
        # df_income_statement, result = query_readable_annual_report_pattern(
        #     self.__data_hub, 'Finance.IncomeStatement', stock, period, fields_income_statement)
        # if result is not None:
        #     print('Data Error')

        # -------------------------------- Merge and Pre-processing --------------------------------

        # df = pd.merge(df_balance_sheet,
        #               df_income_statement,
        #               how='left', on=['stock_identity', 'period'])

        # df = df.sort_values('period')
        # df = df.reset_index()
        # df = df.fillna(0)
        # df = df.replace(0, 1)

        # ------------------------------------- Calc and Plot -------------------------------------

        mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei']
        mpl.rcParams['axes.unicode_minus'] = False

        # font = matplotlib.font_manager.FontProperties(fname='C:/Windows/Fonts/msyh.ttf')
        # mpl.rcParams['axes.unicode_minus'] = False

        # df['应收款'] = df['应收账款'] + df['应收票据']
        # df['净资产'] = df['资产总计'] - df['负债合计']
        # df['短期负债'] = df['短期借款'] + df['一年内到期的非流动负债'] + df['其他流动负债']
        # df['有息负债'] = df['短期负债'] + df['长期借款'] + df['应付债券'] + df['其他非流动负债']
        # df['金融资产'] = df['交易性金融资产'] + df['可供出售金融资产']
        #
        # df['财务费用正'] = df['减:财务费用'].apply(lambda x: x if x > 0 else 0)
        # df['三费'] = df['减:销售费用'] + df['减:管理费用'] + df['财务费用正']

        df = self.__data_utility.auto_query(
            '', period,
            ['减:财务费用', '减:销售费用', '减:管理费用', '营业总收入', '营业收入', '减:营业成本'],
            ['stock_identity', 'period'])

        df['毛利润'] = df['营业收入'] - df['减:营业成本']
        df['财务费用正'] = df['减:财务费用'].apply(lambda x: x if x > 0 else 0)
        df['三费'] = df['减:销售费用'] + df['减:管理费用'] + df['财务费用正']

        s1 = df['三费'] / df['营业总收入']
        s1 = s1.apply(lambda x: (x if x < 1 else 1) if x > -0.1 else -0.1)
        plt.subplot(2, 1, 1)
        s1.hist(bins=100)
        plt.title('三费/营业总收入')

        s2 = df['三费'] / df['毛利润']
        s2 = s2.apply(lambda x: (x if x < 1 else 1) if x > -0.1 else -0.1)
        plt.subplot(2, 1, 2)
        s2.hist(bins=100)
        plt.title('三费/毛利润')

        # s1 = df['货币资金'] / df['有息负债']
        # s1 = s1.apply(lambda x: x if x < 10 else 10)
        # plt.subplot(2, 1, 1)
        # s1.hist(bins=100)
        # plt.title('货币资金/有息负债')
        #
        # s2 = df['有息负债'] / df['资产总计']
        # plt.subplot(2, 1, 2)
        # s2.hist(bins=100)
        # plt.title('有息负债/资产总计')

        # s1 = df['应收款'] / df['营业收入']
        # s1 = s1.apply(lambda x: x if x < 2 else 2)
        # plt.subplot(4, 1, 1)
        # s1.hist(bins=100)
        # plt.title('应收款/营业收入')
        #
        # s2 = df['其他应收款'] / df['营业收入']
        # s2 = s2.apply(lambda x: x if x < 1 else 1)
        # plt.subplot(4, 1, 2)
        # s2.hist(bins=100)
        # plt.title('其他应收款/营业收入')
        #
        # s3 = df['预付款项'] / df['营业收入']
        # s3 = s3.apply(lambda x: x if x < 1 else 1)
        # plt.subplot(4, 1, 3)
        # s3.hist(bins=100)
        # plt.title('预付款项/营业收入')
        #
        # s4 = df['预付款项'] / df['减:营业成本']
        # s4 = s4.apply(lambda x: x if x < 1 else 1)
        # plt.subplot(4, 1, 4)
        # s4.hist(bins=100)
        # plt.title('预付款项/营业成本')

        # s1 = df['商誉'] / df['净资产']
        # s1 = s1.apply(lambda x: (x if x < 1 else 1) if x > 0 else 0)
        # plt.subplot(3, 1, 1)
        # s1.hist(bins=100)
        # plt.title('商誉/净资产')
        #
        # s2 = df['在建工程'] / df['净资产']
        # s2 = s2.apply(lambda x: (x if x < 1 else 1) if x > 0 else 0)
        # plt.subplot(3, 1, 2)
        # s2.hist(bins=100)
        # plt.title('在建工程/净资产')
        #
        # s2 = df['在建工程'] / df['资产总计']
        # s2 = s2.apply(lambda x: (x if x < 1 else 1) if x > 0 else 0)
        # plt.subplot(3, 1, 3)
        # s2.hist(bins=100)
        # plt.title('在建工程/资产总计')

        # s1 = df['固定资产'] / df['资产总计']
        # s1 = s1.apply(lambda x: (x if x < 1 else 1) if x > 0 else 0)
        # plt.subplot(2, 1, 1)
        # s1.hist(bins=100)
        # plt.title('固定资产/资产总计')
        #
        # s2 = df['息税前利润'] / df['固定资产']
        # s2 = s2.apply(lambda x: (x if x < 10 else 10) if x > -10 else -10)
        # plt.subplot(2, 1, 2)
        # s2.hist(bins=100)
        # plt.title('息税前利润/固定资产')

        # self.plot_proportion([
        #     ChartLab.PlotItem('固定资产', '资产总计', 0, 1),
        #     ChartLab.PlotItem('息税前利润', '固定资产', -10, 10),
        # ], text_auto_time('2018-12-31'))

        self.repaint()

    class PlotItem:
        def __init__(self,
                     num: str,
                     den: str,
                     lower: float or None = None,
                     upper: float or None = None,
                     bins: int = 100):
            self.numerator = num
            self.denominator = den
            self.limit_lower = lower
            self.limit_upper = upper
            self.plot_bins = bins

    def plot_proportion(self, plot_set: [PlotItem], period: datetime.datetime):
        df = self.prepare_plot_data(plot_set, period)

        plot_count = len(plot_set)
        for plot_index in range(plot_count):
            plot_item = plot_set[plot_index]
            s = df[plot_item.numerator] / df[plot_item.denominator]
            if plot_item.limit_lower is not None:
                s = s.apply(lambda x: max(x, plot_item.limit_lower))
            if plot_item.limit_upper is not None:
                s = s.apply(lambda x: min(x, plot_item.limit_upper))
            plt.subplot(plot_count, 1, plot_index + 1)
            s.hist(bins=plot_item.plot_bins)
            plt.title(plot_item.numerator + '/' + plot_item.denominator)

    def prepare_plot_data(self, plot_set: [PlotItem],
                          period: datetime.datetime) -> pd.DataFrame:
        fields = []
        for plot_item in plot_set:
            fields.append(plot_item.numerator)
            fields.append(plot_item.denominator)
        fields = list(set(fields))
        return self.__data_utility.auto_query(
            '', (period - datetime.timedelta(days=1), period), fields,
            ['stock_identity', 'period'])
Beispiel #23
0
class MinionTraceUi(QWidget):
    def __init__(self, parent):
        super(MinionTraceUi, self).__init__(parent)
        self.parent = parent
        self.hardware_counter = self.parent.hardware_counter
        if self.hardware_counter is True:
            self.counter = self.parent.counter

        # set initial parameters
        self.status = True   # True - allowed to measure, False - forbidden to measure (e.g. if counter is needed elsewhere)
        self.tracemin = 0.
        self.tracemax = 60.
        self.tracelength = 60.
        self.counttime = 0.005
        self.updatetime = 100
        self.tracex = np.ndarray([0])
        self.tracey1 = np.ndarray([0])  # apd1
        self.tracey2 = np.ndarray([0])  # apd2
        self.traceysum = np.ndarray([0])

        self.uisetup()

    def uisetup(self):
        self.tracefigure = Figure()
        self.tracecanvas = FigureCanvas(self.tracefigure)
        self.tracecanvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.tracecanvas.setMinimumSize(50, 50)
        self.toolbar = NavigationToolbar(self.tracecanvas, self)

        self.tracestartbutton = QPushButton('start')
        self.tracestartbutton.pressed.connect(self.tracestartclicked)
        self.tracestopbutton = QPushButton('stop')
        self.tracestopbutton.pressed.connect(self.tracestopclicked)

        self.traceminlabel = QLabel('t_min:')
        self.tracemintext = QDoubleSpinBox()
        self.tracemintext.setRange(0, 9999)
        self.tracemintext.setValue(int(self.tracemin))
        self.tracemintext.editingFinished.connect(self.updatetracesetting)

        self.tracemaxlabel = QLabel('t_max:')
        self.tracemaxtext = QDoubleSpinBox()
        self.tracemaxtext.setRange(0, 9999)
        self.tracemaxtext.setValue(int(self.tracemax))
        self.tracemaxtext.editingFinished.connect(self.updatetracesetting)

        self.tracelengthlabel = QLabel('dt:')
        self.tracelengthtext = QDoubleSpinBox()
        self.tracelengthtext.setRange(0, 9999)
        self.tracelengthtext.setDecimals(2)
        self.tracelengthtext.setValue(self.tracelength)
        self.tracelengthtext.editingFinished.connect(self.updatetracesetting)

        self.updatetimelabel = QLabel('updateintervall:')
        self.updatetimetext = QDoubleSpinBox()
        self.updatetimetext.setRange(10, 1000)
        self.updatetimetext.setDecimals(0)
        self.updatetimetext.setValue(self.updatetime)
        self.updatetimetext.editingFinished.connect(self.tracetimechanged)

        self.counttimelabel = QLabel('counttime [ms]:')
        self.counttimetext = QDoubleSpinBox()
        self.counttimetext.setRange(0, 1000)
        self.counttimetext.setValue(int(self.counttime*1000))
        self.counttimetext.editingFinished.connect(self.tracetimechanged)

        self.traceapd1check = QCheckBox('apd1')
        self.traceapd1check.stateChanged.connect(self.checkboxupdate)
        self.traceapd2check = QCheckBox('apd2')
        self.traceapd2check.stateChanged.connect(self.checkboxupdate)
        self.traceapdsumcheck = QCheckBox('sum')
        self.traceapdsumcheck.toggle()
        self.traceapdsumcheck.stateChanged.connect(self.checkboxupdate)

        # create layout
        trace_layout = QGridLayout()
        trace_layout.addWidget(self.tracecanvas, 0, 0, 5, 10)
        trace_layout.addWidget(self.toolbar, 5, 0, 1, 10)
        trace_layout.addWidget(self.tracelengthlabel, 6, 0, 1, 1)
        trace_layout.addWidget(self.tracelengthtext, 6, 1, 1, 1)

        trace_layout.addWidget(self.traceapd1check, 6, 2, 1, 1)
        trace_layout.addWidget(self.traceapd2check, 6, 3, 1, 1)
        trace_layout.addWidget(self.traceapdsumcheck, 6, 4, 1, 1)

        trace_layout.addWidget(self.traceminlabel, 7, 0, 1, 1)
        trace_layout.addWidget(self.tracemintext, 7, 1, 1, 1)
        trace_layout.addWidget(self.tracemaxlabel, 7, 2, 1, 1)
        trace_layout.addWidget(self.tracemaxtext, 7, 3, 1, 1)

        trace_layout.addWidget(self.counttimelabel, 8, 0, 1, 1)
        trace_layout.addWidget(self.counttimetext, 8, 1, 1, 1)
        trace_layout.addWidget(self.updatetimelabel, 8, 2, 1, 1)
        trace_layout.addWidget(self.updatetimetext, 8, 3, 1, 1)

        trace_layout.addWidget(self.tracestartbutton, 9, 0)
        trace_layout.addWidget(self.tracestopbutton, 9, 1)

        trace_layout.setSpacing(2)
        self.setLayout(trace_layout)

    def tracetimechanged(self):
        self.counttime = self.counttimetext.value()/1000.
        self.updatetime = self.updatetimetext.value()

        if self.hardware_counter is True:
            self.parent.fpga.setcountingtime(self.counttime)
            self.check_counttime = self.parent.fpga.checkcounttime()
            print('\t fpga counttime:', self.check_counttime)
        print('counttime:', self.counttime)

    def checkboxupdate(self):
        if self.traceapd1check.isChecked() is True:
            self.line1.set_marker('.')
        else:
            self.line1.set_marker('None')

        if self.traceapd2check.isChecked() is True:
            self.line2.set_marker('.')
        else:
            self.line2.set_marker('None')

        if self.traceapdsumcheck.isChecked() is True:
            self.line3.set_marker('.')
        else:
            self.line3.set_marker('None')
        self.updatetraceplot([], [], [], 1)

    def tracestartclicked(self):
        if self.status is True and self.hardware_counter is True:
            print("[%s] start trace" % QThread.currentThread().objectName())
            self.tracex = np.ndarray([0])
            self.tracey1 = np.ndarray([0])  # apd1
            self.tracey2 = np.ndarray([0])  # apd2
            self.traceysum = np.ndarray([0])
            self.tracefigure.clear()
            self.traceaxes = self.tracefigure.add_subplot(111)
            self.line1, = self.traceaxes.plot(self.tracex, self.tracey1, '.')
            self.line2, = self.traceaxes.plot(self.tracex, self.tracey2, '.')
            self.line3, = self.traceaxes.plot(self.tracex, self.traceysum, '.')

            self.traceaxes.set_autoscaley_on(True)
            self.tracefigure.canvas.draw()
            self.traceaxes.grid()

            self.traceaquisition = MinionTraceAquisition(self.counttime, self.updatetime, self.parent.fpga)
            self.tracethread = QThread(self, objectName='TraceThread')
            self.traceaquisition.moveToThread(self.tracethread)
            self.traceaquisition.tracestop.connect(self.tracethread.quit)

            self.tracethread.started.connect(self.traceaquisition.longrun)
            self.tracethread.finished.connect(self.tracethread.deleteLater)
            self.traceaquisition.updatetrace.connect(self.updatetraceplot)
            self.tracethread.start()
            self.checkboxupdate()

    def tracestopclicked(self):
        try:
            print('stop trace')
            self.traceaquisition.stop()
            self.tracethread.quit()
        except:
            print('no trace running')

    def updatetracesetting(self):
        self.tracemin = np.round(self.tracemintext.value(), decimals=2)
        self.tracemax = np.round(self.tracemaxtext.value(), decimals=2)
        self.tracelength = np.round(self.tracelengthtext.value(), decimals=2)
        self.updatetraceplot([], [], [], 1)

    @pyqtSlot(np.ndarray, np.ndarray, np.ndarray)
    def updatetraceplot(self, newx, newy1, newy2, updateflag=0):  # if updateflag=1 update min max - only after aquisition possible
        if updateflag == 0:  # TODO - simplify trace plot selection
            self.tracex = np.append(self.tracex, newx)
            self.tracey1 = np.append(self.tracey1, newy1)
            self.tracey2 = np.append(self.tracey2, newy2)
            self.traceysum = self.tracey1+self.tracey2

            self.line1.set_xdata(self.tracex)
            self.line1.set_ydata(self.tracey1)
            self.line2.set_xdata(self.tracex)
            self.line2.set_ydata(self.tracey2)
            self.line3.set_xdata(self.tracex)
            self.line3.set_ydata(self.traceysum)

            tracexminlim = self.tracex.max()-self.tracelength
            if tracexminlim < 0:
                tracexminlim = 0
            self.traceaxes.set_xlim(tracexminlim, self.tracex.max())

            if self.traceapdsumcheck.isChecked() and not self.traceapd1check.isChecked() and not self.traceapd2check.isChecked():
                self.traceaxes.set_ylim(self.traceysum.min(), self.traceysum.max())
            elif self.traceapd1check.isChecked() and not self.traceapd2check.isChecked() and not self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(self.tracey1.min(), self.tracey1.max())
            elif self.traceapd2check.isChecked() and not self.traceapd1check.isChecked() and not self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(self.tracey2.min(), self.tracey2.max())
            elif self.traceapd1check.isChecked() and not self.traceapd2check.isChecked() and self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(np.min([self.tracey1.min(), self.traceysum.min()]), np.max([self.tracey1.max(), self.traceysum.max()]))
            elif self.traceapd2check.isChecked() and not self.traceapd1check.isChecked() and self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(np.min([self.tracey2.min(), self.traceysum.min()]), np.max([self.tracey2.max(), self.traceysum.max()]))
            elif self.traceapd1check.isChecked() and self.traceapd2check.isChecked() and self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(np.min([self.tracey2.min(), self.tracey1.min(), self.traceysum.min()]), np.max([self.tracey2.max(), self.tracey1.max(), self.traceysum.max()]))
            elif self.traceapd1check.isChecked() and self.traceapd2check.isChecked() and not self.traceapdsumcheck.isChecked():
                self.traceaxes.set_ylim(np.min([self.tracey1.min(), self.tracey2.min()]), np.max([self.tracey1.max(), self.tracey2.max()]))

            self.traceaxes.relim()
            self.traceaxes.autoscale_view()
            self.tracecanvas.draw()
            self.tracecanvas.flush_events()

        elif updateflag == 1:
            if len(self.tracex) > 0:
                self.traceaxes.set_xlim(self.tracemin, self.tracemax)
                if self.traceapdsumcheck.isChecked() and not self.traceapd1check.isChecked() and not self.traceapd2check.isChecked():
                    self.traceaxes.set_ylim(self.traceysum.min(), self.traceysum.max())
                elif self.traceapd1check.isChecked() and not self.traceapd2check.isChecked() and not self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(self.tracey1.min(), self.tracey1.max())
                elif self.traceapd2check.isChecked() and not self.traceapd1check.isChecked() and not self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(self.tracey2.min(), self.tracey2.max())
                elif self.traceapd1check.isChecked() and not self.traceapd2check.isChecked() and self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(np.min([self.tracey1.min(), self.traceysum.min()]), np.max([self.tracey1.max(), self.traceysum.max()]))
                elif self.traceapd2check.isChecked() and not self.traceapd1check.isChecked() and self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(np.min([self.tracey2.min(), self.traceysum.min()]), np.max([self.tracey2.max(), self.traceysum.max()]))
                elif self.traceapd1check.isChecked() and self.traceapd2check.isChecked() and self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(np.min([self.tracey2.min(), self.tracey1.min(), self.traceysum.min()]), np.max([self.tracey2.max(), self.tracey1.max(), self.traceysum.max()]))
                elif self.traceapd1check.isChecked() and self.traceapd2check.isChecked() and not self.traceapdsumcheck.isChecked():
                    self.traceaxes.set_ylim(np.min([self.tracey1.min(), self.tracey2.min()]), np.max([self.tracey1.max(), self.tracey2.max()]))

                self.traceaxes.relim()
                self.traceaxes.autoscale_view()

                self.tracecanvas.draw()
                self.tracecanvas.flush_events()
Beispiel #24
0
class VideoDisplayWidget(QtWidgets.QWidget):
    def __init__(self, videos, app: QtWidgets.QMainWindow, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.app = app

        # -----
        # SETUP
        # -----

        # Set videos and initialize frame number
        self.videos = videos
        self.current_video = self.videos[0]
        self.frame_number = 0

        # Set layout
        self.layout = QtWidgets.QGridLayout()
        self.setLayout(self.layout)

        # -------
        # DISPLAY
        # -------

        # Create display widget for showing images
        self.display_widget = self.add_widget(1,
                                              0,
                                              layout=QtWidgets.QVBoxLayout)

        # Initialize display attributes
        self.display_methods = []
        self.display_image = None

        # Create combobox for switching between images
        self.box_widget = QtWidgets.QComboBox()
        self.add_display('Input image', self.input_image)
        self.display_index = 0
        self.display_function = self.input_image
        self.display_kwargs = {}
        self.box_widget.currentIndexChanged.connect(self.change_display_image)
        self.box_widget.setFixedSize(120, 25)
        self.display_widget.layout().addWidget(self.box_widget,
                                               alignment=QtCore.Qt.AlignRight)

        # Create figure widget
        self.figure = plt.figure(facecolor='0.95')
        self.ax = self.figure.add_axes([0, 0, 1, 1])
        self.ax.axis('off')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setMinimumSize(500, 500)
        self.canvas.setSizePolicy(QtWidgets.QSizePolicy().MinimumExpanding,
                                  QtWidgets.QSizePolicy().MinimumExpanding)
        self.display_widget.layout().addWidget(self.canvas)

        # Initialize the display
        self.update_display_image()  # initializes the display image
        self.image_ = self.ax.imshow(self.display_image,
                                     origin='upper',
                                     cmap='Greys_r',
                                     vmin=0,
                                     vmax=255)

        # -------
        # SLIDERS
        # -------

        # Create slider widget for adjusting e.g. frame, thresholds etc.
        self.slider_widget = self.add_widget(0,
                                             0,
                                             layout=QtWidgets.QVBoxLayout)
        self.slider_widget.layout().setSpacing(0)
        # Frame slider
        self.frame_slider = SliderWidget('Frame', 0,
                                         self.current_video.frame_count - 1, 0)
        self.frame_slider.value_changed.connect(self.change_frame)
        self.slider_widget.layout().addWidget(self.frame_slider)

        # ------
        # VIDEOS
        # ------

        # Create video widget for switching between videos
        self.video_widget = self.add_widget(0, 1, rowspan=2)
        self.video_widget.setMinimumWidth(150)
        self.video_widget.setMaximumWidth(200)
        self.video_list = QtWidgets.QListWidget()
        self.video_widget.layout().addWidget(self.video_list)
        # Add videos to list
        self.video_list.addItems([video.name for video in self.videos])
        self.video_list.itemSelectionChanged.connect(self.switch_video)
        self.video_list.setCurrentRow(0)

    def add_widget(self,
                   i,
                   j,
                   widget=QtWidgets.QWidget,
                   layout: QtWidgets.QLayout = QtWidgets.QGridLayout,
                   rowspan=1,
                   colspan=1):
        """Adds a new widget to the grid.

        Parameters
        ----------
        i, j : int
            Row number, column number.
        widget : QtWidgets.QWidget type
            The type of widget to add.
        layout : QtWidgets.QLayout type
            The layout type to use.
        rowspan, colspan : int
            Number of rows in grid widget should span, number of columns in grid widget should span.

        Returns
        -------
        QtWidgets.QWidget
            The newly created widget.
        """
        w = widget()
        w.setLayout(layout())
        self.layout.addWidget(w, i, j, rowspan, colspan)
        return w

    def add_display(self, name, func, **kwargs):
        self.box_widget.addItem(name)
        self.display_methods.append((func, kwargs))

    def draw(self):
        """Redraws the display image in the GUI."""
        self.image_.set_data(self.display_image)
        self.canvas.draw()
        self.canvas.flush_events()

    @QtCore.pyqtSlot()
    def switch_video(self):
        """Switches between videos."""
        selected_video_index = self.video_list.currentRow(
        )  # get the currently selected row of the video list
        self.current_video = self.videos[
            selected_video_index]  # set the new video
        self.frame_slider.set_range(
            0, self.current_video.frame_count
        )  # reset frame slider range to fit new video
        self.frame_slider.set_value(0)  # go to first frame of video

    @QtCore.pyqtSlot(int)
    def change_display_image(self, i):
        """Changes the image to be displayed (e.g. contours, tracking etc.)."""
        self.display_function, self.display_kwargs = self.display_methods[i]
        self.update_display_image()
        self.draw()

    @QtCore.pyqtSlot(int)
    def change_frame(self, frame):
        """Called when the frame changes."""
        self.frame_number = frame
        self.update_display_image()
        self.draw()

    @staticmethod
    def input_image(image, **kwargs):
        return image

    def update_display_image(self):
        with warnings.catch_warnings(
                record=True
        ) as w:  # catch frame warnings so that GUI does not crash
            warnings.simplefilter("always")
            image = self.current_video.grab_frame(
                self.frame_number)  # grab the current frame
            w = list(
                filter(lambda i: issubclass(i.category, FrameErrorWarning), w))
            if len(w):
                self.app.statusBar().showMessage(str(
                    w[0].message), 1000)  # show any warning in the status bar
            else:
                self.display_image = self.display_function(
                    image, **self.display_kwargs)
class Ui_SignalViewer(object):
    def setupUi(self, SignalViewer):
        SignalViewer.setObjectName("SignalViewer")
        SignalViewer.resize(1100, 888)
        SignalViewer.setStyleSheet("background-color: rgb(230, 230, 230);")
        ################################################
        self.figureecg = Figure()
        self.canvasecg = FigureCanvas(self.figureecg)
        ################################################
        self.figureemg = Figure()
        self.canvasemg = FigureCanvas(self.figureemg)
        ################################################
        self.figureeeg = Figure()
        self.canvaseeg = FigureCanvas(self.figureeeg)
        ################################################
        self.centralwidget = QtWidgets.QWidget(SignalViewer)
        self.centralwidget.setObjectName("centralwidget")
        self.verticalWidget = QtWidgets.QWidget(self.centralwidget)
        self.verticalWidget.setGeometry(QtCore.QRect(0, 0, 1101, 261))
        self.verticalWidget.setStyleSheet(
            "background-color: rgb(72, 72, 72);\n"
            "color: rgb(218, 218, 0);")
        self.verticalWidget.setObjectName("verticalWidget")
        self.verticalLayout = QtWidgets.QVBoxLayout(self.verticalWidget)
        self.verticalLayout.setContentsMargins(0, 0, 0, 0)
        self.verticalLayout.setObjectName("verticalLayout")
        self.line_2 = QtWidgets.QFrame(self.verticalWidget)
        self.line_2.setFrameShape(QtWidgets.QFrame.VLine)
        self.line_2.setFrameShadow(QtWidgets.QFrame.Sunken)
        self.line_2.setObjectName("line_2")
        self.verticalLayout.addWidget(self.line_2)
        self.line = QtWidgets.QFrame(self.verticalWidget)
        self.line.setFrameShape(QtWidgets.QFrame.VLine)
        self.line.setFrameShadow(QtWidgets.QFrame.Sunken)
        self.line.setObjectName("line")
        self.verticalLayout.addWidget(self.line)
        self.verticalWidget_2 = QtWidgets.QWidget(self.centralwidget)
        self.verticalWidget_2.setGeometry(QtCore.QRect(0, 260, 1101, 261))
        self.verticalWidget_2.setMinimumSize(QtCore.QSize(1101, 231))
        self.verticalWidget_2.setStyleSheet(
            "background-color: rgb(72, 72, 72);\n"
            "color: rgb(233, 235, 61);\n"
            "gridline-color: rgb(131, 131, 131);")
        self.verticalWidget_2.setObjectName("verticalWidget_2")
        self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.verticalWidget_2)
        self.verticalLayout_2.setContentsMargins(0, 0, 0, 0)
        self.verticalLayout_2.setObjectName("verticalLayout_2")
        self.verticalWidget_3 = QtWidgets.QWidget(self.centralwidget)
        self.verticalWidget_3.setGeometry(QtCore.QRect(0, 520, 1101, 271))
        self.verticalWidget_3.setStyleSheet(
            "background-color: rgb(72, 72, 72);")
        self.verticalWidget_3.setObjectName("verticalWidget_3")
        self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.verticalWidget_3)
        self.verticalLayout_3.setContentsMargins(0, 0, 0, 0)
        self.verticalLayout_3.setObjectName("verticalLayout_3")
        ##################################################################
        self.verticalLayout.addWidget(self.canvasecg)
        self.verticalLayout_2.addWidget(self.canvasemg)
        self.verticalLayout_3.addWidget(self.canvaseeg)
        ##################################################################
        self.groupBox = QtWidgets.QGroupBox(self.centralwidget)
        self.groupBox.setGeometry(QtCore.QRect(0, 790, 1075, 41))
        font = QtGui.QFont()
        font.setFamily("MS Sans Serif")
        font.setPointSize(14)
        font.setBold(True)
        font.setWeight(75)
        self.groupBox.setFont(font)
        self.groupBox.setTitle("")
        self.groupBox.setObjectName("groupBox")
        self.Playback = QtWidgets.QPushButton(self.groupBox)
        self.Playback.setGeometry(QtCore.QRect(60, 10, 51, 28))
        self.Playback.setObjectName("Playback")
        self.shift_left = QtWidgets.QPushButton(self.groupBox)
        self.shift_left.setGeometry(QtCore.QRect(110, 10, 51, 28))
        self.shift_left.setStyleSheet("font: 75 12pt \"MS Sans Serif\";")
        self.shift_left.setObjectName("shift_left")
        self.shift_right = QtWidgets.QPushButton(self.groupBox)
        self.shift_right.setGeometry(QtCore.QRect(160, 10, 51, 28))
        self.shift_right.setStyleSheet("font: 75 12pt \"MS Sans Serif\";")
        self.shift_right.setObjectName("shift_right")
        self.zin = QtWidgets.QPushButton(self.groupBox)
        self.zin.setGeometry(QtCore.QRect(260, 10, 51, 28))
        self.zin.setStyleSheet("font: 75 12pt \"MS Sans Serif\";")
        self.zin.setObjectName("zin")
        self.zout = QtWidgets.QPushButton(self.groupBox)
        self.zout.setGeometry(QtCore.QRect(210, 10, 51, 28))
        self.zout.setObjectName("zout")
        self.comboBox = QtWidgets.QComboBox(self.groupBox)
        self.comboBox.setGeometry(QtCore.QRect(860, 10, 92, 25))
        self.comboBox.setObjectName("comboBox")
        self.label_choose = QtWidgets.QLabel(self.groupBox)
        self.label_choose.setGeometry(QtCore.QRect(960, 10, 71, 21))
        self.label_choose.setText("")
        self.label_choose.setObjectName("label_choose")
        SignalViewer.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(SignalViewer)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 1100, 31))
        self.menubar.setObjectName("menubar")
        self.menufile = QtWidgets.QMenu(self.menubar)
        self.menufile.setObjectName("menufile")
        self.menuexport_PDF = QtWidgets.QMenu(self.menubar)
        self.menuexport_PDF.setObjectName("menuexport_PDF")
        SignalViewer.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(SignalViewer)
        self.statusbar.setObjectName("statusbar")
        SignalViewer.setStatusBar(self.statusbar)
        self.open_source1 = QtWidgets.QAction(SignalViewer)
        self.open_source1.setObjectName("open_source1")
        self.open_source2 = QtWidgets.QAction(SignalViewer)
        self.open_source2.setObjectName("open_source2")
        self.actioopen_source3 = QtWidgets.QAction(SignalViewer)
        self.actioopen_source3.setObjectName("actioopen_source3")
        self.actionexport_PDF = QtWidgets.QAction(SignalViewer)
        self.actionexport_PDF.setObjectName("actionexport_PDF")
        self.menufile.addAction(self.open_source1)
        self.menufile.addAction(self.open_source2)
        self.menufile.addAction(self.actioopen_source3)
        self.menuexport_PDF.addAction(self.actionexport_PDF)
        self.menubar.addAction(self.menufile.menuAction())
        self.menubar.addAction(self.menuexport_PDF.menuAction())
        ls = ["ECG", "EMG", "EEG"]
        self.comboBox.addItems(ls)
        self.open_source1.triggered.connect(self.imp_ecg)
        self.open_source2.triggered.connect(self.imp_emg)
        self.actioopen_source3.triggered.connect(self.imp_eeg)
        self.Playback.clicked.connect(self.pb)
        self.shift_left.clicked.connect(self.backword)
        self.shift_right.clicked.connect(self.forword)
        self.zin.clicked.connect(self.zoomin)
        self.zout.clicked.connect(self.zoomout)
        self.actionexport_PDF.triggered.connect(self.report)
        self.figureecg.set_facecolor((0.29, 0.29, 0.29))
        self.figureemg.set_facecolor((0.29, 0.29, 0.29))
        self.figureeeg.set_facecolor((0.29, 0.29, 0.29))

        self.retranslateUi(SignalViewer)
        QtCore.QMetaObject.connectSlotsByName(SignalViewer)

    def retranslateUi(self, SignalViewer):
        _translate = QtCore.QCoreApplication.translate
        SignalViewer.setWindowTitle(_translate("SignalViewer", "MainWindow"))
        self.verticalWidget.setToolTip(
            _translate("SignalViewer",
                       "<html><head/><body><p><br/></p></body></html>"))
        self.verticalWidget_2.setToolTip(
            _translate("SignalViewer",
                       "<html><head/><body><p><br/></p></body></html>"))
        self.verticalWidget_3.setToolTip(
            _translate("SignalViewer",
                       "<html><head/><body><p><br/></p></body></html>"))
        self.Playback.setText(_translate("SignalViewer", "►"))
        self.Playback.setShortcut(_translate("SignalViewer", "W"))
        self.shift_left.setText(_translate("SignalViewer", "<"))
        self.shift_left.setShortcut(_translate("SignalViewer", "E"))
        self.shift_right.setText(_translate("SignalViewer", ">"))
        self.shift_right.setShortcut(_translate("SignalViewer", "R"))
        self.zin.setText(_translate("SignalViewer", "+"))
        self.zin.setShortcut(_translate("SignalViewer", "Y"))
        self.zout.setText(_translate("SignalViewer", "-"))
        self.zout.setShortcut(_translate("SignalViewer", "T"))
        self.menufile.setTitle(_translate("SignalViewer", "file"))
        self.menuexport_PDF.setTitle(_translate("SignalViewer", "print"))
        self.open_source1.setText(_translate("SignalViewer", "open ECG"))
        self.open_source1.setShortcut(_translate("SignalViewer", "1"))
        self.open_source2.setText(_translate("SignalViewer", "open EMG"))
        self.open_source2.setShortcut(_translate("SignalViewer", "2"))
        self.actioopen_source3.setText(_translate("SignalViewer", "open EEG"))
        self.actioopen_source3.setShortcut(_translate("SignalViewer", "3"))
        self.actionexport_PDF.setText(_translate("SignalViewer", "export PDF"))
        self.actionexport_PDF.setShortcut(_translate("SignalViewer", "4"))

    def imp_ecg(self):
        self.path_ecg = QFileDialog.getOpenFileName(None, 'Open CSV ', '/home',
                                                    "CSV (*.csv)")[0]
        ds_ecg = pd.read_csv(self.path_ecg)
        self.x_ecg = ds_ecg.iloc[0:-1, 0].values
        self.y_ecg = ds_ecg.iloc[0:-1, 1].values
        self.comboBox.setCurrentText("ECG")
        #self.path_ecg=0
        self.frame_counter_ecg = 25
        self.flag_ecg = False
        #self.x_ecg=0
        #self.y_ecg=0
        ###############################

        #print(axis[60:80])
        ###############################

    def imp_emg(self):
        self.path_emg = QFileDialog.getOpenFileName(None, 'Open CSV ', '/home',
                                                    "CSV (*.csv)")[0]
        ds_emg = pd.read_csv(self.path_emg)
        self.y_emg = ds_emg.iloc[0:-1, 0].values
        self.axis_emg = np.linspace(0, len(self.y_emg) - 1, len(self.y_emg))
        self.comboBox.setCurrentText("EMG")
        self.frame_counter_emg = 25
        self.flag_emg = False
        self.pb()

    def imp_eeg(self):
        self.path_eeg = QFileDialog.getOpenFileName(None, 'Open CSV ', '/home',
                                                    "CSV (*.csv)")[0]
        ds_eeg = pd.read_csv(self.path_eeg)
        self.y_eeg = ds_eeg.iloc[0:-1, 0].values
        self.axis_eeg = np.linspace(0, len(self.y_eeg) - 1, len(self.y_eeg))
        self.comboBox.setCurrentText("EEG")
        self.frame_counter_eeg = 25
        self.flag_eeg = False
        self.pb()

    def pb(self):
        content = self.comboBox.currentText()
        if content == "ECG":
            if self.flag_ecg == False:
                self.flag_ecg = True
                c = self.frame_counter_ecg
                self.figureecg.clear()
                lines = [ax.plot([], [])[0] for ax in self.figureecg.axes]

                def update(i):
                    if not self.flag_ecg:
                        self.ani_ecg.event_source.stop()
                        self.canvasecg.flush_events()

                    else:
                        self.frame_counter_ecg = i + c
                        range_min = 2 * int(
                            ((self.frame_counter_ecg - 25) +
                             abs(self.frame_counter_ecg - 25)) / 2)
                        xa = self.x_ecg[range_min:2 * self.frame_counter_ecg]
                        y1 = self.y_ecg[range_min:2 * self.frame_counter_ecg]
                        ax = self.figureecg.gca()
                        ax.cla()
                        ax.set_ylim(min(self.y_ecg), max(self.y_ecg))
                        ax.set_facecolor((0.29, 0.29, 0.29))
                        ax.grid(True)
                        #ax.ylim(min(ym),max(ym))
                        ax.plot(xa, y1)
                        self.canvasecg.draw()
                    self.canvasecg.flush_events()
                    return lines

                self.ani_ecg = FuncAnimation(
                    self.figureecg,
                    update,
                    frames=np.arange(0,
                                     int(len(self.x_ecg) / 2) - 25),
                    interval=50,
                    blit=True)
            else:
                self.flag_ecg = False
        elif content == "EMG":
            if self.flag_emg == False:
                self.flag_emg = True
                c = self.frame_counter_emg
                self.figureemg.clear()
                lines = [ax.plot([], [])[0] for ax in self.figureemg.axes]

                def update(i):
                    if not self.flag_emg:
                        self.ani_emg.event_source.stop()
                        self.canvasemg.flush_events()

                    else:
                        self.frame_counter_emg = i + c
                        range_min = 2 * int(
                            ((self.frame_counter_emg - 25) +
                             abs(self.frame_counter_emg - 25)) / 2)
                        xa = self.axis_emg[range_min:self.frame_counter_emg *
                                           2]
                        y1 = self.y_emg[range_min:2 * self.frame_counter_emg]
                        ax = self.figureemg.gca()
                        ax.cla()
                        #ax.set_ylim(min(self.y_emg),max(self.y_emg))
                        ax.set_facecolor((0.29, 0.29, 0.29))
                        ax.grid(True)
                        #ax.set_ylim(min(self.y_emg),max(self.y_emg))
                        ax.plot(xa, y1)
                        self.canvasemg.draw()
                    self.canvasemg.flush_events()
                    return lines

                self.ani_emg = FuncAnimation(
                    self.figureemg,
                    update,
                    frames=np.arange(0,
                                     int(len(self.y_emg) / 2) - 25),
                    interval=50,
                    blit=True)
            else:
                self.flag_emg = False
        elif content == "EEG":
            if self.flag_eeg == False:
                self.flag_eeg = True
                c = self.frame_counter_eeg
                self.figureeeg.clear()
                lines = [ax.plot([], [])[0] for ax in self.figureeeg.axes]

                def update(i):
                    if not self.flag_eeg:
                        self.ani_eeg.event_source.stop()
                        self.canvaseeg.flush_events()

                    else:
                        self.frame_counter_eeg = i + c
                        range_min = 2 * int(
                            ((self.frame_counter_eeg - 25) +
                             abs(self.frame_counter_eeg - 25)) / 2)
                        xa = self.axis_eeg[range_min:self.frame_counter_eeg *
                                           2]
                        y1 = self.y_eeg[range_min:2 * self.frame_counter_eeg]
                        ax = self.figureeeg.gca()
                        ax.cla()
                        #ax.set_ylim(min(self.y_eeg),max(self.y_eeg))
                        ax.set_facecolor((0.29, 0.29, 0.29))
                        ax.grid(True)
                        #ax.set_ylim(min(self.y_eeg),max(self.y_eeg))
                        ax.plot(xa, y1)
                        self.canvaseeg.draw()
                    self.canvaseeg.flush_events()
                    return lines

                self.ani_eeg = FuncAnimation(
                    self.figureeeg,
                    update,
                    frames=np.arange(0,
                                     int(len(self.y_eeg) / 2) - 25),
                    interval=50,
                    blit=True)
            else:
                self.flag_eeg = False

    def backword(self):
        content = self.comboBox.currentText()
        if content == "ECG":
            if self.frame_counter_ecg > 35:
                self.frame_counter_ecg = self.frame_counter_ecg - 10
                range_min = 2 * int(((self.frame_counter_ecg - 25) +
                                     abs(self.frame_counter_ecg - 25)) / 2)
                xa = self.x_ecg[range_min:2 * self.frame_counter_ecg]
                ya = self.y_ecg[range_min:2 * self.frame_counter_ecg]
                self.figureecg.clear()
                ax = self.figureecg.add_subplot(111)
                ax.set_ylim(min(self.y_ecg), max(self.y_ecg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_ecg),max(self.y_ecg))
                ax.plot(xa, ya)
                self.canvasecg.draw()
                self.canvasecg.flush_events()
        elif content == "EMG":
            if self.frame_counter_emg > 35:
                self.frame_counter_emg = self.frame_counter_emg - 10
                range_min = 2 * int(((self.frame_counter_emg - 25) +
                                     abs(self.frame_counter_emg - 25)) / 2)
                xa = self.axis_emg[range_min:self.frame_counter_emg * 2]
                ya = self.y_emg[range_min:2 * self.frame_counter_emg]
                self.figureemg.clear()
                ax = self.figureemg.add_subplot(111)
                ax.set_ylim(min(self.y_emg), max(self.y_emg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_emg),max(self.y_emg))
                ax.plot(xa, ya)
                self.canvasemg.draw()
                self.canvasemg.flush_events()
        elif content == "EEG":
            if self.frame_counter_eeg > 35:
                self.frame_counter_eeg = self.frame_counter_eeg - 10
                range_min = 2 * int(((self.frame_counter_eeg - 25) +
                                     abs(self.frame_counter_eeg - 25)) / 2)
                xa = self.axis_eeg[range_min:self.frame_counter_eeg * 2]
                ya = self.y_eeg[range_min:2 * self.frame_counter_eeg]
                self.figureeeg.clear()
                ax = self.figureeeg.add_subplot(111)
                ax.set_ylim(min(self.y_eeg), max(self.y_eeg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_eeg),max(self.y_eeg))
                ax.plot(xa, ya)
                self.canvaseeg.draw()
                self.canvaseeg.flush_events()

    def forword(self):
        content = self.comboBox.currentText()
        if content == "ECG":
            if self.frame_counter_ecg < (len(self.x_ecg) / 2 - 10):
                self.frame_counter_ecg = self.frame_counter_ecg + 10
                range_min = 2 * int(((self.frame_counter_ecg - 25) +
                                     abs(self.frame_counter_ecg - 25)) / 2)
                xa = self.x_ecg[range_min:2 * self.frame_counter_ecg]
                ya = self.y_ecg[range_min:2 * self.frame_counter_ecg]
                self.figureecg.clear()
                ax = self.figureecg.add_subplot(111)
                ax.set_ylim(min(self.y_ecg), max(self.y_ecg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_ecg),max(self.y_ecg))
                ax.plot(xa, ya)
                self.canvasecg.draw()
                self.canvasecg.flush_events()
        elif content == "EMG":
            if self.frame_counter_emg < (len(self.y_emg) / 2 - 10):
                self.frame_counter_emg = self.frame_counter_emg + 10
                range_min = 2 * int(((self.frame_counter_emg - 25) +
                                     abs(self.frame_counter_emg - 25)) / 2)
                xa = self.axis_emg[range_min:self.frame_counter_emg * 2]
                ya = self.y_emg[range_min:2 * self.frame_counter_emg]
                self.figureemg.clear()
                ax = self.figureemg.add_subplot(111)
                ax.set_ylim(min(self.y_emg), max(self.y_emg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_emg),max(self.y_emg))
                ax.plot(xa, ya)
                self.canvasemg.draw()
                self.canvasemg.flush_events()
        elif content == "EEG":
            if self.frame_counter_eeg < (len(self.y_eeg) / 2 - 10):
                self.frame_counter_eeg = self.frame_counter_eeg + 10
                range_min = 2 * int(((self.frame_counter_eeg - 25) +
                                     abs(self.frame_counter_eeg - 25)) / 2)
                xa = self.axis_eeg[range_min:self.frame_counter_eeg * 2]
                ya = self.y_eeg[range_min:2 * self.frame_counter_eeg]
                self.figureeeg.clear()
                ax = self.figureeeg.add_subplot(111)
                ax.set_ylim(min(self.y_eeg), max(self.y_eeg))
                ax.set_facecolor((0.29, 0.29, 0.29))
                ax.grid(True)
                #ax.set_ylim(min(self.y_eeg),max(self.y_eeg))
                ax.plot(xa, ya)
                self.canvaseeg.draw()
                self.canvaseeg.flush_events()

    def zoomin(self):
        content = self.comboBox.currentText()
        if content == "ECG":
            range_min = 2 * int(((self.frame_counter_ecg - 25) +
                                 abs(self.frame_counter_ecg - 25)) / 2)
            xa = self.x_ecg[range_min:2 * self.frame_counter_ecg]
            ya = self.y_ecg[range_min:2 * self.frame_counter_ecg]
            self.figureecg.clear()
            ax = self.figureecg.add_subplot(111)
            #ax.set_ylim(min(self.y_ecg),max(self.y_ecg))
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            #ax.set_ylim(min(self.y_ecg),max(self.y_ecg))
            ax.margins(x=-0.3, y=0.05)
            ax.plot(xa, ya)
            self.canvasecg.draw()
            self.canvasecg.flush_events()
        elif content == "EMG":
            range_min = 2 * int(((self.frame_counter_emg - 25) +
                                 abs(self.frame_counter_emg - 25)) / 2)
            xa = self.axis_emg[range_min:self.frame_counter_emg * 2]
            ya = self.y_emg[range_min:2 * self.frame_counter_emg]
            self.figureemg.clear()
            ax = self.figureemg.add_subplot(111)
            #ax.set_ylim(min(self.y_emg),max(self.y_emg))
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            ax.margins(x=-0.2, y=0.05)
            ax.plot(xa, ya)
            self.canvasemg.draw()
            self.canvasemg.flush_events()
        elif content == "EEG":
            range_min = 2 * int(((self.frame_counter_eeg - 25) +
                                 abs(self.frame_counter_eeg - 25)) / 2)
            xa = self.axis_eeg[range_min:self.frame_counter_eeg * 2]
            ya = self.y_eeg[range_min:2 * self.frame_counter_eeg]
            self.figureeeg.clear()
            ax = self.figureeeg.add_subplot(111)
            #ax.set_ylim(min(self.y_eeg),max(self.y_eeg))
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            ax.margins(x=-0.2, y=0.05)
            ax.plot(xa, ya)
            self.canvaseeg.draw()
            self.canvaseeg.flush_events()

    def zoomout(self):
        content = self.comboBox.currentText()
        if content == "ECG":
            range_min = 2 * int(((self.frame_counter_ecg - 25) +
                                 abs(self.frame_counter_ecg - 25)) / 2)
            xa = self.x_ecg[range_min:2 * self.frame_counter_ecg]
            ya = self.y_ecg[range_min:2 * self.frame_counter_ecg]
            self.figureecg.clear()
            ax = self.figureecg.add_subplot(111)
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            #ax.set_ylim(min(self.y_ecg),max(self.y_ecg))
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            ax.margins(x=0.05, y=2)
            ax.plot(xa, ya)
            self.canvasecg.draw()
            self.canvasecg.flush_events()
        elif content == "EMG":
            range_min = 2 * int(((self.frame_counter_emg - 25) +
                                 abs(self.frame_counter_emg - 25)) / 2)
            xa = self.axis_emg[range_min:self.frame_counter_emg * 2]
            ya = self.y_emg[range_min:2 * self.frame_counter_emg]
            self.figureemg.clear()
            ax = self.figureemg.add_subplot(111)
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            ax.margins(x=0.05, y=2)
            ax.plot(xa, ya)
            self.canvasemg.draw()
            self.canvasemg.flush_events()
        elif content == "EEG":
            range_min = 2 * int(((self.frame_counter_eeg - 25) +
                                 abs(self.frame_counter_eeg - 25)) / 2)
            xa = self.axis_eeg[range_min:self.frame_counter_eeg * 2]
            ya = self.y_eeg[range_min:2 * self.frame_counter_eeg]
            self.figureeeg.clear()
            ax = self.figureeeg.add_subplot(111)
            ax.set_facecolor((0.29, 0.29, 0.29))
            ax.grid(True)
            ax.margins(x=0.05, y=10)
            ax.plot(xa, ya)
            self.canvaseeg.draw()
            self.canvaseeg.flush_events()

    def report(self):
        ecg = pd.read_csv(self.path_ecg)
        x1 = ecg.iloc[0:-1, 0].values
        y1 = ecg.iloc[0:-1, 1].values

        eeg = pd.read_csv(self.path_eeg)
        y2 = eeg.iloc[0:-1, 0].values
        x2 = []
        for i in range(0, len(y2), 1):
            x2.append(i)

        emg1 = pd.read_csv(self.path_emg)
        y3 = emg1.iloc[0:-1, 0].values
        x3 = []
        for i in range(0, len(y3), 1):
            x3.append(i)
        fig = plt

        def Signalplot(a, b, c, d, e, f):
            plt.subplot(2, 3, 1)
            plt.title("ECG")
            plt.xlabel('Time')
            plt.ylabel('milivolt')
            plt.plot(a, b)

            # EEG
            plt.subplot(2, 3, 2)
            plt.title("EEG")
            plt.xlabel('Samples')
            plt.ylabel('Amplitude')
            plt.plot(c, d)
            # EMG
            plt.subplot(2, 3, 3)
            plt.title("EMG")
            plt.xlabel('Samples')
            plt.ylabel('Amplitude')
            plt.plot(e, f)

        def Spplot(
            a,
            b,
            c,
        ):
            plt.subplot(2, 3, 4)
            powerSpectrum, freqenciesFound, time, imageAxis = plt.specgram(a)
            plt.xlabel('Time')
            plt.ylabel('Frequency')
            # EEG
            plt.subplot(2, 3, 5)
            powerSpectrum, freqenciesFound, time, imageAxis = plt.specgram(b)
            plt.xlabel('Time')
            plt.ylabel('Frequency')
            # EMG
            plt.subplot(2, 3, 6)
            plt.xlabel('Time')
            plt.ylabel('Frequency')
            powerSpectrum, freqenciesFound, time, imageAxis = plt.specgram(c)

        Signalplot(x1, y1, x2, y2, x3, y3)
        Spplot(y1, y2, y3)
        fig.tight_layout()
        print("printed")
        plt.savefig('Report_Team_2.pdf')
Beispiel #26
0
class mainWindow(QMainWindow):

    updateActivePix = pyqtSignal()

    def __init__(self, parent=None):
        QMainWindow.__init__(self, parent=parent)
        self.initializeEmptyArrays()
        self.setWindowTitle('quickLook_img.py')
        self.resize(
            600, 850
        )  #(600,850 works for clint's laptop screen. Units are pixels I think.)
        self.create_main_frame()
        self.create_status_bar()
        self.createMenu()
        #self.load_beam_map()

    def initializeEmptyArrays(self, nCol=80, nRow=125):
        self.nCol = nCol
        self.nRow = nRow

        self.rawCountsImage = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.hotPixMask = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.hotPixCut = 2400
        self.image = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))
        self.beamFlagMask = np.zeros(self.nRow * self.nCol).reshape(
            (self.nRow, self.nCol))

    def get_nPixels(self, filename):
        #140 x 146 for MEC
        #80 x 125 for darkness

        npixels = len(np.fromfile(open(filename, mode='rb'), dtype=np.uint16))
        print('npixels = ', npixels, '\n')

        if npixels == 10000:  #darkness
            nCol = 80
            nRow = 125
            print('\n\ncamera is DARKNESS/PICTURE-C\n\n')
        elif npixels == 20440:  #mec
            nCol = 140
            nRow = 146
            print('\n\ncamera is MEC\n\n')
        else:
            raise ValueError('img does not have 10000 or 20440 pixels')

        return nCol, nRow

    def load_IMG_filenames(self, filename):
        print('\nloading img filenames')

        self.imgPath = os.path.dirname(filename)
        fileListRaw = []
        timeStampList = np.array([])
        ii = 0
        for file in os.listdir(self.imgPath):
            if file.endswith(".img"):
                fileListRaw = fileListRaw + [os.path.join(self.imgPath, file)]
                timeStampList = np.append(
                    timeStampList,
                    np.fromstring(file[:-4], dtype=int, sep=' ')[0])
            else:
                continue
            ii += 1

        #the files may not be in chronological order, so let's enforce it
        fileListRaw = np.asarray(fileListRaw)
        fileListRaw = fileListRaw[np.argsort(timeStampList)]
        timeStampList = np.sort(np.asarray(timeStampList))

        self.fileListRaw = fileListRaw
        self.timeStampList = timeStampList

        print('\nfound {:d} .img files\n'.format(len(self.timeStampList)))
        print('first timestamp: ', self.timeStampList[0])
        print('last timestamp:  ', self.timeStampList[-1], '\n')

    def load_log_filenames(self):
        #check if directory exists
        if not os.path.exists(self.logPath):
            text = 'log file path not found.\n Check log file path.'
            self.label_log.setText(text)

            self.logTimestampList = np.asarray([])
            self.logFilenameList = np.asarray([])

            return

        #load the log filenames
        print('\nloading log filenames\n')
        logFilenameList = []
        logTimestampList = []

        for logFilename in os.listdir(self.logPath):

            if logFilename.endswith("telescope.log"):
                continue
            elif logFilename.endswith(".log"):
                logFilenameList.append(logFilename)
                logTimestampList.append(
                    np.fromstring(logFilename[:10], dtype=int, sep=' ')[0])

        #the files may not be in chronological order, so let's enforce it
        logFilenameList = np.asarray(logFilenameList)
        logFilenameList = logFilenameList[np.argsort(logTimestampList)]
        logTimestampList = np.sort(np.asarray(logTimestampList))

        self.logTimestampList = np.asarray(logTimestampList)
        self.logFilenameList = logFilenameList

    def load_beam_map(self):
        filename, _ = QFileDialog.getOpenFileName(
            self,
            'Select One File',
            '/mnt/data0/Darkness/20180522/Beammap/',
            filter='*.txt')
        resID, flag, xPos, yPos = np.loadtxt(filename, unpack=True, dtype=int)

        #resID, flag, xPos, yPos = np.loadtxt('/mnt/data0/Darkness/20180522/Beammap/finalMap_20180524.txt', unpack=True,dtype = int)

        temp = np.nonzero(flag)  #get the indices of the nonzero elements.

        self.beamFlagMask[yPos[temp]][xPos[
            temp]] = 1  #beamFlagMask is 1 when the pixel is not beam mapped
        #self.beamFlagMask = beamFlagMask

    def initialize_spinbox_values(self, filename):
        #set up the spinbox limits and start value, which will be the file you selected
        self.spinbox_imgTimestamp.setMinimum(self.timeStampList[0])
        self.spinbox_imgTimestamp.setMaximum(self.timeStampList[-1])
        self.spinbox_imgTimestamp.setValue(
            np.fromstring(os.path.basename(filename)[:-4], dtype=int,
                          sep=' ')[0])

        self.spinbox_darkStart.setMinimum(self.timeStampList[0])
        self.spinbox_darkStart.setMaximum(self.timeStampList[-10])
        self.spinbox_darkStart.setValue(
            np.fromstring(os.path.basename(filename)[:-4], dtype=int,
                          sep=' ')[0])

    def plotImage(self, filename=None):

        if filename == None:
            filename = self.fileListRaw[np.where(
                self.timeStampList == self.spinbox_imgTimestamp.value())[0][0]]

        self.ax1.clear()

        with open(filename, mode='rb') as f:
            self.rawImage = np.transpose(
                np.reshape(np.fromfile(f, dtype=np.uint16),
                           (self.nCol, self.nRow)))


#        image=np.fromfile(open(fn, mode='rb'),dtype=np.uint16)
#        image = np.transpose(np.reshape(image, (self.nCols, self.nRows)))

        if self.checkbox_darkSubtract.isChecked():
            self.cleanedImage = self.rawImage - self.darkFrame
            self.cleanedImage[np.where(self.cleanedImage < 0)] = 0

        else:
            self.cleanedImage = self.rawImage

        #colorbar auto
        if self.checkbox_colorbar_auto.isChecked():
            self.cbarLimits = np.array([0, np.amax(self.image)])
            self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
            self.fig.cbar.draw_all()
        else:
            self.cbarLimits = np.array([0, self.spinbox_colorBarMax.value()])
            self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
            self.fig.cbar.draw_all()

        self.cleanedImage[np.where(self.cleanedImage > self.hotPixCut)] = 0
        self.cleanedImage = self.cleanedImage * np.logical_not(
            self.beamFlagMask)
        self.image = self.cleanedImage
        self.ax1.imshow(self.image,
                        vmin=self.cbarLimits[0],
                        vmax=self.cbarLimits[1])
        self.ax1.axis('off')

        self.draw()

    def getDarkFrame(self):
        #get an average dark from darkStart to darkStart + darkIntTime
        darkIntTime = self.spinbox_darkIntTime.value()
        darkFrame = np.zeros(darkIntTime * self.nRow * self.nCol).reshape(
            (darkIntTime, self.nRow, self.nCol))

        for ii in range(darkIntTime):
            try:
                darkFrameFilename = self.fileListRaw[np.where(
                    self.timeStampList == (self.spinbox_darkStart.value() +
                                           ii))[0][0]]
            except:
                pass
            else:
                darkFrame[ii] = np.transpose(
                    np.reshape(
                        np.fromfile(open(darkFrameFilename, mode='rb'),
                                    dtype=np.uint16), (self.nCol, self.nRow)))

        self.darkFrame = np.median(darkFrame, axis=0)

    def plotBlank(self):
        self.ax1.imshow(
            np.zeros(self.nRow * self.nCol).reshape((self.nRow, self.nCol)))

    def updateLogLabel(self, IMG_fileExists=True):

        timestamp = self.spinbox_imgTimestamp.value()

        #check if self.logTimestampList has more than zero entries. If not, return.
        if len(self.logTimestampList) == 0:
            text = datetime.datetime.fromtimestamp(timestamp).strftime(
                '%Y-%m-%d %H:%M:%S\n\n'
            ) + 'no log file found.\n Check log file path.'
            self.label_log.setText(text)
            return

        #check if the img exists, if not then return
        if IMG_fileExists == False:
            text = datetime.datetime.fromtimestamp(timestamp).strftime(
                '%Y-%m-%d %H:%M:%S\n\n') + 'no .img file found'
            self.label_log.setText(text)
            return

        #check if a nearby log file exists, then pick the closest one
        diffs = timestamp - self.logTimestampList
        if np.sum(np.abs(diffs) < 3600) == 0:  #nearby means within 1 hour.
            text = datetime.datetime.fromtimestamp(timestamp).strftime(
                '%Y-%m-%d %H:%M:%S\n\n') + 'nearest log is ' + str(
                    np.amin(diffs)) + '\nseconds away from img'
            self.label_log.setText(text)
            return

        diffs[np.where(diffs < 0)] = np.amax(diffs)

        logLabelTimestamp = self.logTimestampList[np.argmin(diffs)]

        labelFilename = self.logFilenameList[np.where(
            self.logTimestampList == logLabelTimestamp)[0][0]]

        #print('labelFilename is ', os.path.join(os.environ['MKID_RAW_PATH'],labelFilename))
        #fin=open(os.path.join(os.environ['MKID_RAW_PATH'],labelFilename),'r')
        fin = open(os.path.join(self.logPath, labelFilename), 'r')
        text = 'img timestamp:\n' + datetime.datetime.fromtimestamp(
            timestamp).strftime(
                '%Y-%m-%d %H:%M:%S'
            ) + '\n\nLogfile time:\n' + datetime.datetime.fromtimestamp(
                logLabelTimestamp).strftime(
                    '%Y-%m-%d %H:%M:%S\n'
                ) + '\n' + labelFilename[:-4] + '\n' + fin.read()
        self.label_log.setText(text)
        fin.close()

    def create_main_frame(self):
        """
        Makes GUI elements on the window
        """
        #Define the plot window.
        self.main_frame = QWidget()
        self.dpi = 100
        self.fig = Figure(
            figsize=(5.0, 10.0), dpi=self.dpi, tight_layout=True
        )  #define the figure, set the max size (inches) and resolution. Overall window size is set with QMainWindow parameter.
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.main_frame)
        self.ax1 = self.fig.add_subplot(111)
        self.ax1.axis('off')
        self.foo = self.ax1.imshow(self.image, interpolation='none')
        self.fig.cbar = self.fig.colorbar(self.foo)

        #spinboxes for the img timestamp
        self.spinbox_imgTimestamp = QSpinBox()
        self.spinbox_imgTimestamp.valueChanged.connect(self.spinBoxValueChange)

        #spinboxes for specifying dark frames
        self.spinbox_darkStart = QSpinBox()
        self.spinbox_darkStart.valueChanged.connect(self.getDarkFrame)
        self.spinbox_darkIntTime = QSpinBox()
        #set up the limits and initial value of the darkIntTime
        self.spinbox_darkIntTime.setMinimum(1)
        self.spinbox_darkIntTime.setMaximum(1000)
        self.spinbox_darkIntTime.setValue(10)
        self.spinbox_darkIntTime.valueChanged.connect(self.getDarkFrame)

        #labels for the start/stop time spinboxes
        label_imgTimestamp = QLabel('IMG timestamp')
        label_darkStart = QLabel('dark Start')
        label_darkIntTime = QLabel('dark int time [s]')

        #make a checkbox for the colorbar autoscale
        self.checkbox_colorbar_auto = QCheckBox()
        self.checkbox_colorbar_auto.setChecked(False)
        self.checkbox_colorbar_auto.stateChanged.connect(
            self.spinBoxValueChange)

        label_checkbox_colorbar_auto = QLabel('Auto colorbar')

        self.spinbox_colorBarMax = QSpinBox()
        self.spinbox_colorBarMax.setRange(1, 2500)
        self.spinbox_colorBarMax.setValue(2000)
        self.spinbox_colorBarMax.valueChanged.connect(self.spinBoxValueChange)

        #make a checkbox for the dark subtract
        self.checkbox_darkSubtract = QCheckBox()
        self.checkbox_darkSubtract.setChecked(False)
        self.checkbox_darkSubtract.stateChanged.connect(
            self.spinBoxValueChange)

        #make a label for the dark subtract checkbox
        label_darkSubtract = QLabel('dark subtract')

        #make a label for the logs
        self.label_log = QLabel('')

        #make a label to display the IMG path and the MKID_RAW_PATH. Also set up log path variable
        try:
            os.environ['MKID_IMG_DIR']
        except:
            labelText = 'MKID_IMG_DIR:      could not find MKID_IMG_DIR'
            self.imgPath = '/'
        else:
            labelText = 'MKID_IMG_DIR:      ' + os.environ['MKID_IMG_DIR']
            self.imgPath = os.environ['MKID_IMG_DIR']

        self.label_IMG_path = QLabel(labelText)
        self.label_IMG_path.setToolTip(
            'Look for img files in this directory. To change, go to File>Open img file'
        )

        try:
            os.environ['MKID_RAW_PATH']
        except:
            labelText = 'MKID_RAW_PATH:  could not find MKID_RAW_PATH'
            self.logPath = '/'
        else:
            labelText = 'MKID_RAW_PATH:  ' + os.environ['MKID_RAW_PATH']
            self.logPath = os.environ['MKID_RAW_PATH']

        self.label_log_path = QLabel(labelText)
        self.label_log_path.setToolTip(
            'Look for log files in this directory. To change, go to File>Change log path.'
        )

        #create a vertical box for the plot to go in.
        vbox_plot = QVBoxLayout()
        vbox_plot.addWidget(self.canvas)

        #create a v box for the timestamp spinbox
        vbox_imgTimestamp = QVBoxLayout()
        vbox_imgTimestamp.addWidget(label_imgTimestamp)
        vbox_imgTimestamp.addWidget(self.spinbox_imgTimestamp)

        #make an hbox for the dark start
        hbox_darkStart = QHBoxLayout()
        hbox_darkStart.addWidget(label_darkStart)
        hbox_darkStart.addWidget(self.spinbox_darkStart)

        #make an hbox for the dark integration time
        hbox_darkIntTime = QHBoxLayout()
        hbox_darkIntTime.addWidget(label_darkIntTime)
        hbox_darkIntTime.addWidget(self.spinbox_darkIntTime)

        #make an hbox for the dark subtract checkbox
        hbox_darkSubtract = QHBoxLayout()
        hbox_darkSubtract.addWidget(label_darkSubtract)
        hbox_darkSubtract.addWidget(self.checkbox_darkSubtract)

        #make a vbox for the autoscale colorbar
        hbox_autoscale = QHBoxLayout()
        hbox_autoscale.addWidget(label_checkbox_colorbar_auto)
        hbox_autoscale.addWidget(self.checkbox_colorbar_auto)
        hbox_autoscale.addWidget(self.spinbox_colorBarMax)

        #make a vbox for dark times
        vbox_darkTimes = QVBoxLayout()
        vbox_darkTimes.addLayout(hbox_darkStart)
        vbox_darkTimes.addLayout(hbox_darkIntTime)
        vbox_darkTimes.addLayout(hbox_darkSubtract)
        vbox_darkTimes.addLayout(hbox_autoscale)

        hbox_controls = QHBoxLayout()
        hbox_controls.addLayout(vbox_imgTimestamp)
        hbox_controls.addLayout(vbox_darkTimes)
        hbox_controls.addWidget(self.label_log)

        #Now create another vbox, and add the plot vbox and the button's hbox to the new vbox.
        vbox_combined = QVBoxLayout()
        vbox_combined.addLayout(vbox_plot)
        #        vbox_combined.addLayout(hbox_imgTimestamp)
        vbox_combined.addLayout(hbox_controls)
        vbox_combined.addWidget(self.label_IMG_path)
        vbox_combined.addWidget(self.label_log_path)

        #Set the main_frame's layout to be vbox_combined
        self.main_frame.setLayout(vbox_combined)

        #Set the overall QWidget to have the layout of the main_frame.
        self.setCentralWidget(self.main_frame)

        #set up the pyqt5 events
        cid = self.fig.canvas.mpl_connect('motion_notify_event',
                                          self.hoverCanvas)
        cid3 = self.fig.canvas.mpl_connect('scroll_event',
                                           self.scroll_ColorBar)

    def spinBoxValueChange(self):
        try:
            filename = self.fileListRaw[np.where(
                self.timeStampList == self.spinbox_imgTimestamp.value())[0][0]]
        except:
            self.plotBlank()
            self.updateLogLabel(IMG_fileExists=False)
        else:
            self.plotImage(filename)
            self.updateLogLabel()

    def draw(self):
        #The plot window calls this function
        self.canvas.draw()
        self.canvas.flush_events()

    def hoverCanvas(self, event):
        if event.inaxes is self.ax1:
            col = int(round(event.xdata))
            row = int(round(event.ydata))
            if row < self.nRow and col < self.nCol:
                self.status_text.setText('({:d},{:d}) {}'.format(
                    col, row, self.image[row, col]))

    def scroll_ColorBar(self, event):
        if event.inaxes is self.fig.cbar.ax:
            stepSize = 0.1  #fractional change in the colorbar scale
            if event.button == 'up':
                self.cbarLimits[1] *= (1 + stepSize)  #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,
                                interpolation='none',
                                vmin=self.cbarLimits[0],
                                vmax=self.cbarLimits[1])
            elif event.button == 'down':
                self.cbarLimits[1] *= (1 - stepSize)  #increment by step size
                self.fig.cbar.set_clim(self.cbarLimits[0], self.cbarLimits[1])
                self.fig.cbar.draw_all()
                self.ax1.imshow(self.image,
                                interpolation='none',
                                vmin=self.cbarLimits[0],
                                vmax=self.cbarLimits[1])

            else:
                pass

        self.draw()

    def create_status_bar(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/popup.py
        self.status_text = QLabel("")
        self.statusBar().addWidget(self.status_text, 1)

    def createMenu(self):
        #Using code from ARCONS-pipeline as an example:
        #ARCONS-pipeline/util/quicklook.py
        self.menubar = self.menuBar()
        self.fileMenu = self.menubar.addMenu("&File")

        openFileButton = QAction('Open img File', self)
        openFileButton.setShortcut('Ctrl+O')
        openFileButton.setStatusTip('Open an img File')
        openFileButton.triggered.connect(
            lambda x: self.getFileNameFromUser(fileType='img'))
        self.fileMenu.addAction(openFileButton)

        changeLogDirectory_Button = QAction('Change log directory', self)
        changeLogDirectory_Button.setShortcut('Ctrl+l')
        changeLogDirectory_Button.setStatusTip(
            'Opens a dialog box so user can select log file manually.')
        changeLogDirectory_Button.triggered.connect(
            lambda x: self.getFileNameFromUser(fileType='log'))
        self.fileMenu.addAction(changeLogDirectory_Button)

        self.fileMenu.addSeparator()

        exitButton = QAction('Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.setStatusTip('Exit application')
        exitButton.triggered.connect(self.close)
        self.fileMenu.addAction(exitButton)

        self.menubar.setNativeMenuBar(False)  #This is for MAC OS

    def getFileNameFromUser(self, fileType):
        # look at this website for useful examples
        # https://pythonspot.com/pyqt5-file-dialog/
        if fileType == 'img':
            try:
                filename, _ = QFileDialog.getOpenFileName(self,
                                                          'Select One File',
                                                          self.imgPath,
                                                          filter='*.img')
            except:
                filename, _ = QFileDialog.getOpenFileName(self,
                                                          'Select One File',
                                                          '/',
                                                          filter='*.img')

            if filename == '':
                print('\nno file selected\n')
                return

            self.imgPath = os.path.dirname(filename)
            self.label_IMG_path.setText('img path:  ' + self.imgPath)

            self.filename = filename
            self.load_IMG_filenames(self.filename)
            self.load_log_filenames()
            self.nCol, self.nRow = self.get_nPixels(self.filename)
            self.initializeEmptyArrays(self.nCol, self.nRow)
            self.initialize_spinbox_values(self.filename)

        elif fileType == 'log':
            try:
                filename, _ = QFileDialog.getOpenFileName(self,
                                                          'Select One File',
                                                          self.logPath,
                                                          filter='*.log')
            except:
                filename, _ = QFileDialog.getOpenFileName(self,
                                                          'Select One File',
                                                          '/',
                                                          filter='*.log')

            if filename == '':
                print('\nno file selected\n')
                return

            self.logPath = os.path.dirname(filename)
            self.label_log_path.setText('log path:  ' + self.logPath)
            self.load_log_filenames()
            self.updateLogLabel()

        else:
            return
Beispiel #27
0
class Ui_Plot(QWidget):
    def __init__(self):
        super().__init__()
        self.title = 'Plot'
        self.left = 328
        self.top = 10
        self.width = 1000
        self.height = 300
        self.initUI()

    def initUI(self):
        self.setWindowTitle(self.title)
        self.setGeometry(self.left, self.top, self.width, self.height)

    def plot_initialize(self, config_data):
        """
        Initialize the plot

        Parameters(passed from GUI_operator.py read_write function)
        ----------
        config_data : dictionary, containing data read from the config file
                      Instance is created in the GUI_run_exp.py - run_exp function


        """
        # Constants for every experiment
        self.fig = Figure(figsize=(10, 7.5))
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.add_subplot(111, position=[0.12, 0.13, 0.85, 0.85])
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.toolbar.setStyleSheet("border: 0 px ;\n")

        self.ax.clear()  # discards the old graph

        self.ax.set_xlim(-3, 3)  # set the start range
        self.ax.set_ylim(-3, 3)

        self.exp_type = cr.get_exp_type(config_data)

        if self.exp_type == "CA":
            self.exp_time = cr.get_exp_time(config_data)
            self.ax.set_xlim(0, 2 * self.exp_time)
            self.ax.set_xlabel("Time (s)")
            self.ax.set_ylabel("Current (mA)")
            self.lines, = self.ax.plot([], [], 'r')

        elif self.exp_type == "LSV" or self.exp_type == "CV":
            self.ax.set_xlabel("Voltage (V)")
            self.ax.set_ylabel("Current (mA)")
            self.lines, = self.ax.plot([], [], 'r')

    def plot_updater(self, data):
        """
        Update the plot.

        Parameters (passed from GUI_operator.py read_write function)
        ----------
        data : tuple , (times, voltages, currents) , live data during measurement
              Instance is created in the GUI_operator.py - read_write function

        """
        times, voltages, currents = list(data)
        edge = 0.1

        if self.exp_type == "CA":
            self.lines.set_xdata(times)
            self.lines.set_ydata(currents)
            self.ax.set_xlim(-edge, max(times) + edge)
            self.ax.set_ylim(min(currents) - edge, max(currents) + edge)
            self.canvas.draw()
            self.canvas.flush_events()

        elif self.exp_type == "LSV" or self.exp_type == "CV":
            self.lines.set_xdata(voltages)
            self.lines.set_ydata(currents)
            self.ax.set_xlim(min(voltages) - edge, max(voltages) + edge)
            self.ax.set_ylim(min(currents) - edge, max(currents) + edge)

            self.canvas.draw()
            self.canvas.flush_events()
Beispiel #28
0
class ejemplo_Gui(QMainWindow):
    def __init__(self):
        super().__init__()
        self.ui = Ui_MainWindow()
        self.ui.setupUi(self)

        self.init_widget()
        self.ui.btnSubirPatrones.clicked.connect(self.subirPatrones)
        self.ui.btnSubirPesos.clicked.connect(self.subirPesos)
        self.ui.btnSubirUmbral.clicked.connect(self.subirUmbral)
        self.ui.btnEntrenar.clicked.connect(self.entrenar)
        self.mensajeService = MensajeService()
        self.arrayMensajes = []
        self.cabeceras = []
        self.fila = 0
        self.entradas = []
        self.yd = 0
        self.pesos = []
        self.contadorIteracion = 0
        self.seguir = 0
        self.patrones = list()
        self.sumatoriaXw = 0
        self.funcion = ""
        self.yr = 0
        self.erroresPatrones = []
        self.errsIt = []
        self.numsIt = []

    def init_widget(self):
        self.figure = Figure(dpi=80, constrained_layout=True)
        self.axis = self.figure.add_subplot(111)
        self.axis.set_title('Iteración Vs Error de iteración')
        self.axis.set_xlabel('Iteración')
        self.axis.set_ylabel('Error de iteración')
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.layoutvertical = QVBoxLayout(self.ui.widgetEntrenamiento)
        self.layoutvertical.setContentsMargins(0, 0, 0, 0)
        self.layoutvertical.addWidget(self.canvas)

    def subirPesos(self):
        self.deleteAllRows(self.ui.tblWidgetPesos)
        file = self.buscarArchivo()
        if file:
            f = self.mensajeService.consultar(file)
            self.fila = 0
            for line in f.readlines():
                separador = ";"
                datos = line.split(separador)
                self.llenarTablaPesos(datos)
                self.fila += 1
            f.close()

    def llenarTablaPesos(self, datos):
        self.fila = 0
        columna = 0
        self.ui.tblWidgetPesos.setColumnCount(len(datos))
        self.ui.tblWidgetPesos.insertRow(self.fila)
        for i in range(len(datos)):
            self.cabeceras.append("w" + str(i + 1) + "1")
            self.ui.tblWidgetPesos.setHorizontalHeaderLabels(self.cabeceras)
            self.ui.tblWidgetPesos.horizontalHeader().setStretchLastSection(
                True)
            self.ui.tblWidgetPesos.horizontalHeader().setSectionResizeMode(
                QHeaderView.Stretch)
            celda = QTableWidgetItem(str(datos[i]))
            self.ui.tblWidgetPesos.setItem(self.fila, columna, celda)
            columna += 1
            self.pesos.append(float(datos[i]))

    def subirUmbral(self):
        file = self.buscarArchivo()
        if file:
            f = self.mensajeService.consultar(file)
            for line in f.readlines():
                separador = ";"
                datos = line.split(separador)
                self.ui.txtUmbral.setText(datos[0])
            f.close()

    def entrenar(self):
        self.contadorIteracion = 0
        self.seguir = 'S'
        self.funcion = self.ui.comboFuncionActivacion.currentText()
        while self.contadorIteracion < int(
                self.ui.txtNumIteracion.text()) and self.seguir == 'S':
            for item in self.patrones:
                self.calcularSoma(item)
                """QMessageBox.question(self, "Mensaje", "Ver", QMessageBox.Ok, QMessageBox.Ok)"""
            self.calcularErrorIteracion()
            self.erroresPatrones.clear()

        if self.seguir == 'S':
            QMessageBox.question(
                self, "Mensaje",
                "La red no aprendó, porque no alcanzó el error máximo permitido.",
                QMessageBox.Ok, QMessageBox.Ok)

    def calcularSoma(self, patron):
        multiXW = 0
        self.sumatoriaXw = 0
        for i in range(len(patron.entradas)):
            multiXW = patron.entradas[i] * self.pesos[i]
            self.sumatoriaXw = self.sumatoriaXw + float(multiXW)
        soma = self.sumatoriaXw - float(self.ui.txtUmbral.text())
        self.ui.txtSoma.setText(str(soma))
        self.funcionActivacion(soma, patron)

    def funcionActivacion(self, soma, patron):

        if self.funcion == "Lineal":
            self.salidaLineal(soma, patron)
        elif self.funcion == "Escalon":
            self.salidaEscalon(soma, patron)
        elif self.funcion == "Sigmoide":
            self.salidaSigmoide(soma, patron)

    def salidaLineal(self, soma, patron):
        print("")

    def salidaEscalon(self, soma, patron):
        if soma >= 0:
            self.yr = 1
        else:
            self.yr = 0

        self.ui.txtYR.setText(str(self.yr))
        self.calcularErrorPatron(self.yr, patron)

    def salidaSigmoide(self, soma, patron):
        print("")

    def calcularErrorPatron(self, yr, patron):
        errorLineal = 0
        errorPatron = 0

        errorLineal = float(patron.yd) - float(yr)

        self.ui.txtErrorLineal.setText(str(errorLineal))

        errorPatron = abs(errorLineal) / 1
        self.ui.txtErrorPatron.setText(str(errorPatron))
        self.erroresPatrones.append(errorPatron)

        self.algoritmoEntrenamiento(errorLineal, patron)

    def algoritmoEntrenamiento(self, errorLineal, patron):
        rataAprendizaje = float(self.ui.txtRataAprendizaje.text())
        nuevosPesos = []
        nuevosPesos.clear()
        c = 0
        i = 0
        j = 0
        while i < len(self.pesos):
            nuevosPesos.append(self.pesos[i] + rataAprendizaje * errorLineal *
                               patron.entradas[i])
            i += 1
        self.pesos.clear()
        self.pesos = []

        self.deleteAllRows(self.ui.tblWidgetPesos)

        self.llenarTablaPesos(nuevosPesos)

        nuevoUmbral = float(
            self.ui.txtUmbral.text()) + rataAprendizaje * errorLineal * 1
        self.ui.txtUmbral.setText(str(nuevoUmbral))

    def calcularErrorIteracion(self):
        sumatoriaErroresPatrones = 0
        errorIteracion = 0
        numeroPatrones = len(self.erroresPatrones)

        for i in self.erroresPatrones:
            sumatoriaErroresPatrones = sumatoriaErroresPatrones + i

        errorIteracion = sumatoriaErroresPatrones / numeroPatrones

        self.contadorIteracion += 1
        self.ui.txtIteracionesCumplidas.setText(str(self.contadorIteracion))

        self.ui.lbIteracion.setText(str(self.contadorIteracion))
        self.ui.lbErrorIt.setText(str(errorIteracion))

        self.errsIt.append(errorIteracion)
        self.numsIt.append(self.contadorIteracion)

        self.axis.clear()

        self.axis.plot(self.numsIt, self.errsIt)
        self.axis.set_title('Iteración Vs Error de iteración')
        self.axis.set_xlabel('Número de iteración', fontsize=12)
        self.axis.set_ylabel('Error de iteración', fontsize=12)

        "self.matplotlibWidget.axis.plot(self.contadorIteracion, errorIteracion, color='green', linewidth=2)"
        self.canvas.draw()

        self.canvas.flush_events()
        time.sleep(0.9)

        if errorIteracion <= float(self.ui.txtErrorMaxPermitido.text()):
            QMessageBox.question(
                self, 'Mensaje',
                "El error de iteracion: " + str(errorIteracion) +
                ", es menor o igual al error maximo permitido. " +
                self.ui.txtErrorMaxPermitido.text(), QMessageBox.Ok,
                QMessageBox.Ok)
            self.seguir = 'N'

    def guardar(self):
        texto = self.txtMensaje.text()
        QMessageBox.question(self, 'Mensaje', "Escribiste: " + texto,
                             QMessageBox.Ok, QMessageBox.Ok)
        self.etiqueta.setText(texto)
        mensaje = self.mensajeService.guardar(texto)
        QMessageBox.question(self, 'Mensaje', "Escribiste: " + mensaje,
                             QMessageBox.Ok, QMessageBox.Ok)

    def buscarArchivo(self):
        file, _ = QFileDialog.getOpenFileName(self, 'Buscar Archivo',
                                              QDir.homePath(),
                                              "Text Files (*.txt)")
        return file

    def subirPatrones(self):

        file = self.buscarArchivo()

        self.deleteAllRows(self.ui.tblWidgetEntSal)

        if file:
            f = self.mensajeService.consultar(file)
            fila = 0
            cabeceras = []
            for line in f.readlines():
                separador = ";"
                datos = line.split(separador)
                columna = 0
                self.ui.tblWidgetEntSal.setColumnCount(len(datos))
                self.ui.tblWidgetEntSal.insertRow(fila)
                self.entradas = []
                for i in range(len(datos)):
                    if i == (len(datos) - 1):
                        cabeceras.append("YD")
                        self.yd = float(datos[i])
                    else:
                        cabeceras.append("x" + str(i + 1))
                        self.entradas.append(float(datos[i]))
                    self.ui.tblWidgetEntSal.setHorizontalHeaderLabels(
                        cabeceras)
                    celda = QTableWidgetItem(datos[i])
                    self.ui.tblWidgetEntSal.setItem(fila, columna, celda)
                    columna += 1

                patron = Patron(self.yd, self.entradas)
                self.patrones.append(patron)
                fila += 1
            self.ui.txtNEntradas.setText(str(len(self.entradas)))
            self.ui.txtNSalidas.setText("1")
            self.ui.txtPatrones.setText(str(len(self.patrones)))
            f.close()

    def deleteAllRows(self, table: QTableWidget) -> None:
        # Obtener el modelo de la tabla
        model: QAbstractTableModel = table.model()
        # Remover todos las filas
        model.removeRows(0, model.rowCount())
Beispiel #29
0
class thzWindow(QMainWindow):
    def __init__(self):
        super(thzWindow, self).__init__()
        loadUi('MainWindow.ui', self)
        self.setWindowTitle('THz Scan GUI')

        # A hack is needed to start the drop down menus in a sane place.
        self.ddSens.setCurrentIndex(18)
        self.ddTc.setCurrentIndex(7)

        ########################################################################
        ##           Load InstrumentControl classes and initiate              ##
        ########################################################################
        #Check for os:
        if os.name == 'nt':  #Respond to windows platform
            print('Identified Windows OS')
            portLIA = 'com3'  #Prolific driver
            portStage = 'com4'  #Arduino Uno ID
        elif os.name == 'posix':
            print('Identified Mac OS')
            portLIA = '/dev/tty.usbserial'
            portStage = '/dev/tty.usbmodem1421'
        else:
            print('CRITICAL: Unidentified OS.')
        '''Change this line to SR530demo, for the demo-mode'''
        #self.lia = SR530(portLIA, 19200)
        self.lia = SR530demo(portLIA, 19200)
        self.lia.connect()
        time.sleep(0.25)
        self.lia.standard_setup()

        #Run basic sanity checks for the LIA connection
        response = self.lia.query('W')
        if response != [b'0\r']:
            print(
                'Error: Connection to LIA unsuccesful. Files will not be saved'
            )
            self.update_statusbar('CRITICAL: LIA connection not available!')
            self.save_files = False
        else:
            self.save_files = True

        #self.stage = ArduinoStageController(portStage, 9600)
        self.stage = ArduinoStageControllerDemo(portStage, 9600)
        self.stage.connect()
        self.stage.initialize()

        ########################################################################
        ##           Define execution control variables                       ##
        ########################################################################
        self.StopRunFlag = False
        self.IsHomedFlag = False
        self.SaveAllFlag = False
        self.SaveOnStop = False

        self.dataX = np.array([])
        self.dataY = np.array([])
        self.dataStep = np.array([])
        garbage = self.estimate_scan_time()

        ########################################################################
        ##           Set up windows and figures for plotting                  ##
        ########################################################################
        # a figure instance to plot on
        self.figure = Figure()
        # this is the Canvas Widget that displays the `figure`
        # it takes the `figure` instance as a parameter to __init__
        self.canvas = FigureCanvas(self.figure)
        # this is the Navigation widget
        # it takes the Canvas widget and a parent
        self.toolbar = NavigationToolbar(self.canvas, self)
        #Scale any fonts accordingly.
        if os.name == 'posix':
            matplotlib.rcParams.update({'font.size': 5})

        # Insert the widgets at appropriate places, and replace the placeholder widget.
        self.verticalLayout.insertWidget(0, self.toolbar)
        self.verticalLayout.replaceWidget(self.wplot, self.canvas)

        ########################################################################
        ##           Define signals and slots for buttons                     ##
        ########################################################################
        self.btnStart.clicked.connect(self.btnStart_clicked)
        self.btnStop.clicked.connect(self.btnStop_clicked)
        self.btnRealtime.clicked.connect(self.btnRealtime_clicked)
        self.btnGoto.clicked.connect(self.btnGoto_clicked)
        self.btnUpdate.clicked.connect(self.btnUpdate_clicked)
        self.cbSaveall.stateChanged.connect(self.update_savestate)

    #@pyqtSlot()

    ############################################################################
    ##           Define button functions                                      ##
    ############################################################################

    def btnStart_clicked(self):
        try:
            self.lia.demo_measure_reset()  #should be removed when out of dev
        except AttributeError:
            pass
        self.update_statusbar('Starting scan')
        self.reset_data_array()
        self.StopRunFlag = False
        self.SaveOnStop = False  #It defaults to the end of the loop where it saves, anyway.

        self.generate_plot()
        # update step point
        self.PresentPosition = self.nStart.value()
        # goto start of scan range

        #self.stage.move(self.PresentPosition)
        # wait for stage controller to arrive

        #loop through n steps:
        length_of_scan = int((self.nStop.value() - self.nStart.value()) /
                             self.nStepsize.value())
        for i in range(length_of_scan):

            #Check for stop flag
            if self.StopRunFlag == True:
                break

            # Measure data
            measurement = self.high_level_measure()

            #append data to dataarray
            self.dataX = np.append(self.dataX, measurement[0])
            self.dataY = np.append(self.dataY, measurement[1])
            self.dataStep = np.append(self.dataStep, self.PresentPosition)

            #Increment the PresentPosition controller variable
            self.PresentPosition = self.PresentPosition + self.nStepsize.value(
            )

            #Execute move start
            self.stage.move(self.PresentPosition)

            #Execute post move wait
            self.interruptable_sleep(self.post_move_wait_time)

            #Update plot
            self.update_plot()

            #every n datapoints save the data
        #plt.pause(0.0001)

        self.save_data_array()

    def btnStop_clicked(self):
        self.update_statusbar('Stopping scan')
        self.StopRunFlag = True
        if self.SaveOnStop:
            self.save_data_array()
        self.lia.send('I0')

    def btnRealtime_clicked(self):
        self.update_statusbar('Realtime display started')
        self.StopRunFlag = False
        self.SaveOnStop = False  #This measurement is made for alignment only, and will not be saved.
        self.lia.send('I 1')

        #Initialize the data set to zeros and sweet nothings.
        self.dataX = np.zeros(200)
        self.dataY = np.copy(self.dataX)
        self.dataStep = np.arange(200)

        #Generate plot
        self.generate_plot()
        self.PresentPosition = 200

        #Loop until stop button is clicked:
        while not self.StopRunFlag:
            #measure
            measurement = self.high_level_measure()
            #append data to dataarray
            self.dataX = np.append(self.dataX, measurement[0])
            self.dataY = np.append(self.dataY, measurement[1])
            self.dataStep = np.append(self.dataStep, self.PresentPosition)
            # Remove the first entry of the datafiles:
            self.dataX = np.delete(self.dataX, 0)
            self.dataY = np.delete(self.dataY, 0)
            self.dataStep = np.delete(self.dataStep, 0)
            #plot
            self.ax.set_xlim([self.dataStep.min(), self.dataStep.max()])
            self.update_plot()

            self.PresentPosition = self.PresentPosition + 1

        self.lia.send('I 0')

    def btnGoto_clicked(self):
        self.update_statusbar('Starting Goto')
        self.stage.move(self.nPosition.value())
        self.update_statusbar('Goto value reached')

    def btnUpdate_clicked(self):
        self.update_statusbar('Updating LIA')
        #Update sensitivity
        selected_sens = self.ddSens.currentIndex()
        #print('Sensitivity: '+str(selected_sens))
        self.lia.set_sens(selected_sens)
        #Update filter Tcs
        selected_tc = 10 - self.ddTc.currentIndex()
        #print('Time constant: '+str(selected_tc))
        self.lia.set_tc(selected_tc)

    ############################################################################
    ##           Define update, save and time calc functions                  ##
    ############################################################################

    def high_level_measure(self):
        #print(self.nAvg.value())
        dataX = []
        dataY = []
        for i in range(int(self.nAvg.value())):
            single_measurement = self.lia.measure()
            dataX.append(single_measurement[0])  #the x value
            dataY.append(single_measurement[1])  # append the y value

        return (np.mean(dataX), np.mean(dataY))

    def update_statusbar(self, new_update):
        self.statusBar.setText('Status: ' + new_update)

        #As this is an often used function, I will piggy-back on this to ensure
        # the scan time estimate is regularly updated and shown.
        estimated_scan_time = self.estimate_scan_time()
        m, s = divmod(estimated_scan_time, 60)
        h, m = divmod(m, 60)
        self.lblEstduration.setText("%dhrs, %02dmins, %02dsecs" % (h, m, s))

    def update_savestate(self):
        if self.cbSaveall.checkState() == 2:
            self.SaveAllFlag = True
        else:
            self.SaveAllFlag = False
        self.update_statusbar('Saves all: ' + str(self.SaveAllFlag))

    def estimate_scan_time(self):
        nStart = self.nStart.value()
        nStop = self.nStop.value()
        nStepsize = self.nStepsize.value()
        nPostmove = self.nPostmove.value()
        nAvg = self.nAvg.value()

        #The integers were easy, now the slightly trickier part;
        # Decoding the time constant from the Tc drop down menu.
        textTc = self.ddTc.currentText()
        multiplier, unit = textTc.split(' ')
        if unit == 's':
            Tc = float(multiplier)
        elif unit == 'ms':
            Tc = float(multiplier) * 1e-3

        self.post_move_wait_time = Tc * (1 + nPostmove)

        # Sum and multiply the time for the scan: The factor 120 is the velocity in steps/second. This should be tuned.
        time = (Tc * (1 + nPostmove) * nAvg +
                nStepsize * 1 / 120.0) * (nStop - nStart) / nStepsize

        return time

    def interruptable_sleep(self, wait_time):
        i = 0
        while not self.StopRunFlag and i < int(wait_time * 100):
            time.sleep(0.01)
            i += 1

    def reset_data_array(self):
        self.dataX = np.array([])
        self.dataY = np.array([])
        self.dataStep = np.array([])

    def save_data_array(self):
        #The filename expression will be yyyymmdd-hr-mn-ss.dat
        prefix_string = self.fileprefix.text()
        working_directory = os.getcwd() + '/'
        datetime_string = time.strftime('%Y%m%d-%H-%M-%S_')
        fname_string = working_directory + datetime_string + prefix_string + '.dat'
        if self.save_files:
            print('Saving file to ' + fname_string)
        else:
            print('Files not saved, as no proper instrument is connected.')

        #Leverage pandas to do the heavy lifting.
        if self.save_files:
            pd.DataFrame(np.array([self.dataX, self.dataY, self.dataStep]).T,
                         columns=['X', 'Y', 'step']).to_csv(fname_string)

    ############################################################################
    ##           Define plotting and plot update functions                    ##
    ############################################################################

    def generate_plot(self):
        plt.ion()
        # create an axis
        self.ax = self.figure.add_subplot(111)

        # discards the old graph
        self.ax.clear()

        self.lineX, = self.ax.plot(self.dataStep, self.dataX)
        self.lineY, = self.ax.plot(self.dataStep, self.dataY)

        self.ax.set_xlim([self.nStart.value(), self.nStop.value()])

        # refresh canvas
        #self.canvas.draw()

    def update_plot(self):
        self.lineX.set_xdata(self.dataStep)
        self.lineY.set_xdata(self.dataStep)

        self.lineX.set_ydata(self.dataX)
        self.lineY.set_ydata(self.dataY)

        #Crop the axis
        y_min = np.min([self.dataX, self.dataY])
        y_max = np.max([self.dataX, self.dataY])
        diff = y_max - y_min
        self.ax.set_ylim([y_min - diff * 0.1, y_max + diff * 0.1])

        self.canvas.draw()
        self.canvas.flush_events()
        if os.name == 'posix':
            plt.pause(0.000001)
Beispiel #30
0
class Ui(QtWidgets.QMainWindow):
    def __init__(self):
        super(Ui, self).__init__()
        uic.loadUi('DualTrack_Analyzer_simple2.ui', self)
        self.show()
        self.title = 'CICADA Dual Tracking Analysis'
        self.setWindowTitle(self.title)

        self.fig, (self.ax1, self.ax2, self.ax3) = plt.subplots(3, 2)
        #self.ax1[0].plot()
        #self.ax2[0].plot()

        self.fig.tight_layout()
        #self.ax.plot(np.random.rand(5))
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)
        self.gridLayout.addWidget(self.canvas)
        self.gridLayout.addWidget(self.toolbar)

        self.canvas.draw()
        self.canvas.flush_events()
        #self.toolbar = NavigationToolbar(self.canvas,self.mplwindow1,coordinates=True)
        #self.mplvl.addWidget(self.toolbar)

        self.actionExit.triggered.connect(self.exit)
        self.actionOpen.triggered.connect(self.openf)
        self.actionOpen.setStatusTip('Click here to open a file')

    def openf(self):
        options = QFileDialog.Options()
        fileName, _ = QFileDialog.getOpenFileName(
            self,
            "QFileDialog.getOpenFileName()",
            "",
            "All Files (*)",
            options=options)
        self.fname_loaded.setText(fileName)

        if fileName:
            print(fileName)
            self.data = pd.read_fwf(fileName)
            #self.textEdit.setText('Basic info: \n'+str(self.data.size)+' elements\n'+ str(self.data.shape[0])+' rows \n'+ str(self.data.shape[1])+' columns')
            #self.columns.setText(str(self.data.dtypes))
            #self.lineEdit.setText(fileName)
        else:
            print('no file selected')

        #self.ax1.clear()
        #self.ax2.clear()
        #self.ax3.clear()

        t = self.data['PollTime']
        t = t - t[0]
        Xdefl = self.data['X-defl'] * 1000
        Ydefl = self.data['Y-defl'] * 1000
        Az = self.data['CurGimAz']
        El = self.data['CurGimEl']
        NFOV_power = self.data['NFOV-Pwr']
        NFOV_XErr = self.data['NFOV-X']
        NFOV_YErr = self.data['NFOV-Y']
        APD_current = self.data['APD-Cur']

        plt.gcf().subplots_adjust(bottom=0.15)
        self.ax1[0].plot(t, Xdefl)
        #self.ax1[0].set_xlabel('PollTime [s]')
        self.ax1[0].set_ylabel('X Deflection [mrad]')

        self.ax1[1].plot(t, Ydefl, color='red')
        #self.ax1[1].set_xlabel('PollTime [s]')
        self.ax1[1].set_ylabel('Y Deflection [mrad]')

        self.ax2[0].plot(t, Az)
        #self.ax2[0].set_xlabel('PollTime [s]')
        self.ax2[0].set_ylabel('Gimbal Azimuth')

        self.ax2[1].plot(t, El, color='red')
        #self.ax2[1].set_xlabel('PollTime [s]')
        self.ax2[1].set_ylabel('Gimbal Elevation')

        self.ax3[0].plot(t, NFOV_XErr / NFOV_power)
        self.ax3[0].plot(t, NFOV_YErr / NFOV_power, color='red', alpha=0.7)
        self.ax3[0].set_xlabel('PollTime [s]')
        self.ax3[0].set_ylabel('NFOV x/y error [norm]')

        self.ax3[1].plot(t, NFOV_power)
        self.ax3[1].set_xlabel('PollTime [s]')
        self.ax3[1].set_ylabel('NFOV Total Power')
        #plt.plot(t,Xdefl)

        self.canvas.draw()

    def exit(self):
        QtWidgets.QApplication.exit()
Beispiel #31
0
class ouuSetupFrame(_ouuSetupFrame, _ouuSetupFrameUI):
    plotSignal = QtCore.pyqtSignal(dict)
    NotUsedText = "Not used"
    ObjFuncText = "Objective Function"
    ConstraintText = "Inequality Constraint"
    DerivativeText = "Derivative"

    def __init__(self, dat=None, parent=None):
        super(ouuSetupFrame, self).__init__(parent=parent)
        self.setupUi(self)
        self.dat = dat
        self.filesDir = ''
        self.scenariosCalculated = False
        self.result = None
        self.useAsConstraint = None
        self.plotsToUpdate = [0]
        self.objLine = None

        # Refresh table
        self.refresh()
        self.plotSignal.connect(self.addPlotValues)

        self.setFixed_button.setEnabled(False)
        self.setX1_button.setEnabled(False)
        self.setX1d_button.setEnabled(False)
        self.setX2_button.setEnabled(False)
        self.setX3_button.setEnabled(False)
        self.setX4_button.setEnabled(False)

        self.input_table.setColumnHidden(3, True)  # Hide scale column
        self.modelFile_edit.clear()
        self.modelFile_radio.setChecked(True)
        self.uqTab = self.tabs.widget(2)
        self.tabs.removeTab(2)
        self.tabs.setCurrentIndex(0)
        self.tabs.setEnabled(False)
        self.output_label.setHidden(True)
        self.output_combo.setHidden(True)
        self.output_combo.setEnabled(False)
        self.mean_radio.setChecked(True)
        self.betaDoubleSpin.setValue(0)
        self.alphaDoubleSpin.setValue(0.5)
        self.primarySolver_combo.setEnabled(False)
        self.secondarySolver_combo.setCurrentIndex(0)
        self.z3_table.setRowCount(1)
        self.compressSamples_chk.setEnabled(False)
        self.calcScenarios_button.setEnabled(False)
        self.scenarioSelect_static.setEnabled(False)
        self.scenarioSelect_combo.setEnabled(False)
        self.z4NewSample_radio.setChecked(True)
        self.x4SampleScheme_combo.setCurrentIndex(0)
        self.x4SampleSize_label.setText('Sample Size')
        self.x4SampleSize_spin.setValue(5)
        self.x4SampleSize_spin.setRange(5, 1000)
        self.x4FileBrowse_button.setEnabled(False)
        self.x4SampleScheme_combo.clear()
        self.x4SampleScheme_combo.addItems([
            SamplingMethods.getFullName(SamplingMethods.LH),
            SamplingMethods.getFullName(SamplingMethods.LPTAU),
            SamplingMethods.getFullName(SamplingMethods.FACT)
        ])
        self.x4RSMethod_check.setChecked(False)
        self.z4_table.setEnabled(False)
        self.z4_table.setRowCount(1)
        self.z4SubsetSize_label.setEnabled(False)
        self.z4SubsetSize_spin.setEnabled(False)
        self.run_button.setEnabled(True)
        self.summary_group.setMaximumHeight(250)
        self.progress_group.setMaximumHeight(250)

        self.setWindowTitle('Optimization Under Uncertainty (OUU)')

        # Connect signals
        self.node_radio.toggled.connect(self.chooseNode)
        self.node_combo.currentIndexChanged.connect(self.loadNodeData)
        self.modelFile_radio.toggled.connect(self.chooseModel)
        self.modelFileBrowse_button.clicked.connect(self.loadModelFileData)
        self.input_table.typeChanged.connect(self.setCounts)
        #self.input_table.typeChanged.connect(self.managePlots)
        #self.input_table.typeChanged.connect(self.manageBestValueTable)
        self.setFixed_button.clicked.connect(self.setFixed)
        self.setX1_button.clicked.connect(self.setX1)
        self.setX1d_button.clicked.connect(self.setX1d)
        self.setX2_button.clicked.connect(self.setX2)
        self.setX3_button.clicked.connect(self.setX3)
        self.setX4_button.clicked.connect(self.setX4)
        self.z4NewSample_radio.toggled.connect(self.chooseZ4NewSample)
        self.z4LoadSample_radio.toggled.connect(self.chooseZ4LoadSample)
        self.x4SampleSize_spin.valueChanged.connect(self.setZ4RS)
        self.z4SubsetSize_spin.valueChanged.connect(self.setZ4RS)
        self.x3FileBrowse_button.clicked.connect(self.loadX3Sample)
        self.compressSamples_chk.toggled.connect(self.activateCompressSample)
        self.calcScenarios_button.clicked.connect(self.calcScenarios)
        self.x4SampleScheme_combo.currentIndexChanged.connect(self.setX4Label)
        self.x4FileBrowse_button.clicked.connect(self.loadX4Sample)
        self.x4RSMethod_check.toggled.connect(self.showZ4Subset)
        self.run_button.clicked.connect(self.analyze)
        self.z3_table.cellChanged.connect(self.z3TableCellChanged)
        self.z4_table.cellChanged.connect(self.z4TableCellChanged)
        self.progressScrollArea.verticalScrollBar().valueChanged.connect(
            self.scrollProgressPlots)

    def freeze(self):
        QApplication.setOverrideCursor(QCursor(QtCore.Qt.WaitCursor))

    def semifreeze(self):
        QApplication.setOverrideCursor(QCursor(QtCore.Qt.BusyCursor))

    def unfreeze(self):
        QApplication.restoreOverrideCursor()

    def refresh(self):
        if self.dat is not None:
            nodes = sorted(self.dat.flowsheet.nodes.keys())
            items = ['Select node']
            items.extend(nodes)
            items.append('Full flowsheet')
            self.node_combo.clear()
            self.node_combo.addItems(items)
            self.node_radio.setChecked(True)
        else:
            self.node_radio.setEnabled(False)
            self.node_combo.setEnabled(False)
            self.modelFile_radio.setChecked(True)
        self.input_table.clearContents()
        self.input_table.setRowCount(0)
        self.modelFile_edit.clear()
        self.output_combo.clear()
        self.bestValue_table.clearContents()
        self.bestValue_table.setRowCount(2)
        self.clearPlots()

    def chooseNode(self, value):
        if value:
            self.node_combo.setEnabled(True)
            self.modelFile_edit.setEnabled(False)
            self.modelFileBrowse_button.setEnabled(False)

    def chooseModel(self, value):
        if value:
            self.node_combo.setEnabled(False)
            self.modelFile_edit.setEnabled(True)
            self.modelFileBrowse_button.setEnabled(True)

    def loadNodeData(self):
        nodeName = self.node_combo.currentText()
        if nodeName in ['', 'Select node']:
            return
        if nodeName == 'Full flowsheet':
            self.model = flowsheetToUQModel(self.dat.flowsheet)
        else:
            node = self.dat.flowsheet.nodes[nodeName]
            self.model = nodeToUQModel(nodeName, node)
        self.input_table.init(self.model, InputPriorTable.OUU)
        self.setFixed_button.setEnabled(True)
        self.setX1_button.setEnabled(True)
        self.setX1d_button.setEnabled(True)
        self.setX2_button.setEnabled(True)
        self.setX3_button.setEnabled(True)
        self.setX4_button.setEnabled(True)
        self.initTabs()
        self.setCounts()

    def loadModelFileData(self):
        if platform.system() == 'Windows':
            allFiles = '*.*'
        else:
            allFiles = '*'
        fname, _ = QFileDialog.getOpenFileName(
            self, 'Open Model File', self.filesDir,
            'Model files (*.in *.dat *.psuade *.filtered);;All files (%s)' %
            allFiles)
        if fname == '':
            return
        self.filesDir, name = os.path.split(fname)
        self.modelFile_edit.setText(fname)
        self.model = LocalExecutionModule.readSampleFromPsuadeFile(fname)
        self.model = self.model.model
        self.input_table.init(self.model, InputPriorTable.OUU)
        self.setFixed_button.setEnabled(True)
        self.setX1_button.setEnabled(True)
        self.setX1d_button.setEnabled(True)
        self.setX2_button.setEnabled(True)
        self.setX3_button.setEnabled(True)
        self.setX4_button.setEnabled(True)
        self.initTabs()
        self.setCounts()

    ##### Brenda:  Start here! #####
    def getSampleFileData(self):
        if platform.system() == 'Windows':
            allFiles = '*.*'
        else:
            allFiles = '*'

        fname, _ = QFileDialog.getOpenFileName(
            self, 'Open Sample File', self.filesDir,
            "Psuade Simple Files (*.smp);;CSV (Comma delimited) (*.csv);;All files (%s)"
            % allFiles)
        if fname == '':
            return (None, None)

        self.filesDir, name = os.path.split(fname)

        try:
            if fname.endswith('.csv'):
                data = LocalExecutionModule.readDataFromCsvFile(
                    fname, askForNumInputs=False)
            else:
                data = LocalExecutionModule.readDataFromSimpleFile(
                    fname, hasColumnNumbers=False)
            data = data[0]
            return (fname, data)
        except:
            import traceback
            traceback.print_exc()
            QMessageBox.critical(
                self, 'Incorrect format',
                'File does not have the correct format! Please consult the users manual about the format.'
            )
            return (None, None)

    def loadX3Sample(self):
        fname, data = self.getSampleFileData()
        if fname is None: return
        numInputs = data.shape[1]
        M3 = len(self.input_table.getUQDiscreteVariables()[0])
        if numInputs != M3:
            QMessageBox.warning(
                self, "Number of variables don't match",
                'The number of variables from the file (%d) does not match the number of Z3 discrete variables (%d).  You will not be able to perform analysis until this is corrected.'
                % (numInputs, M3))
        else:
            self.compressSamples_chk.setEnabled(True)
            self.loadTable(self.z3_table, data)

    def loadTable(self, table, data):
        numSamples = data.shape[0]
        numInputs = data.shape[1]
        table.setRowCount(numSamples + 1)
        for r in xrange(numSamples):
            for c in xrange(numInputs):
                item = QTableWidgetItem('%g' % data[r, c])
                table.setItem(r, c, item)
        table.resizeColumnsToContents()

    def z3TableCellChanged(self, row, col):
        self.randomVarTableCellChanged(self.z3_table, row, col)
        self.compressSamples_chk.setEnabled(True)

    def z4TableCellChanged(self, row, col):
        self.randomVarTableCellChanged(self.z4_table, row, col)

    def randomVarTableCellChanged(self, table, row, col):
        if row == table.rowCount() - 1:
            table.setRowCount(table.rowCount() + 1)

    def writeTableToFile(self, table, fileName, numCols):
        names = {self.z3_table: 'Z3', self.z4_table: 'Z4'}
        assert (numCols <= table.columnCount())
        values = []
        for r in xrange(table.rowCount()):
            rowVals = []
            rowHasData = False
            rowFull = True
            for c in xrange(numCols):
                item = table.item(r, c)
                if not item:
                    rowFull = False
                else:
                    text = item.text()
                    if not text:
                        rowFull = False
                    else:
                        try:
                            rowVals.append(float(text))
                            rowHasData = True
                        except ValueError:
                            rowFull = False
            if not rowFull and rowHasData:
                break
            if rowFull:
                values.append(rowVals)
        if not values or (rowHasData and not rowFull):
            QMessageBox.warning(
                self, "Missing data",
                'The %s table is missing required data!' % names[table])
            return False  # Failed
        LocalExecutionModule.writeSimpleFile(fileName, values, rowLabels=False)
        return True

    def activateCompressSample(self, on):
        if on:
            rowCount = 0
            for r in xrange(self.z3_table.rowCount()):
                for c in xrange(self.z3_table.columnCount()):
                    rowFull = True
                    item = self.z3_table.item(r, c)
                    if item:
                        text = item.text()
                        if not text:
                            rowFull = False
                            break
                        try:
                            float(text)
                        except ValueError:
                            rowFull = False
                            break
                    else:
                        break
                if rowFull:
                    rowCount += 1
            if rowCount < 100:
                QMessageBox.warning(
                    self, "Not enough samples in file",
                    'The file requires at least 100 samples for compression.')
                self.compressSamples_chk.setChecked(False)
                return
        self.calcScenarios_button.setEnabled(on)
        if self.scenariosCalculated:
            self.scenarioSelect_static.setEnabled(True)
            self.scenarioSelect_combo.setEnabled(True)

    def calcScenarios(self):
        self.freeze()
        self.writeTableToFile(
            self.z3_table, 'z3Samples.smp',
            len(self.input_table.getUQDiscreteVariables()[0]))
        self.scenarioFiles = OUU.compress('z3Samples.smp')
        if self.scenarioFiles is not None:
            self.scenarioSelect_combo.clear()
            for i, n in enumerate(sorted(self.scenarioFiles.keys())):
                self.scenarioSelect_combo.addItem(str(n))
                self.scenarioSelect_combo.setItemData(
                    i, '%d bins per dimension' % self.scenarioFiles[n][1],
                    QtCore.Qt.ToolTipRole)
            self.scenarioSelect_static.setEnabled(True)
            self.scenarioSelect_combo.setEnabled(True)
            self.scenariosCalculated = True
        self.unfreeze()

    def loadX4Sample(self):
        fname, inData = self.getSampleFileData()
        if fname is None: return
        numInputs = inData.shape[1]
        numSamples = inData.shape[0]
        self.z4SubsetSize_spin.setMaximum(numSamples)
        self.z4SubsetSize_spin.setValue(min(numSamples, 100))
        M4 = len(self.input_table.getContinuousVariables()[0])
        if numInputs != M4:
            QMessageBox.warning(
                self, "Number of variables don't match",
                'The number of input variables from the file (%d) does not match the number of Z4 continuous variables (%d).  You will not be able to perform analysis until this is corrected.'
                % (numInputs, M4))
        else:
            self.loadTable(self.z4_table, inData)

    def setX4Label(self):
        method = self.x4SampleScheme_combo.currentText()
        if method in [
                SamplingMethods.getFullName(SamplingMethods.LH),
                SamplingMethods.getFullName(SamplingMethods.LPTAU)
        ]:
            self.x4SampleSize_label.setText('Sample Size')
            numM1 = len(self.input_table.getPrimaryVariables()[0])
            self.x4SampleSize_spin.setRange(numM1 + 1, 1000)
            self.x4SampleSize_spin.setValue(numM1 + 1)
            self.x4SampleSize_spin.setSingleStep(1)
        elif method == SamplingMethods.getFullName(SamplingMethods.FACT):
            self.x4SampleSize_label.setText('Number of Levels')
            self.x4SampleSize_spin.setRange(3, 100)
            self.x4SampleSize_spin.setValue(3)
            self.x4SampleSize_spin.setSingleStep(2)

    def initTabs(self):
        self.tabs.setEnabled(True)
        self.tabs.setCurrentIndex(0)
        outputNames = self.model.getOutputNames()

        # Optimization Setup
        self.output_combo.setEnabled(True)
        self.output_combo.clear()
        self.output_combo.addItems(outputNames)
        self.mean_radio.setChecked(True)
        self.betaDoubleSpin.setValue(0)
        self.alphaDoubleSpin.setValue(0.5)
        self.secondarySolver_combo.setCurrentIndex(0)

        # Outputs
        self.outputs_table.blockSignals(True)
        self.outputs_table.setColumnCount(2)
        self.outputs_table.setRowCount(len(outputNames))
        self.useAsConstraint = [False] * len(outputNames)
        self.useAsDerivative = [False] * len(outputNames)
        for r in xrange(len(outputNames)):
            # radio = QRadioButton()
            # if r == 0:
            #     radio.setChecked(True)
            # radio.setProperty('row', r)
            # radio.toggled.connect(self.setObjectiveFunction)
            # self.outputs_table.setCellWidget(r, 0, radio)
            combobox = QComboBox()
            combobox.addItems([
                ouuSetupFrame.NotUsedText, ouuSetupFrame.ObjFuncText,
                ouuSetupFrame.ConstraintText, ouuSetupFrame.DerivativeText
            ])

            if r == 0:
                combobox.setCurrentIndex(1)
            self.outputs_table.setCellWidget(r, 0, combobox)

            item = QTableWidgetItem(outputNames[r])
            self.outputs_table.setItem(r, 1, item)
            # item = QTableWidgetItem()
            # item.setCheckState(QtCore.Qt.Unchecked)
            # self.outputs_table.setItem(r, 2, item)
            # if r == 0:
            #     flags = item.flags()
            #     flags &= (~QtCore.Qt.ItemIsEnabled)
            #     item.setFlags(flags)
        self.outputs_table.resizeColumnsToContents()
        self.outputs_table.blockSignals(False)

        # UQ Setup
        self.compressSamples_chk.setChecked(False)
        self.compressSamples_chk.setEnabled(False)
        self.scenariosCalculated = False
        self.scenarioSelect_static.setEnabled(False)
        self.scenarioSelect_combo.setEnabled(False)
        self.scenarioSelect_combo.clear()
        self.x4SampleScheme_combo.setCurrentIndex(0)
        self.x4SampleSize_label.setText('Sample Size')
        self.x4SampleSize_spin.setValue(5)
        self.x4SampleSize_spin.setRange(5, 1000)
        self.x4RSMethod_check.setChecked(False)

        # Launch/Progress
        self.run_button.setEnabled(
            True)  # TO DO: disable until inputs are validated
        self.bestValue_table.setColumnCount(1)
        self.bestValue_table.clearContents()
        # Plots
        self.plots_group = QGroupBox()
        self.plotsLayout = QVBoxLayout()
        self.plots_group.setLayout(self.plotsLayout)
        self.progressScrollArea.setMinimumHeight(150)
        self.progressScrollArea.setWidget(self.plots_group)
        self.plots_group.setMinimumHeight(150)
        self.objFig = Figure(figsize=(400, 200),
                             dpi=72,
                             facecolor=(1, 1, 1),
                             edgecolor=(0, 0, 0),
                             tight_layout=True)
        self.objCanvas = FigureCanvas(self.objFig)
        self.objFigAx = self.objFig.add_subplot(111)
        self.objFigAx.set_title('OUU Progress')
        self.objFigAx.set_ylabel('Objective')
        self.objFigAx.set_xlabel('Iteration')
        self.objLine = None
        self.plotsLayout.addWidget(self.objCanvas)
        self.objCanvas.setParent(self.plots_group)
        self.inputPlots = []

        self.objXPoints = []
        self.objYPoints = []
        self.objPlotPoints = None

    # def setObjectiveFunction(self, on):
    #     self.outputs_table.blockSignals(True)
    #     row = self.sender().property('row')
    #     item = self.outputs_table.item(row, 2) # Checkbox for inequality constraint
    #     flags = item.flags()
    #     if on:
    #         flags &= ~QtCore.Qt.ItemIsEnabled
    #         item.setCheckState(QtCore.Qt.Unchecked)
    #     else:
    #         flags |= QtCore.Qt.ItemIsEnabled
    #         item.setCheckState(QtCore.Qt.Checked if self.useAsConstraint[row] else QtCore.Qt.Unchecked)
    #     item.setFlags(flags)

    # def toggleConstraintStatus(self, item):
    #     self.useAsConstraint[item.row()] = (item.checkState() == QtCore.Qt.Checked)
    #

    def manageBestValueTable(self):
        self.bestValue = None
        names, indices = self.input_table.getPrimaryVariables()
        self.bestValue_table.setRowCount(len(names) + 2)
        self.bestValue_table.setVerticalHeaderLabels(
            ['Iteration', 'Objective Value'] + names)
        self.bestValue_table.clearContents()

    def setBestValueTable(self, iteration, objValue, inputs):
        item = self.bestValue_table.item(0, 0)  #iteration
        if item is None:
            self.bestValue_table.setItem(0, 0,
                                         QTableWidgetItem('%d' % iteration))
        else:
            item.setText('%d' % iteration)

        if self.bestValue == None or objValue < self.bestValue:
            self.bestValue = objValue
            item = self.bestValue_table.item(1, 0)  #objective value
            if item is None:
                self.bestValue_table.setItem(1, 0,
                                             QTableWidgetItem('%f' % objValue))
            else:
                item.setText('%f' % objValue)

            for i, value in enumerate(inputs):
                item = self.bestValue_table.item(i + 2, 0)  #input
                if item is None:
                    self.bestValue_table.setItem(
                        i + 2, 0, QTableWidgetItem('%f' % value))
                else:
                    item.setText('%f' % value)

    def addPlotValues(self, valuesDict):
        self.addPointToObjPlot(valuesDict['objective'])
        self.addToInputPlots(valuesDict['input'])
        (iteration, objValue) = valuesDict['objective']
        self.setBestValueTable(iteration, objValue, valuesDict['input'][1:])

    def addPointToObjPlot(self, x):
        self.objXPoints.append(x[0])
        self.objYPoints.append(x[1])
        if 0 in self.plotsToUpdate:
            numPoints = len(self.objXPoints)
            if numPoints % math.ceil(
                    float(numPoints) / 30
            ) == 0:  # limit refresh rate as number of points gets large
                self.updateObjPlot()

    def updateObjPlot(self):
        #if not self.objLine:
        if True:
            self.objLine, = self.objFigAx.plot(self.objXPoints,
                                               self.objYPoints, 'bo')
            self.objCanvas.draw()
        else:
            self.objLine.set_xdata(self.objXPoints)
            self.objLine.set_ydata(self.objYPoints)
            self.objFigAx.draw_artist(self.objFigAx.patch)
            self.objFigAx.draw_artist(self.objFigAx.xaxis)
            self.objFigAx.draw_artist(self.objFigAx.yaxis)
            self.objFigAx.draw_artist(self.objLine)
            self.objCanvas.update()
            self.objCanvas.flush_events()

    def addToInputPlots(self, x):
        for i in xrange(len(self.inputPoints)):
            self.inputPoints[i].append(x[i])
            if i > 0 and i in self.plotsToUpdate:
                numPoints = len(self.inputPoints[i])
                if numPoints % math.ceil(
                        float(numPoints) / 30
                ) == 0:  # limit refresh rate as number of points gets large
                    self.updateInputPlot(i)

    def updateInputPlot(self, index):  # Index starts at 1 for first input plot
        self.inputPlots[index - 1]['ax'].plot(self.inputPoints[0],
                                              self.inputPoints[index], 'bo')
        self.inputPlots[index - 1]['canvas'].draw()

    def managePlots(self):
        names, indices = self.input_table.getPrimaryVariables()
        if len(self.inputPlots) < len(names):  #add plots
            for i in xrange(len(self.inputPlots), len(names)):
                fig = Figure(figsize=(400, 200),
                             dpi=72,
                             facecolor=(1, 1, 1),
                             edgecolor=(0, 0, 0),
                             tight_layout=True)
                canvas = FigureCanvas(fig)
                ax = fig.add_subplot(111)
                ax.set_xlabel('Iteration')
                self.inputPlots.append({
                    'fig': fig,
                    'canvas': canvas,
                    'ax': ax
                })
                self.plotsLayout.addWidget(canvas)
                canvas.setParent(self.plots_group)
        elif len(self.inputPlots) > len(names):  #remove plots
            for i in xrange(len(names), len(self.inputPlots)):
                self.inputPlots[i]['fig'].clf()
                self.inputPlots[i]['canvas'].deleteLater()
                del self.inputPlots[i]

        for i, name in enumerate(names):
            #self.inputPlots[i]['ax'].set_ylabel('Primary Input %s' % name)
            self.inputPlots[i]['ax'].set_ylabel(name)

        self.plots_group.setMinimumHeight(190 * (len(names) + 1))

        self.inputPoints = [[] for i in xrange(len(names) + 1)]
        self.clearPlots()

    def clearPlots(self):
        self.objXPoints = []
        self.objYPoints = []
        if 'objFigAx' in self.__dict__ and len(self.objFigAx.lines) > 0:
            self.objFigAx.lines = []
            self.objFigAx.relim()
            #self.objFigAx.set_xlim([0.0, 1.0])
            self.objCanvas.draw()

        if 'inputPoints' in self.__dict__:
            self.inputPoints = [[] for i in xrange(len(self.inputPoints))]
            for i in xrange(1, len(self.inputPoints)):
                if len(self.inputPlots[i - 1]['ax'].lines) > 0:
                    self.inputPlots[i - 1]['ax'].lines = []
                    self.inputPlots[i - 1]['canvas'].draw()

    def scrollProgressPlots(self, value):
        names, indices = self.input_table.getPrimaryVariables()
        numPlots = len(names) + 1
        firstPlotToUpdate = int(value / 190)
        firstPlotToUpdate = min(firstPlotToUpdate, numPlots - 1)
        self.plotsToUpdate = [firstPlotToUpdate]
        if firstPlotToUpdate < numPlots - 1:
            self.plotsToUpdate.append(firstPlotToUpdate + 1)
        #print "Scroll", value,self.plotsToUpdate
        for index in self.plotsToUpdate:
            if index == 0:
                self.updateObjPlot()
            else:
                self.updateInputPlot(index)

    def setFixed(self):
        self.input_table.setCheckedToType(0)

    def setX1(self):
        self.input_table.setCheckedToType(1)

    def setX1d(self):
        self.input_table.setCheckedToType(2)

    def setX2(self):
        self.input_table.setCheckedToType(3)

    def setX3(self):
        varNames = self.input_table.setCheckedToType(4)

    def setX4(self):
        varNames = self.input_table.setCheckedToType(5)

    def setCounts(self):
        # update counts

        M0 = len(self.input_table.getFixedVariables()[0])
        M1 = len(self.input_table.getPrimaryVariables()[0])
        M2 = len(self.input_table.getRecourseVariables()[0])
        M3Vars = self.input_table.getUQDiscreteVariables()[0]
        M3 = len(M3Vars)
        M4Vars = self.input_table.getContinuousVariables()[0]
        M4 = len(M4Vars)
        self.fixedCount_static.setText('# Fixed: %d' % M0)
        self.x1Count_static.setText('# Primary Opt Vars: %d' % M1)
        self.x2Count_static.setText('# Recourse Opt Vars: %d' % M2)
        self.x3Count_static.setText('# Discrete RVs: %d' % M3)
        self.x4Count_static.setText('# Continuous RVs: %d' % M4)

        hideInnerSolver = (M2 == 0)
        self.secondarySolver_label.setHidden(hideInnerSolver)
        self.secondarySolver_combo.setHidden(hideInnerSolver)

        hideZ3Group = (M3 == 0)
        hideZ4Group = (M4 == 0)
        if hideZ3Group and hideZ4Group:  #Hide tab
            if self.tabs.widget(2) == self.uqTab:
                if self.tabs.currentIndex() == 2:
                    self.tabs.setCurrentIndex(0)
                self.tabs.removeTab(2)
        else:  #Show tab
            if self.tabs.widget(2) != self.uqTab:
                self.tabs.insertTab(2, self.uqTab, 'UQ Setup')

        numCols = max(len(M3Vars), self.z3_table.columnCount())
        self.z3_table.setColumnCount(numCols)
        self.z3_table.setHorizontalHeaderLabels(M3Vars + ['Unused'] *
                                                (numCols - len(M3Vars)))
        numCols = max(len(M4Vars), self.z4_table.columnCount())
        self.z4_table.setColumnCount(numCols)
        self.z4_table.setHorizontalHeaderLabels(M4Vars + ['Unused'] *
                                                (numCols - len(M4Vars)))

        self.z3_group.setHidden(hideZ3Group)
        self.z4_group.setHidden(hideZ4Group)

        if self.x4SampleScheme_combo.currentText(
        ) == SamplingMethods.getFullName(SamplingMethods.FACT):
            self.x4SampleSize_spin.setMinimum(3)
        else:
            self.x4SampleSize_spin.setMinimum(M4 + 1)

        self.z4SubsetSize_spin.setMinimum(M4 + 1)

    def chooseZ4NewSample(self, value):
        if value:
            self.x4SampleScheme_label.setEnabled(True)
            self.x4SampleScheme_combo.setEnabled(True)
            self.x4SampleSize_label.setEnabled(True)
            self.x4SampleSize_spin.setEnabled(True)
            self.x4FileBrowse_button.setEnabled(False)
            self.showZ4Subset(False)
            self.setZ4RS(self.x4SampleSize_spin.value())
            self.z4_table.setEnabled(False)

    def chooseZ4LoadSample(self, value):
        if value:
            self.x4SampleScheme_label.setEnabled(False)
            self.x4SampleScheme_combo.setEnabled(False)
            self.x4SampleSize_label.setEnabled(False)
            self.x4SampleSize_spin.setEnabled(False)
            self.x4FileBrowse_button.setEnabled(True)
            self.showZ4Subset(True)
            self.setZ4RS(self.z4SubsetSize_spin.value())
            self.z4_table.setEnabled(True)

    def setZ4RS(self, value):
        if value <= 300:
            rs = ResponseSurfaces.getFullName(ResponseSurfaces.KRIGING)
        elif value <= 600:
            rs = ResponseSurfaces.getFullName(ResponseSurfaces.RBF)
        else:
            rs = ResponseSurfaces.getFullName(ResponseSurfaces.MARS)
        self.x4RSMethod_check.setText('Use Response Surface (%s)' % rs)

    def showZ4Subset(self, show):
        if show and self.x4RSMethod_check.isChecked(
        ) and self.z4LoadSample_radio.isChecked():
            self.z4SubsetSize_label.setEnabled(True)
            self.z4SubsetSize_spin.setEnabled(True)
        else:
            self.z4SubsetSize_label.setEnabled(False)
            self.z4SubsetSize_spin.setEnabled(False)

    def setupPSUADEClient(self):
        curDir = os.getcwd()
        #Copy needed files
        dest = os.path.join(curDir, 'foqusPSUADEClient.py')
        mydir = os.path.dirname(__file__)
        src = os.path.join(mydir, 'foqusPSUADEClient.py')
        shutil.copyfile(src, dest)
        os.chmod(dest, 0700)
        return dest

    def analyze(self):

        if self.run_button.text() == 'Run OUU':  # Run OUU
            names, indices = self.input_table.getPrimaryVariables()
            if len(names) == 0:
                QMessageBox.information(
                    self, 'No Primary Variables',
                    'At least one input must be a primary variable!')
                return

            valid, error = self.input_table.checkValidInputs()
            if not valid:
                QMessageBox.information(
                    self, 'Input Table Distributions',
                    'Input table distributions are either not correct or not filled out completely! %s'
                    % error)
                return

            if self.compressSamples_chk.isChecked(
            ) and not self.scenariosCalculated:
                QMessageBox.information(
                    self, 'Compress Samples Not Calculated',
                    'You have elected to compress samples for discrete random variables (Z3), but have not selected the sample size to use!'
                )
                return

            M1 = len(self.input_table.getPrimaryVariables()[0])
            M2 = len(self.input_table.getRecourseVariables()[0])
            M3 = len(self.input_table.getUQDiscreteVariables()[0])
            M4 = len(self.input_table.getContinuousVariables()[0])

            Common.initFolder(OUU.dname)

            self.managePlots()
            self.clearPlots()
            self.manageBestValueTable()
            self.summary_group.setTitle('Best So Far')

            xtable = self.input_table.getTableValues()

            # get arguments for ouu()
            model = copy.deepcopy(self.model)
            inputNames = model.getInputNames()
            inputTypes = list(model.getInputTypes())
            defaultValues = model.getInputDefaults()
            for row in xtable:
                if row['type'] == 'Fixed':
                    modelIndex = inputNames.index(row['name'])
                    inputTypes[modelIndex] = Model.FIXED
                    defaultValues[modelIndex] = row['value']
            #print inputTypes
            #print defaultValues
            model.setInputTypes(inputTypes)
            model.setInputDefaults(defaultValues)
            data = SampleData(model)
            fname = 'ouuTemp.dat'
            data.writeToPsuade(fname)

            # Outputs
            numOutputs = self.outputs_table.rowCount()
            y = []
            self.useAsConstraint = [False] * numOutputs
            self.useAsDerivative = [False] * numOutputs
            for r in xrange(numOutputs):
                type = self.outputs_table.cellWidget(r, 0).currentText()
                if type == ouuSetupFrame.ObjFuncText:
                    y.append(r + 1)
                elif type == ouuSetupFrame.ConstraintText:
                    self.useAsConstraint[r] = True
                elif type == ouuSetupFrame.DerivativeText:
                    self.useAsDerivative[r] = True

            if self.mean_radio.isChecked():
                phi = {'type': 1}
            elif self.meanWithBeta_radio.isChecked():
                beta = self.betaDoubleSpin.value()
                phi = {'type': 2, 'beta': beta}
            elif self.alpha_radio.isChecked():
                alpha = self.alphaDoubleSpin.value()
                phi = {'type': 3, 'alpha': alpha}
            x3sample = None
            if M3 > 0:
                if self.compressSamples_chk.isChecked():
                    selectedSamples = int(
                        self.scenarioSelect_combo.currentText())
                    sfile = self.scenarioFiles[selectedSamples][0]
                else:
                    sfile = 'z3Samples.smp'
                    success = self.writeTableToFile(self.z3_table, sfile, M3)
                    if not success:
                        return
                    if sfile.endswith(
                            '.csv'):  # Convert .csv file to simple file
                        newFileName = OUU.dname + os.sep + os.path.basename(
                            fname)[:-4] + '.smp'
                        inData = LocalExecutionModule.readDataFromCsvFile(
                            sfile, askForNumInputs=False)
                        LocalExecutionModule.writeSimpleFile(
                            newFileName, inData[0])
                        sfile = newFileName

                #print 'x3 file is', sfile
                x3sample = {'file': sfile}  # x3sample file
                data, _, numInputs, _ = LocalExecutionModule.readDataFromSimpleFile(
                    sfile, hasColumnNumbers=False)
                if numInputs != M3:
                    QMessageBox.critical(
                        self, "Number of variables don't match",
                        'The number of variables from the file (%d) does not match the number of Z3 discrete variables (%d).  You will not be able to perform analysis until this is corrected.'
                        % (numInputs, M3))
                    return
            useRS = self.x4RSMethod_check.isChecked()
            x4sample = None
            if self.z4NewSample_radio.isChecked():
                method = self.x4SampleScheme_combo.currentText()
                method = SamplingMethods.getEnumValue(method)
                N = self.x4SampleSize_spin.value()
                if method in [SamplingMethods.LH, SamplingMethods.LPTAU]:
                    x4sample = {
                        'method': method,
                        'nsamples': N
                    }  # number of samples (range: [M1+1,1000])
                elif method == SamplingMethods.FACT:
                    x4sample = {
                        'method': method,
                        'nlevels': N
                    }  # number of levels (range: [3,100])
            else:
                sfile = 'z4Samples.smp'
                success = self.writeTableToFile(self.z4_table, sfile, M4)
                if not success:
                    return
                if len(sfile) == 0:
                    QMessageBox.critical(self, 'Missing file',
                                         'Z4 sample file not specified!')
                    return

                if sfile.endswith('.csv'):  # Convert .csv file to simple file
                    newFileName = OUU.dname + os.sep + os.path.basename(
                        fname)[:-4] + '.smp'
                    inData = LocalExecutionModule.readDataFromCsvFile(
                        sfile, askForNumInputs=False)
                    LocalExecutionModule.writeSimpleFile(
                        newFileName, inData[0])
                    sfile = newFileName

                inData, outData, numInputs, numOutputs = LocalExecutionModule.readDataFromSimpleFile(
                    sfile, hasColumnNumbers=False)
                numSamples = inData.shape[0]
                if numInputs != M4:
                    QMessageBox.critical(
                        self, "Number of variables don't match",
                        'The number of input variables from the file (%d) does not match the number of Z4 continuous variables (%d).  You will not be able to perform analysis until this is corrected.'
                        % (numInputs, M4))
                    return
                if numSamples <= M4:
                    QMessageBox.critical(
                        self, 'Not enough samples',
                        'Z4 sample file must have at least %d samples!' %
                        (M4 + 1))
                    return

                x4sample = {
                    'file': sfile
                }  # x4sample file, must have at least M4+1 samples

                if useRS:
                    Nrs = self.z4SubsetSize_spin.value(
                    )  # add spinbox to get number of samples to generate RS
                    x4sample[
                        'nsamplesRS'] = Nrs  # TO DO: make sure spinbox has M4+1 as min and x4sample's sample size as max

            #  TODO: Get rid of usebobyqa option. Behavior should be as if usebobyqa is always false
            # TODO: Change GUI to display optimizer and optimizing with bobyqa
            # TODO: Get rid of primarySolver_combo completely
            useBobyqa = False
            method = self.primarySolver_combo.currentText()
            if method == "BOBYQA":
                pass
            elif method == "NEWUOA":
                pass
            if 'simulator' in self.secondarySolver_combo.currentText():
                useBobyqa = True  # use BOBYQA if driver is a simulator, not an optimizer

            self.run_button.setText('Stop')
            optDriver = None
            ensembleOptDriver = None
            if self.node_radio.isChecked():
                ensembleOptDriver = self.setupPSUADEClient()
                optDriver = ensembleOptDriver
                listener = listen.foqusListener(self.dat)
                variableNames = []
                fixedNames = []
                for row in xtable:
                    #print row
                    if row['type'] == 'Fixed':
                        fixedNames.append(row['name'])
                    else:
                        variableNames.append(row['name'])
                #print fixedNames, variableNames
                #print variableNames + fixedNames
                listener.inputNames = variableNames + fixedNames
                outputNames = self.model.getOutputNames()
                listener.outputNames = [outputNames[yItem - 1] for yItem in y]
                listener.failValue = -111111
                self.listenerAddress = listener.address
                listener.start()

            # print M1, M2, M3, M4, useBobyqa
            self.OUUobj = OUU()
            try:
                results = self.OUUobj.ouu(fname,
                                          y,
                                          self.useAsConstraint,
                                          self.useAsDerivative,
                                          xtable,
                                          phi,
                                          x3sample=x3sample,
                                          x4sample=x4sample,
                                          useRS=useRS,
                                          useBobyqa=useBobyqa,
                                          optDriver=optDriver,
                                          ensOptDriver=ensembleOptDriver,
                                          plotSignal=self.plotSignal,
                                          endFunction=self.finishOUU)
            except:
                import traceback
                traceback.print_exc()
                if self.node_radio.isChecked():
                    # stop the listener
                    conn = Client(self.listenerAddress)
                    conn.send(['quit'])
                    conn.close()

                # enable run button
                self.run_button.setEnabled(True)
                return
        else:  # Stop OUU
            self.OUUobj.stopOUU()
            self.run_button.setEnabled(False)
            self.freeze()

    def finishOUU(self):
        if self.node_radio.isChecked():
            # stop the listener
            conn = Client(self.listenerAddress)
            conn.send(['quit'])
            conn.close()

        # enable run button
        if not self.run_button.isEnabled():
            self.unfreeze()
        self.run_button.setText('Run OUU')
        self.run_button.setEnabled(True)

        if not self.OUUobj.getHadError():
            self.summary_group.setTitle('Best Solution')
            #        results.replace('X','Z')
            #
            #        QMessageBox.information(self, 'OUU Results', results)

            msgBox = QMessageBox()
            msgBox.setWindowTitle('FOQUS OUU Finished')
            msgBox.setText('Optimization under Uncertainty analysis finished')
            self.result = msgBox.exec_()

    def getResult(self):
        return self.result
Beispiel #32
0
class SignalLabeler(QMainWindow):
    def __init__(self, signal_folder=None, save_file=None):
        super().__init__()
        self._init_ui()
        self.ax.text(
            0.5,
            0.5,
            'Press F to open File, D to open dir, R to review a result, Q to save result, and X to '
            'skip current read',
            horizontalalignment='center',
            verticalalignment='center',
            transform=self.ax.transAxes)
        self.canvas.draw()
        self.line = None
        self.axvlines = []
        self.cursor_line = None
        self.colors = ['#990000', '#ffa500', '#0a0451']
        self.curr_color = self.colors[0]
        self.pos = []
        self.signal_dir = signal_folder
        self.sig_iter = None
        self.sig_div_dict = {}
        self.signal = []
        self.sig_list = []
        self.curr_file = None
        self.id = None
        self.records = []
        self.save = save_file
        self.start = False
        self.reverse = False
        self.review = False
        self.review_count = 0
        self.cache_dir = os.path.join(os.path.dirname(boostnano.__file__),
                                      'cache')
        self.cache_file = os.path.join(self.cache_dir, 'cache.csv')
        self.cache_size = 1  # Auto save the result to cache file after this number of labeled.
        self.cache_fns = []

        self.prev_file = None
        self.prev_signal = None
        self.prev_id = None
        if not os.path.isdir(self.cache_dir):
            os.mkdir(self.cache_dir)

    def _init_ui(self):
        self.dialog_p = {
            'title': 'PolyA label tool',
            'left': 350,
            'top': 1200,
            'width': 1280,
            'height': 480
        }
        self.setWindowTitle(self.dialog_p['title'])
        self._main = QWidget()
        self.setCentralWidget(self._main)
        self.fig = Figure(figsize=(5, 3))
        # a figure instance to plot on
        self.canvas = FigureCanvasQTAgg(self.fig)
        self.ax = self.fig.add_subplot(111)
        self.setGeometry(self.dialog_p['left'], self.dialog_p['top'],
                         self.dialog_p['width'], self.dialog_p['height'])
        self.gridLayout = QGridLayout(self._main)
        self.gridLayout.addWidget(self.canvas)
        self._main.setLayout(self.gridLayout)
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('motion_notify_event', self.on_move)
        self.show()

    def _reinit(self):
        self.ax.clear()
        self.ax.text(
            0.5,
            0.5,
            'Press F to open File, D to open dir, R to review a result, Q to save result, and X to '
            'skip current read',
            horizontalalignment='center',
            verticalalignment='center',
            transform=self.ax.transAxes)
        self.line = None
        self.curr_color = self.colors[0]
        self.canvas.draw_idle()
        self.pos = []
        self.signal_dir = None
        self.sig_iter = None
        self.sig_div_dict = {}
        self.signal = []
        self.sig_list = []
        self.curr_file = None
        self.records = []
        self.axvlines = []
        self.cursor_line = None
        self.save = None
        self.start = False
        self.review = False
        self.review_count = 0
        self.cache_fns = []

    def refresh(self, xpos):
        self.cursor_line.set_xdata(xpos)
        self.cursor_line.set_color(self.curr_color)
        self.ax.draw_artist(self.ax.patch)
        self.ax.draw_artist(self.line)
        self.ax.draw_artist(self.cursor_line)
        for axvline in self.axvlines:
            self.ax.draw_artist(axvline)
        self.canvas.update()
        self.canvas.flush_events()

    def redraw(self, cursor_x, xpos):
        self.ax.clear()
        self.fig.suptitle(os.path.basename(self.curr_file), fontsize=10)
        if not self.reverse:
            self.line = self.ax.plot(denoise(self.signal))[0]
        else:
            self.line = self.ax.plot(denoise(self.signal[::-1]))[0]
        if cursor_x is not None:
            self.cursor_line = self.ax.axvline(x=cursor_x,
                                               color=self.curr_color)
        self.axvlines = []
        for x, color in xpos:
            self.axvlines.append(self.ax.axvline(x=x, color=color))
        self.canvas.draw_idle()
        self.canvas.flush_events()

    def on_move(self, event):
        if not self.start:
            return
        x = event.xdata
        self.refresh(x)

    def on_click(self, event: QMouseEvent):
        if not self.start:
            return
        x = event.xdata
        self.pos.append(x)
        axv_pairs = []
        for idx, p in enumerate(self.pos):
            axv_pairs.append((p, self.colors[idx]))
        if len(self.pos) >= 3:
            self.records.append([self.curr_file, self.id] + self.pos)
            self.next_signal()
            self.pos = []
            self.axvlines = []
            self.curr_color = self.colors[0]
            axv_pairs = []
        if len(self.records) > self.cache_size:
            self._cache()
        self.curr_color = self.colors[len(self.pos)]
        self.redraw(None, axv_pairs)

    def keyPressEvent(self, event):
        if event.key() == Qt.Key_Q:
            if not self.start:
                self._quit()
                return
            if self.save is None:
                self.save = str(
                    QFileDialog.getExistingDirectory(
                        self, None, "Choose directory to save."))
            self._save()
            self._reinit()

        elif event.key() == Qt.Key_D:
            if self.start:
                return
            if self.signal_dir is None:
                self.signal_dir = str(
                    QFileDialog.getExistingDirectory(
                        self,
                        "Select Signal Directory",
                    ))
                if self.signal_dir == '':
                    self.signal_dir = None
                    return
            self._start()

        elif event.key() == Qt.Key_F:
            if self.start:
                return
            if self.signal_dir is None:
                self.signal_dir = QFileDialog.getOpenFileName(
                    self, 'Select Signal File', '', '*.fast5')[0]
                if self.signal_dir == '':
                    self.signal_dir = None
                    return
            self._start()

        elif event.key() == Qt.Key_X:
            # skip current signal file
            self.pos = []
            self.next_signal()
            self.curr_color = self.colors[0]
            axv_pairs = []
            self.redraw(None, axv_pairs)

        elif event.key() == Qt.Key_R:
            # If already in review mode, then reverse the signal
            if self.review:
                self.reverse = not self.reverse
                axv_pairs = []
                for idx, p in enumerate(self.pos):
                    axv_pairs.append((p, self.colors[idx]))
                self.redraw(None, axv_pairs)
            # If not, start review mode
            else:
                read_file = str(
                    QFileDialog.getOpenFileName(
                        self, "Select result file to review", './',
                        'CSV Files(*.csv)'))
                self.save = read_file.split(',')[0][1:].strip("'")
                if len(self.save) == 0:
                    return
                self._review()

        elif event.key() == Qt.Key_D:
            # Inspect next signal in review mode
            if not self.review:
                return
            self.prev_file = self.curr_file
            self.prev_signal = self.signald
            self.prev_id = self.id
            self.next_signal()
            self.review_count += 1
            if self.curr_file is not None:
                print("Review %d file" % self.review_count)
            else:
                print("Reading %d file fail" % (self.review_count + 1))
                return
            self.pos = self.sig_div_dict[self.curr_file]
            print("Segmentation:" + ",".join([str(x) for x in self.pos]))
            axv_pairs = []
            for idx, p in enumerate(self.pos):
                axv_pairs.append((p, self.colors[idx]))
            print(axv_pairs)
            self.redraw(None, axv_pairs)
        elif event.key() == Qt.Key_A:
            # Inspect previous signal in review mode(can only go back once)
            if not self.review:
                return
            self.curr_file = self.prev_file
            self.id = self.prev_id
            self.signal = self.prev_signal
            self.pos = self.sig_div_dict[self.curr_file]
            axv_pairs = []
            for idx, p in enumerate(self.pos):
                axv_pairs.append((p, self.colors[idx]))
            self.redraw(None, axv_pairs)
        elif event.key() == Qt.Key_C:
            # This will print out the current file name in the terminal
            if not self.review:
                return
            print(self.curr_file)

    def _iter_signals(self):
        for file in self.sig_list:
            fast5 = Fast5(file)
            for _, signal, _, identifier in fast5:
                yield file, identifier, signal

    def _save(self):
        self._cache()
        copyfile(self.cache_file, os.path.join(self.save, 'result.csv'))

    def _cache(self):
        with open(self.cache_file, 'a') as f:
            for record in self.records:
                f.write(','.join([str(x) for x in record]))
                f.write('\n')
        self.records = []

    def _start(self):
        if self.start:
            print("Already start a job")
            return
        print("Begin reading signal file list")
        if os.path.isdir(self.signal_dir):
            file_list = os.listdir(self.signal_dir)
            self.sig_list = [
                os.path.join(self.signal_dir, x) for x in file_list
            ]
        else:
            self.sig_list = [self.signal_dir]
        self.sig_iter = self._iter_signals()
        print("Try to read cache file.")
        if os.path.isfile(self.cache_file):
            with open(self.cache_file) as cache_f:
                for line in cache_f:
                    cols = line.strip().split(',')
                    self.cache_fns.append((cols[0], cols[1]))
            print(
                "Sucessfully read cache file, load {entries} cache entries, delete the cache.csv under {path} to not "
                "use cache record.".format(entries=len(self.cache_fns),
                                           path=self.cache_dir))
        else:
            print("No cache file found.")
        self.next_signal()
        self.start = True
        self.redraw(0, [])
        print(
            "Reading finished, press Q to save result, press X to skip current read."
        )

    def _review(self):
        if self.review:
            print("A review process already begin.")
            return
        print("Begin reading result csv file.")
        with open(self.save, 'r') as csv_f:
            for line in tqdm(csv_f):
                split_line = line.strip().split(',')
                if 'None' in split_line:
                    continue
                self.sig_list.append(split_line[0])
                temp_pos = [float(x) for x in split_line[1:]]
                if len(temp_pos) == 0:
                    print("%s has no segmentation information" %
                          (split_line[0]))
                    continue
                self.sig_div_dict[split_line[0]] = temp_pos
        self.sig_iter = self._iter_signals()
        self.review = True
        self.ax.clear()
        self.ax.text(0.5,
                     0.5,
                     'Press D to view next signal.',
                     horizontalalignment='center',
                     verticalalignment='center',
                     transform=self.ax.transAxes)
        self.canvas.draw()

    def next_signal(self):
        try:
            self.curr_file, self.id, self.signal = next(self.sig_iter)
            while (self.curr_file, self.id) in self.cache_fns:
                self.curr_file, self.id, self.signal = next(self.sig_iter)
        except StopIteration:
            print("End of file")
            if self.save is None:
                self.save = str(
                    QFileDialog.getExistingDirectory(
                        self, "Choose directory to save."))
            self._save()
            exit(0)

    def _quit(self):
        self.close()
Beispiel #33
0
class BaseTS:
    frontend = "matplotlib"

    def __init__(self,
                 vars_=[],
                 xrange=None,
                 update_r=0,
                 figsize=(800, 800),
                 dpi=100):
        x, y = figsize
        self.fig = Figure((x / dpi, y / dpi), dpi)
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.add_subplot(111)
        self._tracked_vars = []
        self.lines = []
        self.update_rate = update_r
        self.update_i = 0
        for var in vars_:
            self.add_var(var)
        self.xrange = xrange
        self.ax.set_xlim(xrange)
        self.ax.set_xlabel("time")
        self.set_colors()
        # self.ax.set_ylabel('x')
        self.minx = 0
        self.maxx = 1
        self.maxy = -np.inf
        self.miny = np.inf

    def add_var(self, var):
        self._tracked_vars.append(var)

    # override if neccesary
    def update_triggered(self):
        self.ax.lines = []
        for i, lam in enumerate(self._tracked_vars):
            x = lam[0](lam[1]).tracked_variables[lam[2]][0]
            if len(lam) == 4:
                y = [
                    x.item(lam[3])
                    for x in lam[0](lam[1]).tracked_variables[lam[2]][1]
                ]
            else:
                y = lam[0](lam[1]).tracked_variables[lam[2]][1]
            lines = Line2D(x, y, color=COLORS[i % len(COLORS)])
            self.ax.add_line(lines)
            if x[-1] > self.maxx:
                self.maxx = x[-1]
            if x[-1] < self.minx:
                self.minx = x[-1]
            if y[-1] > self.maxy:
                self.maxy = y[-1]
            if y[-1] < self.miny:
                self.miny = y[-1]
        if self.xrange is not None:
            self.ax.set_xlim(self.xrange)
        else:
            self.ax.set_xlim((self.minx, self.maxx))
            self.ax.set_ylim((self.miny, self.maxy))
        self.canvas.draw()
        self.canvas.flush_events()

    def reset(self):
        self.minx = 0
        self.maxx = 1
        self.maxy = -np.inf
        self.miny = np.inf
        self.update_triggered()

    def iterate(self):
        if self.update_rate:
            self.update_i += 1
            if self.update_i % self.update_rate == 0:
                self.update_i = 0
                self.update_triggered()

    def set_colors(self):

        line_color = "#FFFFFF"
        face_color = "#000000"

        self.fig.patch.set_facecolor(face_color)
        self.ax.patch.set_facecolor(face_color)

        self.ax.spines['bottom'].set_color(line_color)
        self.ax.spines['top'].set_color(line_color)
        self.ax.xaxis.label.set_color(line_color)
        self.ax.tick_params(axis='x', colors=line_color)

        self.ax.spines['right'].set_color(line_color)
        self.ax.spines['left'].set_color(line_color)
        self.ax.yaxis.label.set_color(line_color)
        self.ax.tick_params(axis='y', colors=line_color)
Beispiel #34
0
class LabellingTool(qtw.QMainWindow):
    """"""
    def __init__(
        self,
        experiment_id: Optional[int] = None,
        db_folder: Optional[str] = None,
        db_name: Optional[str] = None,
        start_over: bool = False,
        figure_fontsize: int = 8,
    ) -> None:
        """"""
        if db_folder is None:
            db_folder = nt.config["db_folder"]

        LABELS = list(dict(nt.config["core"]["labels"]).keys())

        if db_name is None:
            logger.warning("Labelling default main database.")
            db_name = nt.config["main_db"]
        nt.set_database(db_name)
        self.db_name = db_name

        self.db_folder = db_folder

        # print(qc.config['core']['db_location'])
        matplotlib.rc("font", size=figure_fontsize)
        super(LabellingTool, self).__init__()

        self.current_label = dict.fromkeys(LABELS, 0)
        self.experiment_id = experiment_id

        if self.experiment_id is None:
            logger.error("Please select an experiment. Labelling entire " +
                         " database is not supported yet.")
            raise NotImplementedError
            # all_experiments = experiments()
            # for e in all_experiments:
            # self.experiment = e
            # (self._iterator_list,
            #  self.labelled_ids,
            #  self.n_total) = self.get_data_ids(start_over)
        else:
            try:
                self.experiment = load_experiment(self.experiment_id)

                (
                    self._iterator_list,
                    self.labelled_ids,
                    self.n_total,
                ) = self.get_data_ids(start_over)

                self._id_iterator = iter(self._iterator_list)
                try:
                    self.current_id = self._id_iterator.__next__()
                except StopIteration:
                    logger.warning("All data of this experiment is already " +
                                   "labelled")
                    raise

            except ValueError:
                msg = "Unable to load experiment."
                # ee = experiments()
                # for e in ee:
                #     msg += e.name + '\n'
                qtw.QMessageBox.warning(self,
                                        "Error instantiating LabellingTool.",
                                        msg, qtw.QMessageBox.Ok)
            except IndexError as I:
                msg = "Did not find any unlabelled data in experiment "
                msg += self.experiment.name + "."
                qtw.QMessageBox.warning(self,
                                        "Error instantiating LabellingTool.",
                                        msg, qtw.QMessageBox.Ok)

        self._main_widget = qtw.QWidget(self)
        self.setCentralWidget(self._main_widget)

        self.initUI()
        self.show()

    def get_data_ids(
        self,
        start_over: bool = False,
    ) -> Tuple[List[int], List[int], int]:
        """"""
        unlabelled_ids: List[int] = []
        labelled_ids: List[int] = []
        print("getting datasets")

        last_id = nt.get_last_dataid(self.db_name, db_folder=self.db_folder)
        all_ids = list(range(1, last_id))

        # dds = self.experiment.data_sets()
        # if len(dds) == 0:
        #     logger.error('Experiment has no data. Nothing to label.')
        #     raise ValueError

        # start_id = dds[0].run_id
        # stop_id = dds[-1].run_id
        # all_ids = list(range(start_id, stop_id+1))
        print("len(all_ids): " + str(len(all_ids)))

        # check if database has label columns as column
        if not start_over:
            # Make sure database has nanotune label columns. Just a check.
            try:
                ds = load_by_id(1)
                quality = ds.get_metadata("good")
            except OperationalError:
                logger.warning("""No nanotune_label column found in current
                                database. Probably because no data has been
                                labelled yet. Hence starting over. """)
                start_over = True
            # except RuntimeError:
            #     logger.error('Probably data in experiment.')
            #     raise
        print("start_over: " + str(start_over))
        if start_over:
            unlabelled_ids = all_ids
            labelled_ids = []
        else:
            unlabelled_ids = nt.get_unlabelled_ids(self.db_name)
            labelled_ids = [x for x in all_ids if x not in unlabelled_ids]

        return unlabelled_ids, labelled_ids, len(all_ids)

    def initUI(self) -> None:
        """"""
        # msg = str(len(self.labelled_ids)) + ' labelled, '
        # msg += str(self.n_total - len(self.labelled_ids)) + ' to go.'
        self.progressbar = qtw.QProgressBar()
        self.progressbar.setMinimum(1)
        self.progressbar.setMaximum(100)
        # self.progressbar.setTextVisible(True)

        self.statusBar().addPermanentWidget(self.progressbar)
        # self.progressbar.setGeometry(30, 40, 200, 25)

        pp = len(self.labelled_ids) / self.n_total * 100
        self.progressbar.setValue(pp)

        self.statusBar().showMessage("Are we there yet?")

        self.setGeometry(300, 250, 700, 1100)
        # integers are:
        # X coordinate
        # Y coordinate
        # Width of the frame
        # Height of the frame
        self.setWindowTitle("nanotune Labelling Tool")

        self._main_layout = self.initMainLayout()
        self._main_widget.setLayout(self._main_layout)

    def initMainLayout(self) -> qtw.QVBoxLayout:
        """"""
        # -----------  Main Layout  ----------- #
        layout = qtw.QVBoxLayout(self._main_widget)
        # -----------  Figure row  ----------- #
        figure_row = qtw.QVBoxLayout()

        self.l1 = qtw.QLabel()
        self.l1.setText("Plot ID: {}".format(self.current_id))
        self.l1.setAlignment(qtc.Qt.AlignCenter)
        figure_row.addWidget(self.l1)

        rcParams.update({"figure.autolayout": True})

        self._figure = Figure(tight_layout=True)
        self._axes = self._figure.add_subplot(111)
        self._cb = [None]

        while True:
            try:
                # if self._cb[0] is not None:
                #     _, self._cb = plot_by_id(self.current_id, axes=self._axes,
                #                              colorbars=self._cb[0])
                # else:
                _, self._cb = plot_by_id(self.current_id, axes=self._axes)
                # title = 'Run #{}  Experiment #{}'.format(self.current_id,
                #                                          self.experiment_id)
                # for ax in self._axes:
                #     ax.set_title(title)
                break
            except (RuntimeError, IndexError) as r:
                logger.warning("Skipping current dataset" + str(r))
                self.labelled_ids.append(self.current_id)
                try:
                    self.current_id = self._id_iterator.__next__()
                except StopIteration:
                    logger.error("All datasets labelled.")
                    break
                pass

        self._canvas = FigureCanvasQTAgg(self._figure)
        self._canvas.setParent(self._main_widget)

        self._canvas.draw()
        self._canvas.draw()

        self._canvas.update()
        self._canvas.flush_events()

        figure_row.addWidget(self._canvas)
        figure_row.addStretch(10)

        # -----------  Buttons row  ----------- #
        self._buttons = []
        self._quality_group = qtw.QButtonGroup(self)
        button_row = qtw.QHBoxLayout()

        # left part of button row
        quality_column = qtw.QVBoxLayout()

        btn_good = qtw.QPushButton("Good")
        btn_good.setObjectName("good")
        btn_good.setCheckable(True)
        self._quality_group.addButton(btn_good)

        self._buttons.append(btn_good)

        btn_bad = qtw.QPushButton(label_bad)
        btn_bad.setObjectName(label_bad)
        btn_bad.setCheckable(True)
        self._quality_group.addButton(btn_bad)
        self._buttons.append(btn_bad)

        quality_column.addWidget(btn_good)
        quality_column.addWidget(btn_bad)

        button_row.addLayout(quality_column)

        # right part of button row
        # list of instances of class label
        labels_column = qtw.QVBoxLayout()
        LABELS = list(dict(nt.config["core"]["labels"]).keys())
        LABELS_MAP = dict(nt.config["core"]["labels"])
        for label in LABELS:
            if label not in ["good"]:
                # bl = QHBoxLayout()
                btn = qtw.QPushButton(LABELS_MAP[label])
                btn.setObjectName(label)
                btn.setCheckable(True)
                # bl.addWidget(btn)
                # bl.addStretch(1)
                # labels_column.addLayout(bl)
                labels_column.addWidget(btn)
                self._buttons.append(btn)

        button_row.addLayout(labels_column)

        # -----------   Finalize row   ----------- #
        finalize_row = qtw.QHBoxLayout()

        # go_back_btn = QPushButton('Go Back')
        # finalize_row.addWidget(go_back_btn)
        # go_back_btn.clicked.connect(self.go_back)

        clear_btn = qtw.QPushButton("Clear")
        finalize_row.addWidget(clear_btn)
        clear_btn.clicked.connect(self.clear)

        save_btn = qtw.QPushButton("Save")
        finalize_row.addWidget(save_btn)
        save_btn.clicked.connect(self.save_labels)

        # -----------   Exit row   ----------- #
        exit_row = qtw.QHBoxLayout()

        exit_btn = qtw.QPushButton("Exit")
        exit_row.addWidget(exit_btn)
        exit_btn.clicked.connect(self.exit)

        empty_space = qtw.QHBoxLayout()
        empty_space.addStretch(1)
        exit_row.addLayout(empty_space)

        # -----------   Add all rows to main vertial box   ----------- #
        layout.addLayout(figure_row)
        layout.addLayout(button_row)
        layout.addLayout(finalize_row)
        layout.addLayout(exit_row)

        return layout

    def next(self) -> None:
        """"""
        # TO DO: Loop to the next unlabelled dataset ...
        self._axes.clear()
        self._axes.relim()
        # if len(self._figure.axes) > 1:
        #     self._figure.delaxes(self._figure.axes[1])
        if self._cb[0] is not None:
            for cbar in self._cb:
                cbar.ax.clear()
                cbar.ax.relim()
                cbar.remove()
                # cbar = None
        self._figure.tight_layout()

        while True:
            try:
                self.labelled_ids.append(self.current_id)
                self.current_id = self._id_iterator.__next__()

                # Update GUI
                self.l1.setText("Plot ID: {}".format(self.current_id))
                pp = len(self.labelled_ids) / self.n_total * 100
                self.progressbar.setValue(pp)
                # _, self._cb = plot_by_id(self.current_id, axes=self._axes,
                #                              colorbars=self._cb[0])
                # # if self._cb[0] is not None:
                # _, self._cb = plot_by_id(self.current_id, axes=self._axes,
                #                      colorbars=self._cb[0])
                # else:
                _, self._cb = plot_by_id(self.current_id, axes=self._axes)
                # for cbar in self._cb:
                #     cbar.ax.clear()
                #     cbar.ax.relim()
                #     cbar.remove()
                # title = 'Run #{}  Experiment #{}'.format(self.current_id,
                #                                          self.experiment_id)
                # for ax in self._axes:
                #     ax.set_title(title)

                self._figure.tight_layout()
                self._canvas.draw()
                self._canvas.update()
                self._canvas.flush_events()
                break

            except StopIteration:
                msg1 = "You are done!"
                msg2 = "All datasets of " + self.experiment.name
                msg2 += " are labelled."
                qtw.QMessageBox.information(self, msg1, msg2,
                                            qtw.QMessageBox.Ok)
                return
            except (RuntimeError, IndexError) as r:
                logger.warning("Skipping this dataset " + str(r))
                self.labelled_ids.append(self.current_id)
                try:
                    self.current_id = self._id_iterator.__next__()
                except StopIteration:
                    msg1 = "You are done!"
                    msg2 = "All datasets of " + self.experiment.name
                    msg2 += " are labelled."
                    qtw.QMessageBox.information(self, msg1, msg2,
                                                qtw.QMessageBox.Ok)
                    return
                pass

    def clear(self) -> None:
        """"""
        self._quality_group.setExclusive(False)
        for btn in self._buttons:
            btn.setChecked(False)
        self._quality_group.setExclusive(True)

        LABELS = list(dict(nt.config["core"]["labels"]).keys())
        self.current_label = dict.fromkeys(LABELS, 0)

    def save_labels(self) -> None:
        """"""
        # logger.error('Need to update label saving! -> One column per label.')
        # raise NotImplementedError
        for button in self._buttons:
            if button.objectName() == label_bad:
                continue
            checked = button.isChecked()
            self.current_label[button.objectName()] = int(checked)

        if self._quality_group.checkedId() == -1:
            msg = 'Please choose quality. \n \n Either "Good" or '
            msg += '"' + label_bad + '"' + " has"
            msg += " to be selected."
            qtw.QMessageBox.warning(self, "Cannot save label.", msg,
                                    qtw.QMessageBox.Ok)
        else:
            ds = load_by_id(self.current_id)

            for label, value in self.current_label.items():
                ds.add_metadata(label, value)

            self.clear()
            self.next()

    # def go_back(self):
    #     """
    #     """
    #     self._axes.clear()
    #     if self._cb[0] is not None:
    #         self._cb[0].ax.clear()

    #     while True:
    #         try:
    #             self.labelled_ids = self.labelled_ids[:-1]
    #             self.current_id -= 1

    #             # Update GUI
    #             self.l1.setText('Plot ID: {}'.format(self.current_id))
    #             pp = (len(self.labelled_ids)/self.n_total*100)
    #             self.progressbar.setValue(pp)
    #             if self._cb[0] is not None:
    #                 _, self._cb = plot_by_id(self.current_id, axes=self._axes,
    #                                          colorbars=self._cb[0])
    #             else:
    #                 _, self._cb = plot_by_id(self.current_id, axes=self._axes)
    #             self._figure.tight_layout()

    #             self._canvas.draw()
    #             self._canvas.update()
    #             self._canvas.flush_events()
    #             break

    #         except (RuntimeError, IndexError) as r:
    #             logger.warning('Skipping a dataset ' +
    #                             str(r))
    #             # self.labelled_ids.append(self.current_id)
    #             # self.current_id = self._id_iterator.__next__()
    #             self.current_id -= 1
    #             self.go_back()
    #             pass

    def exit(self) -> None:
        """"""
        # logger.warning('Saving labels.')
        n_missing = self.n_total - len(self.labelled_ids)
        quit_msg1 = "Please don't go"  # , nanotune needs you! \n"
        # quit_msg1 += " " + str(n_missing) + ' datasets are calling for labels.'
        quit_msg2 = "Would you like to give it another try?"
        reply = qtw.QMessageBox.question(self, quit_msg1, quit_msg2,
                                         qtw.QMessageBox.Yes,
                                         qtw.QMessageBox.No)

        if reply == qtw.QMessageBox.No:
            self.close()