Example #1
0
   def __createFigure(self):  ##创建绘图系统
      self.__fig=mpl.figure.Figure(figsize=(8, 5)) #单位英寸
      figCanvas = FigureCanvas(self.__fig)         #创建FigureCanvas对象,必须传递一个Figure对象
      self.__naviBar=NavigationToolbar(figCanvas, self)  #创建NavigationToolbar工具栏

      actList=self.__naviBar.actions()  #关联的Action列表
      for act in actList:     #获得每个Action的标题和tooltip,可注释掉
         print ("text=%s,\ttoolTip=%s"%(act.text(),act.toolTip()))
      self.__changeActionLanguage() #改工具栏的语言为汉语
   ##工具栏改造
      actList[6].setVisible(False)  #隐藏Subplots 按钮
      actList[7].setVisible(False)  #隐藏Customize按钮
      act8=actList[8] #分隔条
      self.__naviBar.insertAction(act8,self.ui.actTightLayout)    #"紧凑布局"按钮
      self.__naviBar.insertAction(act8,self.ui.actSetCursor)      #"十字光标"按钮
      
      count=len(actList)       #Action的个数
      lastAction=actList[count-1]   #最后一个Action
      self.__naviBar.insertAction(lastAction,self.ui.actScatterAgain)  #"重绘散点"按钮

      lastAction.setVisible(False) #隐藏其原有的坐标提示
      self.__naviBar.addSeparator()
      self.__naviBar.addAction(self.ui.actQuit)    #"退出"按钮
      self.__naviBar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon) #显示方式
      
      self.addToolBar(self.__naviBar)  #添加作为主窗口工具栏
      self.setCentralWidget(figCanvas)

      figCanvas.setCursor(Qt.CrossCursor)
   ## 必须保留变量cid,否则可能被垃圾回收
      self._cid1=figCanvas.mpl_connect("motion_notify_event",self.do_canvas_mouseMove)
      self._cid2=figCanvas.mpl_connect("axes_enter_event",self.do_axes_mouseEnter)
      self._cid3=figCanvas.mpl_connect("axes_leave_event",self.do_axes_mouseLeave)
      self._cid4=figCanvas.mpl_connect("pick_event",self.do_series_pick)
      self._cid5=figCanvas.mpl_connect("scroll_event",self.do_scrollZoom)
Example #2
0
    def create_figure_widget(cls, parent=None):
        """

        :rtype tuple(QWidget, Figure)
        """
        # Inspiration:
        # http://matplotlib.org/examples/user_interfaces/embedding_in_qt4_wtoolbar.html
        fig = Figure()

        widget = QtWidgets.QWidget(parent)
        layout = QtWidgets.QVBoxLayout()
        widget.setLayout(layout)

        # The canvas widget
        canvas = FigureCanvas(fig)
        canvas.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
           QtWidgets.QSizePolicy.Expanding)
        canvas.setFocusPolicy(QtCore.Qt.StrongFocus)
        canvas.updateGeometry()

        # The toolbar
        mpl_toolbar = NavigationToolbar(canvas, widget)

        def on_key_press(event):
            key_press_handler(event, canvas, mpl_toolbar)
        canvas.mpl_connect('key_press_event', on_key_press)

        # Lay it out
        layout.addWidget(canvas)
        layout.addWidget(mpl_toolbar)
        return widget, fig
Example #3
0
    def __init__(self, tsne, cell_clusters, colors, out_dir, pr_res=None, trends_win=None):
        super().__init__()
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        self.setAttribute(QtCore.Qt.WA_DeleteOnClose)
        layout = QtWidgets.QVBoxLayout(self._main)

        self.tsne = tsne
        self.pr_res = pr_res
        self.cell_clusters = cell_clusters
        self.colors = colors
        self.out_dir = out_dir
        self.pseudotime = 0
        self.trends_win = trends_win

        self.done_button = QtWidgets.QPushButton("Done")
        self.done_button.clicked.connect(self.on_click_done)
        layout.addWidget(self.done_button)

        if pr_res:
            self.setWindowTitle("tSNE -- pseudotime")

            self.set_pseudotime = QtWidgets.QLineEdit(self)
            layout.addWidget(self.set_pseudotime)
            self.set_pseudotime_button = QtWidgets.QPushButton('set pseudotime')
            self.set_pseudotime_button.clicked.connect(self._update_pseudotime)
            layout.addWidget(self.set_pseudotime_button)

            self.slider = QtWidgets.QSlider(Qt.Horizontal)
            self.slider.setFocusPolicy(Qt.StrongFocus)
            self.slider.setTickPosition(QtWidgets.QSlider.TicksBothSides)
            self.slider.setMinimum(0)
            self.slider.setMaximum(10000)
            self.slider.setTickInterval(1)
            self.slider.setSingleStep(1)
            self.slider.valueChanged.connect(self._update_tsne)
            layout.addWidget(self.slider)

            self.pseudotime_cluster_button = QtWidgets.QPushButton('Set pseudotime cluster as cluster')
            self.pseudotime_cluster_button.clicked.connect(self._on_pseudotime_cluster_click)
            layout.addWidget(self.pseudotime_cluster_button)
        else:
            self.setWindowTitle("tSNE -- define custom clusters")

        tsne_canvas = FigureCanvas(Figure(figsize=(5, 5)))
        layout.addWidget(tsne_canvas)
        self.addToolBar(QtCore.Qt.BottomToolBarArea,
                        NavigationToolbar(tsne_canvas, self))
        self._tsne_fig = tsne_canvas.figure
        self._tsne_ax = tsne_canvas.figure.subplots()

        tsne_canvas.mpl_connect('button_release_event', self.on_mouse_click)

        fig, self._tsne_ax = palantir.plot.plot_tsne(self.tsne, fig=self._tsne_fig, ax=self._tsne_ax)
        self._tsne_ax.figure.canvas.draw()
Example #4
0
    def __init__(self, trends, pr_res, clusters, id_to_name, out_dir, goea_dir,
                 tsne, cell_clusters, colors):
        super().__init__()
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        self.setAttribute(QtCore.Qt.WA_DeleteOnClose)
        self.setWindowTitle("Gene trends for ")
        layout = QtWidgets.QVBoxLayout(self._main)

        self.trends = trends
        self.clusters = clusters
        self.id_to_name = id_to_name
        self.out_dir = out_dir
        self.goea_dir = goea_dir

        self.line = False
        self.lines = list()
        self.pseudotime = 0
        self.n_lines = 0

        self.colors = colors

        ##self.colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']

        ## drop down to choose trajectory
        self.select_cluster = QtWidgets.QComboBox()
        layout.addWidget(self.select_cluster)
        self.select_cluster.addItem('Select cluster')

        for cluster in set(self.clusters):
            self.select_cluster.addItem(str(cluster))
        self.select_cluster.currentIndexChanged.connect(self._select_cluster)

        gene_trend_canvas = FigureCanvas(Figure(figsize=(5, 5)))
        layout.addWidget(gene_trend_canvas)
        self.addToolBar(NavigationToolbar(gene_trend_canvas, self))
        self._gt_fig = gene_trend_canvas.figure
        self._gt_ax = gene_trend_canvas.figure.subplots()

        ## Add listener for mouse motion to update tsne canvas
        gene_trend_canvas.mpl_connect('motion_notify_event',
                                      self._on_mouse_move)

        go_button = QtWidgets.QPushButton("GO analysis")
        go_button.clicked.connect(self._on_click_go)
        layout.addWidget(go_button)

        self.tsne_win = TsneWindow(tsne,
                                   cell_clusters,
                                   colors,
                                   out_dir,
                                   pr_res=pr_res,
                                   trends_win=self)
        self.tsne_win.show()
Example #5
0
    def __init__(self, parent=None, toolbarVisible=True, showHint=False):
        super().__init__(parent)

        self.figure = mpl.figure.Figure(figsize=(50, 50))
        figCanvas = FigureCanvas(self.figure)
        # scroll = QScrollArea(self.widget)
        # self.scroll.setWidget(figCanvas)

        self.naviBar = NavigationToolbar(figCanvas, self)

        actList = self.naviBar.actions()
        count = len(actList)
        self.__lastActtionHint = actList[count - 1]
        self.__showHint = showHint
        self.__lastActtionHint.setVisible(self.__showHint)
        self.__showToolbar = toolbarVisible
        self.naviBar.setVisible(self.__showToolbar)

        layout = QVBoxLayout(self)
        layout.addWidget(self.naviBar)
        layout.addWidget(figCanvas)
        # layout.addWidget(scroll)
        layout.setContentsMargins(0, 0, 0, 0)
        layout.setAlignment(Qt.AlignTop)
        layout.setSpacing(0)

        self.__cid = figCanvas.mpl_connect("scroll_event", self.do_scrollZoom)
Example #6
0
    def __init__(self, parent=None, toolbarVisible=True, showHint=False):
        super().__init__(parent)  #调用父类构造函数,创建窗体

        self.figure = plt.figure()  #公共的figure属性
        figCanvas = FigureCanvas(self.figure)  #创建FigureCanvas对象,必须传递一个Figure对象

        self.naviBar = NavigationToolbar(figCanvas,
                                         self)  #创建NavigationToolbar工具栏
        self.__changeActionLanguage()  #改为汉语

        actList = self.naviBar.actions()  #关联的Action列表
        count = len(actList)  #Action的个数
        self.__lastActtionHint = actList[count - 1]  #最后一个Action

        self.__showHint = showHint
        self.__lastActtionHint.setVisible(self.__showHint)  #隐藏其原有的坐标提示

        self.__showToolbar = toolbarVisible
        self.naviBar.setVisible(self.__showToolbar)

        layout = QVBoxLayout(self)
        layout.addWidget(self.naviBar)  #添加工具栏NavigationToolbar对象
        layout.addWidget(figCanvas)  #添加FigureCanvas对象
        layout.setContentsMargins(0, 0, 0, 0)
        layout.setSpacing(0)
        ##      self.setCentralWidget(widget)

        #必须保留变量 ,cid1,否则可能被垃圾回收
        self.__cid = figCanvas.mpl_connect("scroll_event", self.do_scrollZoom)
Example #7
0
    def __init__(self, parent=None, toolbarVisible=True, showHint=False):
        super().__init__(parent)

        self.figure = Figure()  #公共的figure属性
        figCanvas = FigureCanvas(self.figure)  #创建FigureCanvas对象,必须传递一个Figure对象

        self.naviBar = NavigationToolbar(figCanvas, self)  #公共属性naviBar
        self.__changeActionLanguage()  #改为汉语

        actList = self.naviBar.actions()  #关联的Action列表
        count = len(actList)  #Action的个数
        self.__lastActtionHint = actList[count - 1]  #最后一个Action,坐标提示标签
        self.__showHint = showHint  #是否在工具栏上显示坐标提示
        self.__lastActtionHint.setVisible(self.__showHint)  #隐藏其原有的坐标提示

        self.__showToolbar = toolbarVisible  #是否显示工具栏
        self.naviBar.setVisible(self.__showToolbar)

        layout = QVBoxLayout(self)
        layout.addWidget(self.naviBar)  #添加工具栏
        layout.addWidget(figCanvas)  #添加FigureCanvas对象
        layout.setContentsMargins(0, 0, 0, 0)
        layout.setSpacing(0)

        #鼠标滚轮缩放
        self.__cid = figCanvas.mpl_connect("scroll_event", self.do_scrollZoom)
Example #8
0
class MplWidget(QWidget):
    def __init__(self, parent=None):

        QWidget.__init__(self, parent)

        self.canvas = FigureCanvas(Figure(tight_layout=True))

        vertical_layout = QVBoxLayout()
        vertical_layout.addWidget(self.canvas)

        self.setLayout(vertical_layout)

        self.canvas.axes = None  #self.canvas.figure.add_subplot(111)

        # add new animation feature
        def on_press(event):
            if event.key.isspace():
                if anim.running:
                    anim.event_source.stop()
                else:
                    anim.event_source.start()
                anim.running ^= True

        self.canvas.mpl_connect('key_press_event', on_press)
Example #9
0
    def __init__(self):
        super(ApplicationWindow, self).__init__()
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        layout = QtWidgets.QVBoxLayout(self._main)

        static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(static_canvas)
        self.addToolBar(NavigationToolbar(static_canvas, self))

        self._static_ax = static_canvas.figure.add_subplot(111)
        t = np.linspace(0, 10, 51)
        self._static_ax.plot(t, np.sin(t), ".", picker=5)
        self._static_ax.plot(t, np.cos(t), "o", picker=5)
        self.cid = static_canvas.mpl_connect('pick_event', self.on_pick)
        self.curve_dialog = CurvePropertiesDialog()
Example #10
0
class BoardView(QDialog, QObject):
    frame = pyqtSignal(QImage)

    def __init__(self, parent=None):
        QDialog.__init__(self)
        self.setWindowTitle("Board View")

        #		self.setWindowModality(QtCore.Qt.ApplicationModal)
        self.gridLayout = QGridLayout(self)
        self.setLayout(self.gridLayout)

        self.makeWindow()

    def makeWindow(self):
        btnline = 0

        self.dynamic_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        self.plotWidgetax = self.dynamic_canvas.figure.subplots()
        self.gridLayout.addWidget(self.dynamic_canvas, 0, 0, btnline, 1)
        cid = self.dynamic_canvas.mpl_connect('button_press_event',
                                              self.boardclick)
        self.temp()

    def boardclick(self, event):
        if event.dblclick and event.button == 1:
            print("dblckicked board", event.xdata, event.ydata)

    def temp(self):
        x = np.linspace(-1, 1, 100)
        y = np.linspace(-1, 1, 100)
        npts = 400
        px, py = np.random.choice(x, npts), np.random.choice(y, npts)
        X, Y = np.meshgrid(x, y)
        Ti = griddata((px, py), f(px, py), (X, Y), method='cubic')
        c = self.plotWidgetax.pcolormesh(X, Y, Ti)

    def new_cam_files(self, gcodefile):
        (lines, rapidlines) = gcodefile.previewGcode()
        lc = mc.LineCollection(lines)
        rlc = mc.LineCollection(rapidlines, color="red")

        print(lc)
        print(rlc)
        self.plotWidgetax.add_collection(lc)
        self.plotWidgetax.add_collection(rlc)
        self.plotWidgetax.plot()
        self.dynamic_canvas.draw()
Example #11
0
class Plot(QtWidgets.QWidget):
    def __init__(self, item_model, selection_model):
        super().__init__()

        self.item_model = item_model
        self.item_model.dataChanged.connect(self.update_plot)
        self.item_model.rowsInserted.connect(self.update_plot)
        self.item_model.rowsRemoved.connect(self.update_plot)

        self.selection_model = selection_model
        self.selection_model.currentRowChanged.connect(self.update_plot)

        self.canvas = FigureCanvas(Figure())
        self.canvas.mpl_connect('button_press_event', self.add_or_show)

        self.axes = self.canvas.figure.subplots()
        self.axes.axis('equal')

        self.update_plot()
        self.axes.autoscale()

        self.layout = QtWidgets.QVBoxLayout(self)
        self.layout.addWidget(self.canvas)
        self.setLayout(self.layout)

    @QtCore.Slot()
    def update_plot(self):
        xlim, ylim = (self.axes.get_xlim(), self.axes.get_ylim())
        self.axes.clear()
        self.axes.autoscale(False)

        data = self.item_model.get_data()
        have_image = []
        no_image = []
        for point, image in data:
            if image:
                have_image.append(point)
            else:
                no_image.append(point)

        if len(have_image) > 0:
            x, y, _ = zip(*have_image)
            self.axes.plot(x, y, 'o', markersize=5, color='#2ca02c')

        if len(no_image) > 0:
            x, y, _ = zip(*no_image)
            self.axes.plot(x, y, 'o', markersize=5, color='#1f77b4')

        current_index = self.selection_model.currentIndex()
        if current_index.isValid() and current_index.row() < len(data):
            x, y, _ = data[current_index.row()][0]
            self.axes.plot([x], [y],
                           'o',
                           markersize=10,
                           fillstyle='none',
                           markeredgewidth=2,
                           color='#ff7f0e')

        self.axes.set_xlim(xlim)
        self.axes.set_ylim(ylim)
        self.canvas.draw()

    @QtCore.Slot()
    def add_or_show(self, event):
        if self.canvas.toolbar.mode != '' or event.button != 1:
            return

        click_x, click_y = (event.xdata, event.ydata)
        if None in [click_x, click_y]:
            return

        scale = self.axes.get_xlim()[1] - self.axes.get_xlim()[0]
        for idx, ((x, y, _), _) in enumerate(self.item_model.get_data()):
            distance = sqrt(pow(x - click_x, 2) + pow(y - click_y, 2))
            if distance < scale * 0.01:
                self.selection_model.setCurrentIndex(
                    self.item_model.createIndex(idx, 0),
                    QtCore.QItemSelectionModel.Current,
                )
                return
Example #12
0
class DataView(QtWidgets.QWidget):
    def __init__(self, parent=None, figure=None, femb=0):
        super().__init__(parent=parent)
        self.femb = femb
        if figure is None:
            figure = Figure(tight_layout=True)
        self.setFocusPolicy(QtCore.Qt.StrongFocus)
        self.figure = figure
        self.fig_ax = self.figure.subplots()
        self.fig_canvas = FigureCanvas(self.figure)
        self.fig_canvas.draw()

        self.fig_toolbar = CustomNavToolbar(self.fig_canvas,
                                            self,
                                            coordinates=False)
        self.fig_toolbar.setParent(self.fig_canvas)
        self.fig_toolbar.setMinimumWidth(300)

        self.fig_canvas.mpl_connect("resize_event", self.resize)
        self.resize(None)

        self.layout = QtWidgets.QVBoxLayout(self)
        self.layout.setContentsMargins(0, 0, 0, 0)
        self.layout.setSpacing(0)
        self.layout.addWidget(self.fig_canvas)

        self.toolbar_shown(False)

        self.save_props = []

        self.last_lims = None

        self.times, self.data = None, None

    def resize(self, event):
        x, y = self.figure.axes[0].transAxes.transform((0, 0.0))
        figw, figh = self.figure.get_size_inches()
        ynew = figh * self.figure.dpi - y - self.fig_toolbar.frameGeometry(
        ).height()
        self.fig_toolbar.move(int(x), int(ynew))

    def focusInEvent(self, *args, **kwargs):
        super().focusInEvent(*args, **kwargs)
        self.resize(None)
        self.toolbar_shown(True)

    def focusOutEvent(self, *args, **kwargs):
        super().focusOutEvent(*args, **kwargs)
        self.toolbar_shown(False)

    def toolbar_shown(self, shown):
        if shown:
            self.fig_toolbar.show()
        else:
            self.fig_toolbar.hide()

    def get_state(self):
        all_props = self.__dict__
        return {
            prop: getattr(self, prop)
            for prop in self.save_props if prop in all_props
        }

    def set_state(self, state):
        all_props = self.__dict__
        for prop, val in state.items():
            if prop in all_props:
                setattr(self, prop, val)

    def load_data(self, timestamps, samples):
        pass

    def plot_data(self, rescale=False, save_to=None):
        pass
Example #13
0
class PlotInterpolation(QtWidgets.QMainWindow):
    """
    Plot time series interpolation tool
    """
    def __init__(self, parent=None):
        super(PlotInterpolation, self).__init__(parent)

    def __set_variables(self, qa_analytics):
        """
        Set variables from TATSSI QA analytics
        """
        # Set qa_analytics
        self.qa_analytics = qa_analytics

        # imshow plots
        self.left_imshow = None
        self.right_imshow = None
        self.projection = None

        # Set widgets connections with methods
        self.data_vars.currentIndexChanged.connect(
                self.__on_data_vars_change)

        self.time_steps.currentIndexChanged.connect(
                self.__on_time_steps_change)

        self.pb_Interpolate.clicked.connect(
                self.on_pbInterpolate_click)

        self.interpolation_methods.clicked.connect(
                self.on_interpolation_methods_click)

        self.progressBar.setEnabled(False)
        self.pb_Interpolate.setEnabled(False)

        # Change combobox stylesheet and add scrollbar
        self.time_steps.setStyleSheet("combobox-popup: 0")
        self.time_steps.view().setVerticalScrollBarPolicy(
                Qt.ScrollBarAsNeeded)

        # Time series object
        self.ts = qa_analytics.ts
        # Source dir
        self.source_dir = qa_analytics.source_dir
        # Product and version
        self.product = qa_analytics.product
        self.version = qa_analytics.version

        # Mask
        self.mask = qa_analytics.mask

        # Data variables
        self.data_vars.addItems(self.__fill_data_variables())
        # Time steps
        self.time_steps.addItems(self.__fill_time_steps())

         # Create plot objects
        self.__create_plot_objects()

        # Populate plots
        self.__populate_plots()

        # Set interpolation methods
        self.__fill_interpolation_methods()

    def __fill_interpolation_methods(self):
        """
        Fill interpolation methods
        """
        # TODO document interpolation methods
        # interpolation_methods = ['linear', 'nearest', 'slinear',
        #                          'quadratic', 'cubic', 'krog',
        #                          'pchip', 'spline', 'akima']

        interpolation_methods = ['linear', 'nearest', 'spline']

        self.interpolation_methods.addItems(interpolation_methods)

    def on_pbInterpolate_click(self):
        """
        Performs interpolation for using a specific user selected
        method taking into account the TATSSI QA analytics mask
        """
        # Wait cursor
        QtWidgets.QApplication.setOverrideCursor(Qt.WaitCursor)

        # Variables for interpolation
        self.qa_analytics.selected_data_var = \
                self.data_vars.currentText()
        self.qa_analytics.selected_interpolation_method = \
                self.interpolation_methods.selectedItems()[0].text()

        # TATSSI interpolation
        tsi = TimeSeriesInterpolation(self.qa_analytics, isNotebook=False)
        # Enable progress bar
        self.progressBar.setEnabled(True)
        self.progressBar.setValue(1)
        tsi.interpolate(progressBar=self.progressBar)

        # Standard cursor
        QtWidgets.QApplication.restoreOverrideCursor()

        # Disable progress bar
        self.progressBar.setValue(0)
        self.progressBar.setEnabled(False)

    def on_interpolation_methods_click(self):
        """
        Enable the pbInterpolate push button when there is an
        interpolation method selected
        """
        if len(self.interpolation_methods.selectedItems()) > 0:
            self.pb_Interpolate.setEnabled(True)
        else:
            self.pb_Interpolate.setEnabled(False)

    @pyqtSlot(int)
    def __on_time_steps_change(self, index):
        """
        Handles a change in the time step to display
        """
        if len(self.time_steps.currentText()) == 0 or \
                self.left_imshow is None or \
                self.right_imshow is None:
            return None

        # Wait cursor
        QtWidgets.QApplication.setOverrideCursor(Qt.WaitCursor)

        self.left_ds = getattr(self.ts.data, self.data_vars.currentText())
        if self.mask is None:
            self.right_ds = self.left_ds.copy(deep=True)
        else:
            self.right_ds = self.left_ds * self.mask

        self.left_imshow.set_data(self.left_ds.data[index])
        self.right_imshow.set_data(self.right_ds.data[index])

        # Set titles
        self.left_p.set_title(self.time_steps.currentText())
        self.right_p.set_title(self.time_steps.currentText())

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

        # Standard cursor
        QtWidgets.QApplication.restoreOverrideCursor()

    @pyqtSlot(int)
    def __on_data_vars_change(self, index):
        """
        Handles a change in the data variable to display
        """
        if len(self.data_vars.currentText()) == 0 or \
                self.left_imshow is None or \
                self.right_imshow is None:
            return None

        # Wait cursor
        QtWidgets.QApplication.setOverrideCursor(Qt.WaitCursor)

        self.left_ds = getattr(self.ts.data, self.data_vars.currentText())
        if self.mask is None:
            self.right_ds = self.left_ds.copy(deep=True)
        else:
            self.right_ds = self.left_ds * self.mask

        self.left_imshow.set_data(self.left_ds.data[0])
        self.right_imshow.set_data(self.right_ds.data[0])

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

        # Standard cursor
        QtWidgets.QApplication.restoreOverrideCursor()

    def __fill_data_variables(self):
        """
        Fill the data variables dropdown list
        """
        data_vars = []
        for data_var in self.ts.data.data_vars:
            data_vars.append(data_var)

        return data_vars

    def __fill_time_steps(self):
        """
        Fill the time steps dropdown list
        """
        tmp_ds = getattr(self.ts.data, self.data_vars.currentText())

        time_steps = np.datetime_as_string(tmp_ds.time.data, 'm').tolist()

        return time_steps

    def on_click(self, event):
        """
        Event handler
        """
        # Event does not apply for time series plot
        # Check if the click was in a
        if event.inaxes in [self.ts_p]:
            return

        # Clear subplot
        self.ts_p.clear()

        # Delete last reference point
        if len(self.left_p.lines) > 0:
            del self.left_p.lines[0]
            del self.right_p.lines[0]

        # Draw a point as a reference
        # Draw a point as a reference
        self.left_p.plot(event.xdata, event.ydata,
                marker='o', color='red', markersize=7, alpha=0.7)
        self.right_p.plot(event.xdata, event.ydata,
                marker='o', color='red', markersize=7, alpha=0.7)

        # Non-masked data
        left_plot_sd = self.left_ds.sel(longitude=event.xdata,
                                        latitude=event.ydata,
                                        method='nearest')
        if left_plot_sd.chunks is not None:
            left_plot_sd = left_plot_sd.compute()

        # Masked data
        right_plot_sd = self.right_ds.sel(longitude=event.xdata,
                                          latitude=event.ydata,
                                          method='nearest')
        if right_plot_sd.chunks is not None:
            right_plot_sd = right_plot_sd.compute()

        # Plots
        left_plot_sd.plot(ax=self.ts_p, color='black',
                linestyle = '-', linewidth=1, label='Original data')

        # Interpolate data
        right_plot_sd_masked = right_plot_sd.where(right_plot_sd != 0)
        right_plot_sd_masked.plot(ax = self.ts_p, color='blue',
                marker='o', linestyle='None', alpha=0.7, markersize=4,
                label='Kept by user QA selection')

        # For every interpol method selected by the user
        for method in self.interpolation_methods.selectedItems():
            _method=method.text()
            tmp_ds = right_plot_sd_masked.interpolate_na(dim='time',
                    method=_method)

            # Plot
            tmp_ds.plot(ax = self.ts_p, label=_method, linewidth=2)

        # Change ylimits
        max_val = left_plot_sd.data.max()
        min_val = left_plot_sd.data.min()

        data_range = max_val - min_val
        max_val = max_val + (data_range * 0.2)
        min_val = min_val - (data_range * 0.2)
        self.ts_p.set_ylim([min_val, max_val])

        # Legend
        self.ts_p.legend(loc='best', fontsize='small',
                         fancybox=True, framealpha=0.5)

        # Grid
        self.ts_p.grid(axis='both', alpha=.3)

        # Redraw plot
        plt.draw()

    def __populate_plots(self):
        """
        Populate plots
        """
        # Left plot
        self.left_ds = getattr(self.ts.data, self.data_vars.currentText())
        self.left_imshow = self.left_ds[0].plot.imshow(cmap='Greys_r',
                ax=self.left_p, add_colorbar=False,
                transform=self.projection)

        # Turn off axis
        self.left_p.axis('off')
        self.left_p.set_aspect('equal')
        self.fig.canvas.draw_idle()

        # Plot the centroid
        _layers, _rows, _cols = self.left_ds.shape

        plot_sd = self.left_ds[:, int(_rows / 2), int(_cols / 2)]
        plot_sd.plot(ax = self.ts_p, color='black',
                linestyle = '--', linewidth=1, label='Original data')

        # Right panel
        if self.mask is None:
            self.right_ds = self.left_ds.copy(deep=True)
        else:
            self.right_ds = self.left_ds * self.mask
            self.right_ds.attrs = self.left_ds.attrs

        # Right plot
        self.right_imshow = self.right_ds[0].plot.imshow(cmap='Greys_r',
                ax=self.right_p, add_colorbar=False,
                transform=self.projection)

        # Turn off axis
        self.right_p.axis('off')
        self.right_p.set_aspect('equal')

        #plt.margins(tight=True)
        #plt.tight_layout()

        # Legend
        self.ts_p.legend(loc='best', fontsize='small',
                         fancybox=True, framealpha=0.5)

    def __create_plot_objects(self):
        """
        Create plot objects
        """
        # Get projection from first data variable
        for key in self.qa_analytics.ts.data.data_vars:
            proj4_string = getattr(self.qa_analytics.ts.data, key).crs
            break

        # If projection is Sinusoidal
        srs = get_projection(proj4_string)
        if srs.GetAttrValue('PROJECTION') == 'Sinusoidal':
            globe=ccrs.Globe(ellipse=None,
                semimajor_axis=6371007.181,
                semiminor_axis=6371007.181)

            self.projection = ccrs.Sinusoidal(globe=globe)
        else:
            globe = ccrs.Globe(ellipse='WGS84')
            self.projection = ccrs.Mollweide(globe=globe)

        # Figure
        self.fig = plt.figure(figsize=(8.0, 7.0))

        # Left plot
        self.left_p = plt.subplot2grid((2, 2), (0, 0), colspan=1,
                projection=self.projection)

        # Right plot
        self.right_p = plt.subplot2grid((2, 2), (0, 1), colspan=1,
                sharex=self.left_p, sharey=self.left_p,
                projection=self.projection)

        if self.projection is not None:
            for _axis in [self.left_p, self.right_p]:
                _axis.coastlines(resolution='10m', color='white')
                _axis.add_feature(cfeature.BORDERS, edgecolor='white')
                _axis.gridlines()

        # Time series plot
        self.ts_p = plt.subplot2grid((2, 2), (1, 0), colspan=2)

    def _plot(self, qa_analytics, cmap='viridis', dpi=72):
        """
        From the TATSSI QA Analytics object plots:
          - Percentage of data available
          - Maximum gap length
        """
        # Load UI
        uic.loadUi('plot_interpolation.ui', self)

        # Set plot variables
        self.__set_variables(qa_analytics)

        # Set plot on the plot widget
        self.plotWidget = FigureCanvas(self.fig)
        # Set focus
        self.plotWidget.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.plotWidget.setFocus()
        # Connect the canvas with the event
        self.plotWidget.mpl_connect('button_press_event',
                self.on_click)

        lay = QtWidgets.QVBoxLayout(self.content_plot)
        lay.setContentsMargins(0, 100, 0, 0)
        lay.addWidget(self.plotWidget)

        # Add toolbar
        font = QFont()
        font.setPointSize(12)

        toolbar = NavigationToolbar(self.plotWidget, self)
        toolbar.setFont(font)

        self.addToolBar(QtCore.Qt.BottomToolBarArea, toolbar)

        # Needed in order to use a tight layout with Cartopy axes
        self.fig.canvas.draw()
        plt.tight_layout()
Example #14
0
class ResponseGUI(QtWidgets.QMainWindow):
    def __init__(self, fname=None, output_fname='', star_name='', order=3, smoothing=0.02, parent=None, locked=False, **kwargs):
        QtWidgets.QMainWindow.__init__(self, parent)
        self.setWindowTitle('PyNOT: Response')
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)

        # Set attributes:
        self.spectrum = None
        self.response = None
        self.filename = fname
        self.output_fname = output_fname
        self.first_time_open = True
        # Extinction table attributes:
        self.ext_fname = alfosc.path + '/calib/lapalma.ext'
        try:
            ext_wl, ext = np.loadtxt(alfosc.path + '/calib/lapalma.ext', unpack=True)
        except:
            ext_wl, ext = None, None
        self.ext_wl = ext_wl
        self.ext = ext
        # Reference table attributes:
        self.ref_tab = None
        self.flux_bins = None
        self.wl_bins = None
        self.mag_bins = None
        self.resp_bins = None
        self.mask = None


        # Fitting Parameters:
        self.star_chooser = QtWidgets.QComboBox()
        self.all_names = sorted([name.upper() for name in alfosc.standard_stars] + [''])
        self.star_chooser.addItems(self.all_names)
        self.star_chooser.setCurrentText(star_name)
        self.star_chooser.currentTextChanged.connect(self.set_star)

        self.exptime_edit = QtWidgets.QLineEdit("")
        self.exptime_edit.setValidator(QtGui.QDoubleValidator())
        self.airmass_edit = QtWidgets.QLineEdit("")
        self.airmass_edit.setValidator(QtGui.QDoubleValidator())

        self.order_edit = QtWidgets.QLineEdit("%i" % order)
        self.order_edit.setValidator(QtGui.QIntValidator(1, 5))
        self.order_edit.returnPressed.connect(self.fit_response)

        self.smooth_edit = QtWidgets.QLineEdit("%.2f" % smoothing)
        self.smooth_edit.setValidator(QtGui.QDoubleValidator())
        self.smooth_edit.returnPressed.connect(self.fit_response)

        self.fit_btn = QtWidgets.QPushButton("Fit Response")
        self.fit_btn.setShortcut("ctrl+F")
        self.fit_btn.clicked.connect(self.fit_response)


        # -- Plotting
        self.figure_points = Figure(figsize=(8, 6))
        self.axes = self.figure_points.subplots(2, 1)
        self.canvas_points = FigureCanvas(self.figure_points)
        self.canvas_points.mpl_connect('pick_event', self.pick_points)
        self.figp_mpl_toolbar = NavigationToolbar(self.canvas_points, self)
        self.figp_mpl_toolbar.setFixedHeight(20)
        self.data_line = None
        self.fit_line = None
        self.data_points = None
        self.response_points = None

        # == TOP MENU BAR:
        self.save_btn = QtWidgets.QPushButton("Save")
        self.save_btn.clicked.connect(self.save_response)
        self.load_btn = QtWidgets.QPushButton("Load")
        self.load_btn.clicked.connect(self.load_spectrum)
        if locked:
            self.close_btn = QtWidgets.QPushButton("Done")
            self.close_btn.clicked.connect(self.done)
            self.load_btn.setEnabled(False)
            self.save_btn.setEnabled(False)
            self.save_btn.setText("")
        else:
            self.close_btn = QtWidgets.QPushButton("Close")
            self.close_btn.clicked.connect(self.close)


        # == Layout ===========================================================
        main_layout = QtWidgets.QVBoxLayout(self._main)
        main_layout.setSpacing(5)
        main_layout.setContentsMargins(5, 5, 5, 5)

        top_menubar = QtWidgets.QHBoxLayout()
        top_menubar.addWidget(self.close_btn)
        top_menubar.addWidget(self.save_btn)
        top_menubar.addWidget(self.load_btn)
        top_menubar.addStretch(1)

        central_layout = QtWidgets.QHBoxLayout()

        main_layout.addLayout(top_menubar)
        main_layout.addLayout(central_layout)

        # TabWidget Layout:
        fig_layout = QtWidgets.QVBoxLayout()
        fig_layout.addWidget(self.canvas_points, 1)
        fig_layout.addWidget(self.figp_mpl_toolbar)
        central_layout.addLayout(fig_layout)

        # Right Panel Layout:
        right_panel = QtWidgets.QVBoxLayout()
        right_panel.setContentsMargins(50, 0, 50, 10)

        separatorLine = QtWidgets.QFrame()
        separatorLine.setFrameShape(QtWidgets.QFrame.HLine)
        separatorLine.setFrameShadow(QtWidgets.QFrame.Sunken)
        separatorLine.setMinimumSize(3, 20)
        right_panel.addWidget(separatorLine)

        row_model = QtWidgets.QHBoxLayout()
        row_model.addWidget(QtWidgets.QLabel("Star Name: "))
        row_model.addWidget(self.star_chooser)
        right_panel.addLayout(row_model)

        separatorLine = QtWidgets.QFrame()
        separatorLine.setFrameShape(QtWidgets.QFrame.HLine)
        separatorLine.setFrameShadow(QtWidgets.QFrame.Sunken)
        separatorLine.setMinimumSize(3, 20)
        right_panel.addWidget(separatorLine)

        row_exptime = QtWidgets.QHBoxLayout()
        row_exptime.addWidget(QtWidgets.QLabel("Exposure Time: "))
        row_exptime.addWidget(self.exptime_edit)
        right_panel.addLayout(row_exptime)

        row_airmass = QtWidgets.QHBoxLayout()
        row_airmass.addWidget(QtWidgets.QLabel("Airmass: "))
        row_airmass.addWidget(self.airmass_edit)
        right_panel.addLayout(row_airmass)

        separatorLine = QtWidgets.QFrame()
        separatorLine.setFrameShape(QtWidgets.QFrame.HLine)
        separatorLine.setFrameShadow(QtWidgets.QFrame.Sunken)
        separatorLine.setMinimumSize(3, 20)
        right_panel.addWidget(separatorLine)

        row_orders = QtWidgets.QHBoxLayout()
        row_orders.addWidget(QtWidgets.QLabel("Spline Degree:"))
        row_orders.addWidget(self.order_edit)
        row_orders.addStretch(1)
        right_panel.addLayout(row_orders)

        row_smooth = QtWidgets.QHBoxLayout()
        row_smooth.addWidget(QtWidgets.QLabel("Smoothing factor:"))
        row_smooth.addWidget(self.smooth_edit)
        row_smooth.addStretch(1)
        right_panel.addLayout(row_smooth)

        separatorLine = QtWidgets.QFrame()
        separatorLine.setFrameShape(QtWidgets.QFrame.HLine)
        separatorLine.setFrameShadow(QtWidgets.QFrame.Sunken)
        separatorLine.setMinimumSize(3, 20)
        right_panel.addWidget(separatorLine)

        row_fit = QtWidgets.QHBoxLayout()
        row_fit.addStretch(1)
        row_fit.addWidget(self.fit_btn)
        row_fit.addStretch(1)
        right_panel.addLayout(row_fit)

        right_panel.addStretch(1)
        central_layout.addLayout(right_panel)

        self.canvas_points.setFocus()

        self.create_menu()

        # -- Set Data:
        if fname:
            self.load_spectrum(fname)


    def done(self):
        success = self.save_response(self.output_fname)
        if success:
            self.close()

    def save_response(self, fname=''):
        if self.response is None:
            msg = "No response function has been fitted. Nothing to save..."
            QtWidgets.QMessageBox.critical(None, "Save Error", msg)
            return False

        if not fname:
            current_dir = os.path.dirname(os.path.abspath(__file__))
            basename = os.path.join(current_dir, "response_%s.fits" % (self.spectrum.header['OBJECT']))
            filters = "FITS Files (*.fits *.fit)"
            fname, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Response Function', basename, filters)

        if fname:
            hdu = fits.HDUList()
            prim_hdr = fits.Header()
            prim_hdr['AUTHOR'] = 'PyNOT version %s' % __version__
            prim_hdr['OBJECT'] = self.spectrum.header['OBJECT']
            prim_hdr['DATE-OBS'] = self.spectrum.header['DATE-OBS']
            prim_hdr['EXPTIME'] = self.spectrum.header['EXPTIME']
            prim_hdr['AIRMASS'] = self.spectrum.header['AIRMASS']
            prim_hdr['ALGRNM'] = self.spectrum.header['ALGRNM']
            prim_hdr['ALAPRTNM'] = self.spectrum.header['ALAPRTNM']
            prim_hdr['RA'] = self.spectrum.header['RA']
            prim_hdr['DEC'] = self.spectrum.header['DEC']
            prim_hdr['COMMENT'] = 'PyNOT response function'
            prim = fits.PrimaryHDU(header=prim_hdr)
            hdu.append(prim)
            col_wl = fits.Column(name='WAVE', array=self.spectrum.wl, format='D', unit=self.spectrum.wl_unit)
            col_resp = fits.Column(name='RESPONSE', array=self.response, format='D', unit='-2.5*log(erg/s/cm2/A)')
            tab = fits.BinTableHDU.from_columns([col_wl, col_resp])
            hdu.append(tab)
            hdu.writeto(fname, overwrite=True, output_verify='silentfix')
            return True
        else:
            return False

    def clear_all(self):
        for ax in self.axes:
            ax.clear()
        self.exptime_edit.setText("")
        self.airmass_edit.setText("")
        self.flux_bins = None
        self.wl_bins = None
        self.mag_bins = None
        self.resp_bins = None
        self.response = None
        self.mask = None
        self.filename = ""
        self.canvas_points.draw()


    def load_spectrum(self, fname=''):
        if fname is False:
            current_dir = os.path.dirname(os.path.abspath(__file__))
            filters = "FITS files (*.fits | *.fit)"
            fname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open 1D Spectrum', current_dir, filters)
            fname = str(fname)
            if self.first_time_open:
                print(" [INFO] - Don't worry about the warning above. It's an OS warning that can not be suppressed.")
                print("          Everything works as it should")
                self.first_time_open = False

        if not os.path.exists(fname):
            return

        # Clear all models:
        self.clear_all()
        self.filename = fname
        hdr = fits.getheader(fname, 1)
        table = fits.getdata(fname, 1)
        self.spectrum = Spectrum(wl=table['WAVE'], data=table['FLUX'], header=hdr,
                                 wl_unit=table.columns['WAVE'].unit,
                                 flux_unit=table.columns['FLUX'].unit)
        if 'EXPTIME' in hdr:
            self.exptime_edit.setText("%.1f" % hdr['EXPTIME'])
            self.exptime_edit.setEnabled(False)
        else:
            self.exptime_edit.setEnabled(True)
        if 'AIRMASS' in hdr:
            self.airmass_edit.setText("%.1f" % hdr['AIRMASS'])
            self.airmass_edit.setEnabled(False)
        else:
            self.airmass_edit.setEnabled(True)

        if 'TCSTGT' in hdr:
            TCSname = hdr['TCSTGT']
            TCSname = alfosc.lookup_std_star(TCSname)
            if TCSname:
                star_name = alfosc.standard_star_names[TCSname]
                self.star_chooser.setCurrentText(star_name.upper())
        elif 'OBJECT' in hdr:
            object_name = hdr['OBJECT']
            if object_name.upper() in self.all_names:
                self.star_chooser.setCurrentText(object_name.upper())

        if self.ref_tab is not None:
            self.calculate_flux_in_bins()
            self.calculate_response_bins()
        self.update_plot()


    def set_star(self, text):
        star_name = str(text).lower()
        self.ref_tab = np.loadtxt(alfosc.path+'/calib/std/%s.dat' % star_name)
        self.calculate_flux_in_bins()
        self.calculate_response_bins()
        self.update_plot()

    def calculate_flux_in_bins(self):
        if self.spectrum is None:
            WarningDialog(self, "No spectrum loaded!", "No spectral data has been loaded.")
            return
        wl = self.spectrum.wl
        flux = self.spectrum.data
        flux_bins = list()
        for wl_ref, mag_ref, bandwidth in self.ref_tab:
            l1 = wl_ref - bandwidth/2
            l2 = wl_ref + bandwidth/2
            band = (wl >= l1) * (wl <= l2)
            if np.sum(band) > 3:
                f0 = np.nanmean(flux[band])
                if f0 < 0:
                    f0 = np.nan
            else:
                f0 = np.nan
            # flux_bins.append(f0 / bandwidth)
            flux_bins.append(f0)

        self.flux_bins = np.array(flux_bins)
        mask = ~np.isnan(self.flux_bins)
        self.wl_bins = self.ref_tab[:, 0][mask]
        self.mag_bins = self.ref_tab[:, 1][mask]
        self.dw = self.ref_tab[:, 2][mask]
        self.flux_bins = self.flux_bins[mask]
        self.mask = np.ones_like(self.flux_bins, dtype=bool)

    def calculate_response_bins(self):
        ref_flux = 10**(-(self.mag_bins + 2.406)/2.5) / (self.wl_bins)**2
        exp_str = self.exptime_edit.text()
        airm_str = self.airmass_edit.text()
        if exp_str == '' or airm_str == '':
            WarningDialog(self, "No exposure time or airmass!", "Please set both exposure time and airmass.")
            return
        exptime = float(exp_str)
        airmass = float(airm_str)
        # Calculate Sensitivity:
        extinction = np.interp(self.wl_bins, self.ext_wl, self.ext)
        cdelt = np.diff(self.spectrum.wl)[0]
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            self.resp_bins = 2.5*np.log10(self.flux_bins / (exptime * cdelt * ref_flux)) + airmass*extinction

    def update_plot(self):
        if self.spectrum is None:
            WarningDialog(self, "No spectrum loaded!", "No spectral data has been loaded.")
            return
        if 'OBJECT' in self.spectrum.header:
            object_name = self.spectrum.header['OBJECT']
        else:
            object_name = ''

        if self.data_line is None:
            self.data_line, = self.axes[0].plot(self.spectrum.wl, self.spectrum.data,
                                                color='k', alpha=0.9, lw=0.8, label=object_name)
        else:
            self.data_line.set_data(self.spectrum.wl, self.spectrum.data)
            self.data_line.set_label(object_name)
        xunit = self.spectrum.wl_unit
        yunit = self.spectrum.flux_unit
        self.axes[1].set_xlabel("Wavelength  [%s]" % xunit, fontsize=11)
        self.axes[0].set_ylabel("Flux  [%s]" % yunit, fontsize=11)
        self.axes[1].set_ylabel("Response", fontsize=11)
        self.axes[0].legend()
        self.figure_points.tight_layout()
        self.update_points()

    def update_points(self):
        if self.resp_bins is not None:
            mask = self.mask
            if self.response_points is None:
                self.response_points, = self.axes[1].plot(self.wl_bins[mask], self.resp_bins[mask], 'bo', picker=True, pickradius=5)
                self.masked_response, = self.axes[1].plot(self.wl_bins[~mask], self.resp_bins[~mask], 'rx', picker=True, pickradius=5)
            else:
                self.response_points.set_data(self.wl_bins[mask], self.resp_bins[mask])
                self.masked_response.set_data(self.wl_bins[~mask], self.resp_bins[~mask])

            if self.data_points is None:
                self.data_points, = self.axes[0].plot(self.wl_bins[mask], self.flux_bins[mask], 'bo', picker=True, pickradius=5)
                self.masked_data, = self.axes[0].plot(self.wl_bins[~mask], self.flux_bins[~mask], 'rx', picker=True, pickradius=5)
            else:
                self.data_points.set_data(self.wl_bins[mask], self.flux_bins[mask])
                self.masked_data.set_data(self.wl_bins[~mask], self.flux_bins[~mask])

        if self.response is not None:
            if self.fit_line is None:
                self.fit_line, = self.axes[1].plot(self.spectrum.wl, self.response,
                                                   color='Crimson', lw=1.5, alpha=0.8)
            else:
                self.fit_line.set_data(self.spectrum.wl, self.response)
        self.canvas_points.draw()


    def fit_response(self):
        if self.resp_bins is None:
            WarningDialog(self, "No response data!", "No response data to fit.\nMake sure to load a spectrum and reference star data.")
            return
        wl = self.spectrum.wl
        order = int(self.order_edit.text())
        smoothing = float(self.smooth_edit.text())
        mask = self.mask
        # resp_fit = Chebyshev.fit(self.wl_bins[mask], self.resp_bins[mask], order, domain=[wl.min(), wl.max()])
        resp_fit = UnivariateSpline(self.wl_bins[mask], self.resp_bins[mask], k=order, s=smoothing)
        self.response = resp_fit(wl)
        self.update_points()


    def pick_points(self, event):
        x0 = event.mouseevent.xdata
        y0 = event.mouseevent.ydata
        is_left_press = event.mouseevent.button == 1
        is_right_press = event.mouseevent.button == 3
        is_on = (event.artist is self.data_points) or (event.artist is self.response_points)
        is_off = (event.artist is self.masked_data) or (event.artist is self.masked_response)
        is_data = (event.artist is self.masked_data) or (event.artist is self.data_points)
        if is_data:
            xrange = self.wl_bins.max() - self.wl_bins.min()
            yrange = self.flux_bins.max() - self.flux_bins.min()
            dist = (self.wl_bins - x0)**2 / xrange**2 + (self.flux_bins - y0)**2 / yrange**2
        else:
            xrange = self.wl_bins.max() - self.wl_bins.min()
            yrange = self.resp_bins.max() - self.resp_bins.min()
            dist = (self.wl_bins - x0)**2 / xrange**2 + (self.resp_bins - y0)**2 / yrange**2
        index = np.argmin(dist)
        if is_on and is_left_press:
            self.mask[index] = ~self.mask[index]
        elif is_off and is_right_press:
            self.mask[index] = ~self.mask[index]
        else:
            return
        self.update_points()


    def create_menu(self):
        load_file_action = QtWidgets.QAction("Load Spectrum", self)
        load_file_action.setShortcut("ctrl+O")
        load_file_action.triggered.connect(self.load_spectrum)

        save_1d_action = QtWidgets.QAction("Save", self)
        save_1d_action.setShortcut("ctrl+S")
        save_1d_action.triggered.connect(self.save_response)

        view_hdr_action = QtWidgets.QAction("Display Header", self)
        view_hdr_action.setShortcut("ctrl+shift+H")
        view_hdr_action.triggered.connect(self.display_header)

        main_menu = self.menuBar()
        file_menu = main_menu.addMenu("File")
        file_menu.addAction(load_file_action)
        file_menu.addAction(save_1d_action)

        view_menu = main_menu.addMenu("View")
        view_menu.addAction(view_hdr_action)


    def display_header(self):
        if self.spectrum is not None:
            HeaderViewer(self.spectrum.header, parent=self)
        else:
            msg = "No Data Loaded"
            info = "Load a spectrum first"
            WarningDialog(self, msg, info)
Example #15
0
class SampleLogsView(QSplitter):
    """Sample Logs View

    This contains a table of the logs, a plot of the currently
    selected logs, and the statistics of the selected log.
    """
    def __init__(self, presenter, parent = None, name = '', isMD=False, noExp = 0):
        super(SampleLogsView, self).__init__(parent)

        self.presenter = presenter

        self.setWindowTitle("{} sample logs".format(name))
        self.setWindowFlags(Qt.Window)

        # Create sample log table
        self.table = QTableView()
        self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.table.clicked.connect(self.presenter.clicked)
        self.table.doubleClicked.connect(self.presenter.doubleClicked)
        self.table.contextMenuEvent = self.tableMenu
        self.addWidget(self.table)

        frame_right = QFrame()
        layout_right = QVBoxLayout()

        #Add full_time and experimentinfo options
        layout_options = QHBoxLayout()

        if isMD:
            layout_options.addWidget(QLabel("Experiment Info #"))
            self.experimentInfo = QSpinBox()
            self.experimentInfo.setMaximum(noExp-1)
            self.experimentInfo.valueChanged.connect(self.presenter.changeExpInfo)
            layout_options.addWidget(self.experimentInfo)

        self.full_time = QCheckBox("Relative Time")
        self.full_time.setChecked(True)
        self.full_time.stateChanged.connect(self.presenter.plot_logs)
        layout_options.addWidget(self.full_time)
        layout_right.addLayout(layout_options)

        # Sample log plot
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setSizePolicy(QSizePolicy.Expanding,QSizePolicy.Expanding)
        self.canvas.mpl_connect('button_press_event', self.presenter.plot_clicked)
        self.ax = self.fig.add_subplot(111, projection='mantid')
        layout_right.addWidget(self.canvas)

        # Sample stats
        self.create_stats_widgets()
        layout_stats = QFormLayout()
        layout_stats.addRow('', QLabel("Log Statistics"))
        layout_stats.addRow('Min:', self.stats_widgets["minimum"])
        layout_stats.addRow('Max:', self.stats_widgets["maximum"])
        layout_stats.addRow('Mean:', self.stats_widgets["mean"])
        layout_stats.addRow('Median:', self.stats_widgets["median"])
        layout_stats.addRow('Std Dev:', self.stats_widgets["standard_deviation"])
        layout_stats.addRow('Time Avg:', self.stats_widgets["time_mean"])
        layout_stats.addRow('Time Std Dev:', self.stats_widgets["time_standard_deviation"])
        layout_stats.addRow('Duration:', self.stats_widgets["duration"])
        layout_right.addLayout(layout_stats)
        frame_right.setLayout(layout_right)

        self.addWidget(frame_right)
        self.setStretchFactor(0,1)

        self.resize(1200,800)
        self.show()

    def tableMenu(self, event):
        """Right click menu for table, can plot or print selected logs"""
        menu = QMenu(self)
        plotAction = menu.addAction("Plot selected")
        plotAction.triggered.connect(self.presenter.new_plot_logs)
        plotAction = menu.addAction("Print selected")
        plotAction.triggered.connect(self.presenter.print_selected_logs)
        menu.exec_(event.globalPos())

    def set_model(self, model):
        """Set the model onto the table"""
        self.model = model
        self.table.setModel(self.model)
        self.table.resizeColumnsToContents()
        self.table.horizontalHeader().setSectionResizeMode(2, QHeaderView.Stretch)

    def plot_selected_logs(self, ws, exp, rows):
        """Update the plot with the selected rows"""
        self.ax.clear()
        self.create_ax_by_rows(self.ax, ws, exp, rows)
        self.fig.canvas.draw()

    def new_plot_selected_logs(self, ws, exp, rows):
        """Create a new plot, in a separate window for selected rows"""
        fig, ax = plt.subplots(subplot_kw={'projection': 'mantid'})
        self.create_ax_by_rows(ax, ws, exp, rows)
        fig.show()

    def create_ax_by_rows(self, ax, ws, exp, rows):
        """Creates the plots for given rows onto axis ax"""
        for row in rows:
            log_text = self.get_row_log_name(row)
            ax.plot(ws,
                    LogName=log_text,
                    label=log_text,
                    marker='.',
                    FullTime=not self.full_time.isChecked(),
                    ExperimentInfo=exp)

        ax.set_ylabel('')
        if ax.get_legend_handles_labels()[0]:
            ax.legend()

    def get_row_log_name(self, i):
        """Returns the log name of particular row"""
        return str(self.model.item(i, 0).text())

    def get_exp(self):
        """Get set experiment info number"""
        return self.experimentInfo.value()

    def get_selected_row_indexes(self):
        """Return a list of selected row from table"""
        return [row.row() for row in self.table.selectionModel().selectedRows()]

    def set_selected_rows(self, rows):
        """Set seleceted rows in table"""
        mode = QItemSelectionModel.Select | QItemSelectionModel.Rows
        for row in rows:
            self.table.selectionModel().select(self.model.index(row, 0), mode)

    def create_stats_widgets(self):
        """Creates the statistics widgets"""
        self.stats_widgets = {"minimum": QLineEdit(),
                              "maximum": QLineEdit(),
                              "mean": QLineEdit(),
                              "median": QLineEdit(),
                              "standard_deviation": QLineEdit(),
                              "time_mean": QLineEdit(),
                              "time_standard_deviation": QLineEdit(),
                              "duration": QLineEdit()}
        for widget in self.stats_widgets.values():
            widget.setReadOnly(True)

    def set_statistics(self, stats):
        """Updates the statistics widgets from stats dictionary"""
        for param in self.stats_widgets.keys():
            self.stats_widgets[param].setText('{:.6}'.format(getattr(stats, param)))

    def clear_statistics(self):
        """Clears the values in statistics widgets"""
        for widget in self.stats_widgets.values():
            widget.clear()
Example #16
0
class MRIPlotWidget(QtWidgets.QWidget):

    #class PlotWidget(QtWidgets.QWidget):

    def __init__(self, parent=None, showToolbar=True, imageData=None):

        super().__init__(parent)
        self.fig, self.ax = plt.subplots()
        #        fig =Figure(figsize=(3, 5))
        self.fig.set_tight_layout(True)
        self.plot_canvas = FigureCanvas(self.fig)
        #        self.ax = self.fig.add_subplot(111)

        #        mplcursors.cursor(fig,hover=True)

        self.layout = QtWidgets.QVBoxLayout(self)

        #    def __init__( self, parent=None, showToolbar=True, imageData=None):

        self.axesList = []
        self.imageData = imageData

        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed,
                                           QtWidgets.QSizePolicy.Fixed)

        self.toggleImage = QtWidgets.QRadioButton("Hide background Image")
        self.toggleImage.toggled.connect(
            lambda: self.toggleImageChanged(self.toggleImage))

        self.toggleImage.isChecked()

        self.layout.addWidget(self.toggleImage)
        self.toggleImage.setSizePolicy(sizePolicy)

        self.sliceLabel = QtWidgets.QLabel("slices")
        self.layout.addWidget(self.sliceLabel)
        self.sliceLabel.setSizePolicy(sizePolicy)

        self.slicesSlider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.slicesSlider.setMinimum(0)
        self.slicesSlider.setMaximum(4)
        self.slicesSlider.setValue(0)
        self.slicesSlider.setTickPosition(QtWidgets.QSlider.TicksBelow)
        self.slicesSlider.setTickInterval(1)
        self.slicesSlider.valueChanged.connect(self.valuechangedSlider)

        self.slicesSlider.setSizePolicy(
            QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding,
                                  QtWidgets.QSizePolicy.Fixed))
        self.layout.addWidget(self.slicesSlider)

        self.echoesLabel = QtWidgets.QLabel("echoes")
        self.echoesLabel.setSizePolicy(sizePolicy)
        self.layout.addWidget(self.echoesLabel)

        self.echoesSlider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.echoesSlider.setMinimum(0)
        self.echoesSlider.setMaximum(16)
        self.echoesSlider.setValue(0)
        self.echoesSlider.setTickPosition(QtWidgets.QSlider.TicksBelow)
        self.echoesSlider.setTickInterval(1)
        self.echoesSlider.valueChanged.connect(self.valuechangedSlider)

        self.echoesSlider.setSizePolicy(
            QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding,
                                  QtWidgets.QSizePolicy.Fixed))
        self.layout.addWidget(self.echoesSlider)

        self.layout.addWidget(self.plot_canvas)

        if showToolbar:
            self.toolbar = NavigationToolbar(self.plot_canvas, self)
            self.layout.addWidget(self.toolbar)

        self.setSizePolicy(QtWidgets.QSizePolicy.Expanding,
                           QtWidgets.QSizePolicy.Expanding)
        self.updateGeometry()

        self.plot_canvas.mpl_connect('button_press_event', self.onclick)
        #        self.plot_canvas.mpl_connect("motion_notify_event", self.onhover)

        self.ax.imshow(matplotlib.image.imread('vision.png')[:, :, 0])
        #        self.canvas.figure.axes
        #        self.mpl_cursor = mplcursors.cursor(self.plot_canvas.figure.axes,hover=True)
        self.ax.grid(False)

    def valuechangedSlider(self):
        slice_ = self.slicesSlider.value()
        echo = self.echoesSlider.value()

        self.imageData.currentSlice = slice_
        self.imageData.currentEcho = echo
        print("slicesSlider Value =", slice_, "echoesSlider Value =", echo)
        if isinstance(self.imageData.ImageDataT2, np.ndarray):
            print("updating image slice")
            if self.toggleImage.isChecked():
                self.imageData.mriSliceIMG *= 0.0
            else:
                self.imageData.mriSiceIMG = self.imageData.ImageDataT2[:, :,
                                                                       slice_,
                                                                       echo].copy(
                                                                       )

            self.imageData.overlayRoisOnImage(slice_ + 1,
                                              self.imageData.fittingParam)
            self.update_plot(
                self.imageData.mriSiceIMG,
                self.imageData.maskedROIs.reshape(
                    self.imageData.mriSiceIMG.shape))

            self.histPlotWidget.update_plot([
                slice_ + 1, self.imageData.T2slices, self.imageData.dixonSlices
            ], [
                self.imageData.t2_data_summary_df,
                self.imageData.dixon_data_summary_df
            ], self.imageData.fittingParam)

            self.barPlotWidget.update_plot([
                slice_ + 1, self.imageData.T2slices, self.imageData.dixonSlices
            ], [
                self.imageData.t2_data_summary_df,
                self.imageData.dixon_data_summary_df
            ], self.imageData.fittingParam)
        else:
            print("No images to update")

    def on_fittingParams_rbtn_toggled(self, fittingParam):

        #        rb = self.fittingParams_rbtn.sender()
        print(fittingParam)
        self.imageData.fittingParam = fittingParam
        self.valuechangedSlider()

    def register_PlotWidgets(self, T2PlotWidget, histPlotWidget, barPlotWidget,
                             radioButtonsWidget):

        self.T2PlotWidget = T2PlotWidget
        self.histPlotWidget = histPlotWidget
        self.barPlotWidget = barPlotWidget
        self.radioButtonsWidget = radioButtonsWidget

#    def onhover(self,event):
#
#        if event.inaxes:
#
#            xcoord = int(round(event.xdata))
#            ycoord = int(round(event.ydata))
#
#            print('on hover, ', xcoord, ycoord)

    def onclick(self, event):

        xcoord = int(round(event.xdata))
        ycoord = int(round(event.ydata))

        print("MRI Plot window On Click")

        print('ycoord =', ycoord)

        print(type(self.imageData.ImageDataT2))

        if type(self.imageData.ImageDataT2) != type(None):

            image_shape = self.imageData.ImageDataT2.shape

            print(image_shape[0], image_shape[0] - ycoord, ycoord)

            t2data = self.imageData.ImageDataT2[
                ycoord, xcoord,
                int(self.slicesSlider.value()), :]

            self.T2PlotWidget.update_plot(xcoord, ycoord, t2data)

    def update_plot(self, img, maskedROIs):

        self.ax.cla()
        self.ax.imshow(img, cmap=plt.cm.gray, interpolation='nearest')

        print("maskedROIs.shape", maskedROIs.shape)
        print("img.shape", img.shape)

        print("maskedROIs.max()", maskedROIs.max())

        if maskedROIs.max() > 0:

            self.ax.imshow(maskedROIs.reshape(img.shape),
                           cmap=plt.cm.jet,
                           alpha=.5,
                           interpolation='bilinear')

        mpl_cursor = mplcursors.cursor(self.plot_canvas.figure.axes,
                                       hover=True)

        @mpl_cursor.connect("add")
        def _(sel):

            ann = sel.annotation
            ttt = ann.get_text()
            xc, yc, zl = [s.split('=') for s in ttt.splitlines()]

            x = round(float(xc[1]))
            y = round(float(yc[1]))

            print("x", x, "y", y)

            nrows, ncols = img.shape
            cslice = self.imageData.currentSlice
            fitParam = self.imageData.fittingParam

            print("cslice", cslice, "nrows", nrows, "ncols")
            print("fitParam", fitParam)

            ### figure out which data set to use

            slice_df = None

            if fitParam in self.imageData.t2_data_summary_df.columns:
                print(fitParam, "T2 dataFrame chosen")
                data_df = self.imageData.t2_data_summary_df
                slice_df = data_df[data_df.slice == cslice + 1]
            elif fitParam in self.imageData.dixon_data_summary_df.columns:
                print(fitParam, "Dixon dataFrame chosen")
                data_df = self.imageData.dixon_data_summary_df
                if cslice + 1 in self.imageData.T2slices:
                    dixonSliceIndex = self.imageData.dixonSlices[
                        self.imageData.T2slices.index(cslice + 1)]
                    slice_df = data_df[data_df.slice == dixonSliceIndex]
                else:
                    slice_df = data_df[data_df.slice == cslice]

            ### return current slice

#            slice_df = data_df[data_df.slice==cslice+1]

            roiList = []
            valueList = []

            if not isinstance(slice_df, type(None)):
                print("type(slice_df)", type(slice_df))

                print("slice_df.shape", slice_df.shape)

                roiList = slice_df[slice_df['pixel_index'] == y * ncols +
                                   x]['roi'].values
                valueList = slice_df[slice_df['pixel_index'] == y * ncols +
                                     x][fitParam].values

                print("roiList", roiList)
                print("valueList", valueList)

                fitParamLabel = parameterNames[fitParam][1]

            if len(roiList) > 0:
                roi = roiList[0]
                value = valueList[0]
                ann.set_text(fitParamLabel.format(roi, value))
            else:
                ann.set_text("x = {:d}\ny = {:d}".format(x, y))

        self.ax.grid(False)

        self.plot_canvas.draw()

    def toggleImageChanged(self, b1):

        print("Entered toggleImageChanged")
        if not isinstance(self.imageData.mriSliceIMG, type(None)):
            if self.toggleImage.isChecked():
                print("Clear background image")
                self.update_plot(
                    np.zeros((self.imageData.mriSliceIMG.shape)),
                    self.imageData.maskedROIs.reshape(
                        (self.imageData.mriSliceIMG.shape)))
            else:
                self.valuechangedSlider()
Example #17
0
class ApplicationWindow(QtWidgets.QMainWindow):
    def __init__(self):
        super().__init__()
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        self.main_layout = QtWidgets.QHBoxLayout(self._main)
        self.init_left_space(self.main_layout)
        self.init_middle_space(self.main_layout)
        self.init_panel_space(self.main_layout)
        self.init_data()
        self.cmap = cm.gray

    def init_data(self):
        self.data_list = []

    def init_left_space(self, layout):
        self.axis_can = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.axis_can)
        #self.addToolBar(NavigationToolbar(static_canvas, self))

        self.axis_axe = self.axis_can.figure.subplots()
        #t = np.linspace(0, 10, 501)
        #self.axis_axe.plot(t, np.tan(t), ".")

        self.axis_can.mpl_connect('button_press_event', self.on_axis_clicked)
        self.axis_can.mpl_connect('scroll_event', self.on_axis_scroll)

    def on_axis_clicked(self, event):
        self.current_slicer.view_to_data_z((event.xdata, event.ydata))
        self.update_cron()
        self.update_sagi()

    def on_cron_clicked(self, event):
        self.current_slicer.view_to_data_y((event.xdata, event.ydata))
        self.update_axis()
        self.update_sagi()

    def on_sagi_clicked(self, event):
        self.current_slicer.view_to_data_x((event.xdata, event.ydata))
        self.update_axis()
        self.update_cron()

    def on_axis_scroll(self, event):
        if event.button == 'up':
            self.current_slicer.data_indice[0] += 1
        else:
            self.current_slicer.data_indice[0] -= 1
        self.update_axis()

    def on_cron_scroll(self, event):
        if event.button == 'up':
            self.current_slicer.data_indice[1] += 1
        else:
            self.current_slicer.data_indice[1] -= 1
        self.update_cron()

    def on_sagi_scroll(self, event):
        if event.button == 'up':
            self.current_slicer.data_indice[2] += 1
        else:
            self.current_slicer.data_indice[2] -= 1
        self.update_sagi()

    def init_middle_space(self, layout):
        self.middle_layout = QtWidgets.QVBoxLayout()
        layout.addLayout(self.middle_layout)

        self.cron_can = FigureCanvas(Figure(figsize=(5, 3)))
        self.cron_axe = self.cron_can.figure.subplots()
        #self.cron_axe.plot(t, np.tan(t), ".")
        #t = np.linspace(0, 10, 501)
        self.middle_layout.addWidget(self.cron_can)

        self.sagi_can = FigureCanvas(Figure(figsize=(5, 3)))
        self.sagi_axe = self.sagi_can.figure.subplots()
        #t = np.linspace(0, 10, 501)
        #self.sagi_axe.plot(t, np.tan(t), ".")

        self.middle_layout.addWidget(self.sagi_can)

        self.cron_can.mpl_connect('button_press_event', self.on_cron_clicked)
        self.cron_can.mpl_connect('scroll_event', self.on_cron_scroll)
        self.sagi_can.mpl_connect('button_press_event', self.on_sagi_clicked)
        self.sagi_can.mpl_connect('scroll_event', self.on_sagi_scroll)

    def init_panel_space(self, layout):
        topWidget = QtWidgets.QWidget()
        topWidget.setFixedWidth(100)
        layout.addWidget(topWidget)
        self.panel_layout = QtWidgets.QVBoxLayout(topWidget)

        load_img_btn = QtWidgets.QPushButton('Load NII')
        self.panel_layout.addWidget(load_img_btn, 1)
        load_img_btn.clicked.connect(self.on_img_btn_clicked)

        load_label_btn = QtWidgets.QPushButton('Load Label')
        self.panel_layout.addWidget(load_label_btn, 1)
        load_label_btn.clicked.connect(self.on_lbl_btn_clicked)

        #self.addToolBar(QtCore.Qt.BottomToolBarArea,
        #                NavigationToolbar(dynamic_canvas, self))

    def on_img_btn_clicked(self):
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        files, _ = QtWidgets.QFileDialog.getOpenFileNames(
            self,
            "QFileDialog.getOpenFileNames()",
            "",
            "All Files (*);;Python Files (*.py)",
            options=options)
        if files is None or len(files) < 1:
            return
        print(files)
        data_nii = nib.load(files[0])
        self.current_slicer = Slicer(data_nii)

        self.update_axis()
        self.update_sagi()
        self.update_cron()

    def on_lbl_btn_clicked(self):
        files, _ = QtWidgets.QFileDialog.getOpenFileNames(
            self, "QFileDialog.getOpenFileNames()", "",
            "All Files (*);;NII Files (*.nii);(*.nii.gz)")
        if files is None or len(files) < 1:
            return
        print(files)
        data_nii = nib.load(files[0])
        self.current_slicer.load_label(data_nii)

        self.update_axis()
        self.update_sagi()
        self.update_cron()

    def update_axis(self):
        self.axis_axe.clear()
        self.current_slicer.draw_z_sec(self.axis_axe)
        self.axis_axe.figure.canvas.draw()

    def update_cron(self):
        self.cron_axe.clear()
        self.current_slicer.draw_y_sec(self.cron_axe)
        self.cron_axe.figure.canvas.draw()

    def update_sagi(self):
        self.sagi_axe.clear()
        self.current_slicer.draw_x_sec(self.sagi_axe)
        self.sagi_axe.figure.canvas.draw()
Example #18
0
class PlotConfusion(QWidget):

    #------------------------------------
    # Constructor PlotConfusion
    #-------------------

    def __init__(self, parent, file):

        # axis_labels = ['AMADEC', 'ARRAUR','CORALT','DYSMEN', 'EUPIMI','HENLES','HYLDEC','LOPPIT', 'TANGYR', 'TANICT']
        self.axis_labels = [
            'AMADEC_CALL', 'AMADEC_SONG', 'ARRAUR_CALL', 'ARRAUR_SONG',
            'CORALT_CALL', 'CORALT_SONG', 'DYSMEN_CALL', 'DYSMEN_SONG',
            'EUPIMI_CALL', 'EUPIMI_SONG', 'HENLES_CALL', 'HENLES_SONG',
            'HYLDEC_CALL', 'HYLDEC_SONG', 'LOPPIT_CALL', 'LOPPIT_SONG',
            'TANGYR_CALL', 'TANGYR_SONG', 'TANICT_CALL', 'TANICT_SONG'
        ]

        super(QWidget, self).__init__(parent)
        self.figure = plt.figure(figsize=(10, 5))
        self.resize(400, 400)

        self.canvas = FigureCanvas(self.figure)
        self.xint = -1
        self.yint = -1
        self.file = file

        ax = self.figure.add_subplot(111)
        ax.set_title('Confusion Matrix on Last Epoch')
        # ax = sns.heatmap(file.getConfusion(), xticklabels=self.axis_labels, yticklabels=axis_labels, center=10, vmax=20)
        ax = sns.heatmap(file.getNormalized(),
                         xticklabels=self.axis_labels,
                         yticklabels=self.axis_labels,
                         center=0.45)
        ax.set_xlabel('actual species')
        ax.set_ylabel('predicted species')

        ax.tick_params(axis='x', labelrotation=90)

        self.canvas.draw()

        layout = QVBoxLayout()
        layout.addWidget(self.canvas)
        self.setLayout(layout)
        self.cid = self.canvas.mpl_connect("motion_notify_event",
                                           self.onMotion)
        self.heatmap = ax

    #------------------------------------
    # onMotion
    #-------------------

    def onMotion(self, event):
        if not event.inaxes:
            self.xint = -1
            self.yint = -1
            return

        self.xint = int(event.xdata)
        self.yint = int(event.ydata)
        self.rect = mpatches.Rectangle((self.xint, self.yint),
                                       1,
                                       1,
                                       fill=False,
                                       linestyle='dashed',
                                       edgecolor='red',
                                       linewidth=2.0)

        self.heatmap.add_patch(self.rect)
        self.canvas.draw()
        self.rect.remove()
Example #19
0
def show(time, data, com, com_dot, com_ddot, com_i, grf, angles, stick):
    qapp = QtWidgets.QApplication(sys.argv)
    app = QtWidgets.QMainWindow()
    app.setWindowTitle("Analyse biomécanique de Kinovea")

    _main = QtWidgets.QWidget()
    app.setCentralWidget(_main)
    main_layout = QtWidgets.QHBoxLayout(_main)

    # Body position column
    body_position_layout = QtWidgets.QVBoxLayout()
    main_layout.addLayout(body_position_layout)

    # Show model
    body_position_canvas = FigureCanvas(Figure(figsize=(5, 3)))
    body_position_layout.addWidget(body_position_canvas)
    body_position_ax = body_position_canvas.figure.subplots()

    time_idx = 0
    body_position_ax.set_ylabel("Axe vertical (m)")
    body_position_ax.set_xlabel("Axe frontal (m)")
    body_position_text = body_position_ax.text(
        0.5, 0.99, "", fontsize=12, horizontalalignment='center', verticalalignment='top', transform=body_position_ax.transAxes
    )

    kino_n_image = 5
    kino_pre_plot = []
    kino_post_plot = []
    kino_colors = np.linspace(0.88, 0, kino_n_image)
    for i in range(kino_n_image):
        i2 = kino_n_image - 1 - i
        kino_pre_plot.append(
            body_position_ax.plot(np.nan, np.nan, color=[kino_colors[i2], kino_colors[i2], kino_colors[i2]])
        )
        kino_post_plot.append(
            body_position_ax.plot(np.nan, np.nan, color=[kino_colors[i2], kino_colors[i2], kino_colors[i2]])
        )

    stick_plot = body_position_ax.plot(np.nan, np.nan, 'r')
    comi_plot = body_position_ax.plot(np.nan, np.nan, 'k.')
    com_plot = body_position_ax.plot(np.nan, np.nan, 'k.', markersize=20)

    def move_stick_figure(time_idx):
        body_position_ax.set_title(f"Position du corps à l'instant {time[time_idx]:.2f} s")

        body_position_text.set_text(f"CoM = [{str(np.round(com[0, 0, time_idx], 2))}; {str(np.round(com[1, 0, time_idx], 2))}]")

        for i in range(kino_n_image):
            if time_idx - i - 1 >= 0 and kino_pre_check.isChecked():
                kino_pre_plot[i][0].set_data(data[0, stick, time_idx - i - 1], data[1, stick, time_idx - i - 1])
            else:
                kino_pre_plot[i][0].set_data(np.nan, np.nan)
            if time_idx + i + 1 < data.shape[2] and kino_post_check.isChecked():
                kino_post_plot[i][0].set_data(data[0, stick, time_idx + i + 1], data[1, stick, time_idx + i + 1])
            else:
                kino_post_plot[i][0].set_data(np.nan, np.nan)

        stick_plot[0].set_data(data[0, stick, time_idx], data[1, stick, time_idx])
        comi_plot[0].set_data(com_i[0, :, time_idx], com_i[1, :, time_idx])
        com_plot[0].set_data(com[0, 0, time_idx], com[1, 0, time_idx])

    # Force axis equal with min and max data
    body_position_ax.plot(
        [np.min(data[0, :, :]), np.max(data[0, :, :])],
        [np.min(data[1, :, :]) - (np.max(data[1, :, :]) - np.min(data[1, :, :]))*0.1 ,
                              np.max(data[1, :, :]) + (np.max(data[1, :, :]) - np.min(data[1, :, :]))*0.1 ], 'w.'
    )
    body_position_ax.axis('equal')

    time_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
    body_position_layout.addWidget(time_slider)

    kinogram_layout = QtWidgets.QHBoxLayout()
    body_position_layout.addLayout(kinogram_layout)
    kino_pre_check = QtWidgets.QCheckBox()
    kino_pre_check.setText("Kinogramme pre")
    kinogram_layout.addWidget(kino_pre_check)
    kino_post_check = QtWidgets.QCheckBox()
    kino_post_check.setText("Kinogramme post")
    kinogram_layout.addWidget(kino_post_check)

    # Trajectory column
    trajectory_canvas = FigureCanvas(Figure(figsize=(5, 3)))
    main_layout.addWidget(trajectory_canvas)
    buffer = (time[-1]-time[0])*0.005
    xlim = (time[0]-buffer, time[-1]+buffer)

    ax_height = trajectory_canvas.figure.add_subplot(411)
    ax_height.set_title("Hauteur du CoM")
    ax_height.set_ylabel("Hauteur (m)")
    ax_height.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    ax_height.plot(time, com[1, 0, :])
    ylim = ax_height.get_ylim()
    height_vbar = ax_height.plot((np.nan, np.nan), ylim, 'r')
    ax_height.set_xlim(xlim)
    ax_height.set_ylim(ylim)

    ax_velocity = trajectory_canvas.figure.add_subplot(412)
    ax_velocity.set_title("Vitesse verticale")
    ax_velocity.set_ylabel("Vitesse (m/s)")
    ax_velocity.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    ax_velocity.plot(time, com_dot[1, 0, :])
    ylim = ax_velocity.get_ylim()
    velocity_vbar = ax_velocity.plot((np.nan, np.nan), ylim, 'r')
    ax_velocity.set_xlim(xlim)
    ax_velocity.set_ylim(ylim)

    ax_acceleration = trajectory_canvas.figure.add_subplot(413)
    ax_acceleration.set_title("Accélération verticale")
    ax_acceleration.set_ylabel("Accélération (m/s²)")
    ax_acceleration.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    ax_acceleration.plot(time, com_ddot[1, 0, :])
    ylim = ax_acceleration.get_ylim()
    acceleration_vbar = ax_acceleration.plot((np.nan, np.nan), ylim, 'r')
    ax_acceleration.set_xlim(xlim)
    ax_acceleration.set_ylim(ylim)

    ax_grf = trajectory_canvas.figure.add_subplot(414)
    ax_grf.set_title("GRF")
    ax_grf.set_ylabel("GRF (N)")
    ax_grf.set_xlabel("Temps (s)")
    ax_grf.plot(time, grf[1, 0, :])
    ylim = ax_grf.get_ylim()
    grf_vbar = ax_grf.plot((np.nan, np.nan), ylim, 'r')
    ax_grf.set_xlim(xlim)
    ax_grf.set_ylim(ylim)

    trajectory_canvas.figure.tight_layout(h_pad=-0.5, w_pad=-6)

    # Angles column
    angles_canvas = FigureCanvas(Figure(figsize=(5, 3)))
    main_layout.addWidget(angles_canvas)
    ax_angles = angles_canvas.figure.subplots()
    ax_angles.set_title("Angles articulaire au cours du temps")
    ax_angles.set_ylabel("Angle (°)")
    ax_angles.set_xlabel("Temps (s)")
    for joint in angles.values():
        ax_angles.plot(time, KinoveaReader.to_degree(joint))
    ylim = ax_angles.get_ylim()
    angles_vbar = ax_angles.plot((np.nan, np.nan), ylim, 'r')
    ax_angles.set_xlim(xlim)
    ax_angles.set_ylim(ylim)
    ax_angles.legend(angles.keys())

    def change_time():
        time_idx = time_slider.value()

        move_stick_figure(time_idx)
        body_position_canvas.draw()

        height_vbar[0].set_xdata([time[time_idx], time[time_idx]])
        velocity_vbar[0].set_xdata([time[time_idx], time[time_idx]])
        acceleration_vbar[0].set_xdata([time[time_idx], time[time_idx]])
        grf_vbar[0].set_xdata([time[time_idx], time[time_idx]])
        trajectory_canvas.draw()

        angles_vbar[0].set_xdata([time[time_idx], time[time_idx]])
        angles_canvas.draw()

    time_slider.setMinimum(0)
    time_slider.setMaximum(time.shape[0] - 1)
    time_slider.setPageStep(1)
    time_slider.setValue(0)
    time_slider.valueChanged.connect(change_time)
    body_position_canvas.mpl_connect(body_position_canvas.resize_event, change_time)
    kino_pre_check.stateChanged.connect(change_time)
    kino_post_check.stateChanged.connect(change_time)

    # app.showMaximized()
    change_time()
    app.show()
    qapp.exec_()
Example #20
0
class ProgressViewer(QtWidgets.QMainWindow):
    '''GUI to track progress of EMC reconstruction
    Shows orthogonal volumes slices, plots of metrics vs iteration and log file
    Can periodically poll log file for updates and automatically update plots

    Can also be used to view slices through other 3D volumes using the '-f' option
    '''
    def __init__(self, config='config.ini', model=None):
        super(ProgressViewer, self).__init__()
        self.config = config
        self.model_name = model
        self.max_iternum = 0
        plt.style.use('dark_background')

        self.beta_change = self.num_rot_change = []
        self.checker = QtCore.QTimer(self)

        self._read_config(config)
        self._init_ui()
        if model is not None:
            self._parse_and_plot()
        self.old_fname = self.fname.text()

    def _init_ui(self):
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'py_src/style.css'), 'r') as f:
            self.setStyleSheet(f.read())
        self.setWindowTitle('Dragonfly Progress Viewer')
        self.setGeometry(100, 100, 1600, 800)
        overall = QtWidgets.QWidget()
        self.setCentralWidget(overall)
        layout = QtWidgets.QHBoxLayout(overall)
        layout.setContentsMargins(0, 0, 0, 0)

        self._init_menubar()
        plot_splitter = self._init_plotarea()
        options_widget = self._init_optionsarea()

        main_splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal)
        main_splitter.setObjectName('frame')
        layout.addWidget(main_splitter)
        main_splitter.addWidget(plot_splitter)
        main_splitter.addWidget(options_widget)

        self.show()

    def _init_menubar(self):
        menubar = self.menuBar()
        menubar.setNativeMenuBar(False)

        # File Menu
        filemenu = menubar.addMenu('&File')
        action = QtWidgets.QAction('&Load Volume', self)
        action.triggered.connect(self._load_volume)
        filemenu.addAction(action)
        action = QtWidgets.QAction('&Save Image', self)
        action.triggered.connect(self._save_plot)
        filemenu.addAction(action)
        action = QtWidgets.QAction('Save Log &Plot', self)
        action.triggered.connect(self._save_log_plot)
        filemenu.addAction(action)
        action = QtWidgets.QAction('&Quit', self)
        action.triggered.connect(self.close)
        filemenu.addAction(action)

        # Color map picker
        cmapmenu = menubar.addMenu('&Color Map')
        self.color_map = QtWidgets.QActionGroup(self, exclusive=True)
        for i, cmap in enumerate(['coolwarm', 'cubehelix', 'CMRmap', 'gray', 'gray_r', 'jet']):
            action = self.color_map.addAction(QtWidgets.QAction(cmap, self, checkable=True))
            if i == 0:
                action.setChecked(True)
            action.triggered.connect(self._cmap_changed)
            cmapmenu.addAction(action)

    def _init_plotarea(self):
        plot_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical)
        plot_splitter.setObjectName('plots')

        # Volume slices figure
        self.fig = matplotlib.figure.Figure(figsize=(14, 5))
        self.fig.subplots_adjust(left=0.0, bottom=0.00, right=0.99, wspace=0.0)
        #self.fig.set_facecolor('#232629')
        self.fig.set_facecolor('#112244')
        self.canvas = FigureCanvas(self.fig)
        self.canvas.show()
        plot_splitter.addWidget(self.canvas)
        self.vol_plotter = VolumePlotter(self.fig, self.recon_type, self.num_modes)
        self.need_replot = self.vol_plotter.need_replot

        # Progress plots figure
        self.log_fig = matplotlib.figure.Figure(figsize=(14, 5), facecolor='w')
        #self.log_fig.set_facecolor('#232629')
        self.log_fig.set_facecolor('#112244')
        self.plotcanvas = FigureCanvas(self.log_fig)
        self.plotcanvas.show()
        plot_splitter.addWidget(self.plotcanvas)
        self.log_plotter = LogPlotter(self.log_fig, self.folder)
        
        return plot_splitter

    def _init_optionsarea(self):
        options_widget = QtWidgets.QWidget()
        vbox = QtWidgets.QVBoxLayout()
        options_widget.setLayout(vbox)

        # -- Log file
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('Log file name:', self)
        hbox.addWidget(label)
        self.logfname = QtWidgets.QLineEdit(self.logfname, self)
        self.logfname.setMinimumWidth(160)
        hbox.addWidget(self.logfname)
        label = QtWidgets.QLabel('VRange:', self)
        hbox.addWidget(label)
        self.rangemin = QtWidgets.QLineEdit('0', self)
        self.rangemin.setFixedWidth(48)
        self.rangemin.returnPressed.connect(self._range_changed)
        hbox.addWidget(self.rangemin)
        self.rangestr = QtWidgets.QLineEdit('1', self)
        self.rangestr.setFixedWidth(48)
        self.rangestr.returnPressed.connect(self._range_changed)
        hbox.addWidget(self.rangestr)

        # -- Volume file
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('File name:', self)
        hbox.addWidget(label)
        if self.model_name is None:
            self.fname = QtWidgets.QLineEdit(self.folder+'/output/intens_001.bin', self)
        else:
            self.fname = QtWidgets.QLineEdit(self.model_name, self)
        self.fname.setMinimumWidth(160)
        hbox.addWidget(self.fname)
        label = QtWidgets.QLabel('Exp:', self)
        hbox.addWidget(label)
        self.expstr = QtWidgets.QLineEdit('1', self)
        self.expstr.setFixedWidth(48)
        self.expstr.returnPressed.connect(self._range_changed)
        hbox.addWidget(self.expstr)

        # -- Sliders
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('Layer num.', self)
        hbox.addWidget(label)
        self.layer_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
        self.layer_slider.setRange(0, 200)
        self.layer_slider.sliderMoved.connect(self._layerslider_moved)
        self.layer_slider.sliderReleased.connect(self._layernum_changed)
        hbox.addWidget(self.layer_slider)
        self.layernum = MySpinBox(self)
        self.layernum.setValue(self.layer_slider.value())
        self.layernum.setMinimum(0)
        self.layernum.setMaximum(200)
        self.layernum.valueChanged.connect(self._layernum_changed)
        self.layernum.editingFinished.connect(self._layernum_changed)
        self.layernum.setFixedWidth(48)
        hbox.addWidget(self.layernum)
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('Iteration', self)
        hbox.addWidget(label)
        self.iter_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
        self.iter_slider.setRange(0, 1)
        self.iter_slider.sliderMoved.connect(self._iterslider_moved)
        self.iter_slider.sliderReleased.connect(self._iternum_changed)
        hbox.addWidget(self.iter_slider)
        self.iternum = MySpinBox(self)
        self.iternum.setValue(self.iter_slider.value())
        self.iternum.setMinimum(0)
        self.iternum.setMaximum(1)
        self.iternum.valueChanged.connect(self._iternum_changed)
        self.iternum.editingFinished.connect(self._iternum_changed)
        self.iternum.setFixedWidth(48)
        hbox.addWidget(self.iternum)
        if self.num_modes > 1:
            hbox = QtWidgets.QHBoxLayout()
            vbox.addLayout(hbox)
            label = QtWidgets.QLabel('Mode', self)
            hbox.addWidget(label)
            self.mode_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
            self.mode_slider.setRange(0, self.num_modes-1)
            self.mode_slider.sliderMoved.connect(self._modeslider_moved)
            self.mode_slider.sliderReleased.connect(self._modenum_changed)
            hbox.addWidget(self.mode_slider)
            self.modenum = MySpinBox(self)
            self.modenum.setValue(self.iter_slider.value())
            self.modenum.setMinimum(0)
            self.modenum.setMaximum(self.num_modes-1)
            self.modenum.valueChanged.connect(self._modenum_changed)
            self.modenum.editingFinished.connect(self._modenum_changed)
            self.modenum.setFixedWidth(48)
            hbox.addWidget(self.modenum)
            self.old_modenum = self.modenum.value()

        # -- Buttons
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        button = QtWidgets.QPushButton('Check', self)
        button.clicked.connect(self._check_for_new)
        hbox.addWidget(button)
        self.ifcheck = QtWidgets.QCheckBox('Keep checking', self)
        self.ifcheck.stateChanged.connect(self._keep_checking)
        self.ifcheck.setChecked(False)
        hbox.addWidget(self.ifcheck)
        hbox.addStretch(1)
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        hbox.addStretch(1)
        button = QtWidgets.QPushButton('Plot', self)
        button.clicked.connect(self._parse_and_plot)
        hbox.addWidget(button)
        button = QtWidgets.QPushButton('Reparse', self)
        button.clicked.connect(self._force_plot)
        hbox.addWidget(button)
        button = QtWidgets.QPushButton('Quit', self)
        button.clicked.connect(self.close)
        hbox.addWidget(button)

        # -- Log file display
        log_area = QtWidgets.QScrollArea(self)
        vbox.addWidget(log_area)
        log_area.setMinimumWidth(300)
        log_area.setWidgetResizable(True)
        self.emclog_text = QtWidgets.QTextEdit(
            'Press \'Check\' to synchronize with log file<br>'
            'Select \'Keep Checking\' to periodically synchronize<br><br>'
            'The top half of the display area will show three orthogonal<br>'
            'slices of the 3D volume. The bottom half will show plots of<br>'
            'various parameters vs iteration.', self)
        self.emclog_text.setReadOnly(True)
        self.emclog_text.setFontPointSize(8)
        self.emclog_text.setFontFamily('Courier')
        self.emclog_text.setFontWeight(QtGui.QFont.DemiBold)
        self.emclog_text.setTabStopWidth(22)
        self.emclog_text.setLineWrapMode(QtWidgets.QTextEdit.NoWrap)
        self.emclog_text.setObjectName('logtext')
        log_area.setWidget(self.emclog_text)

        return options_widget

    def _layernum_changed(self, value=None):
        if value is None:
            # Slider released or editing finished
            self.need_replot = True
        elif value == self.layernum.value():
            self.layer_slider.setValue(value)
        self._parse_and_plot()

    def _layerslider_moved(self, value):
        self.layernum.setValue(value)

    def _iternum_changed(self, value=None):
        if value is None:
            self.fname.setText(self.folder+'/output/intens_%.3d.bin' % self.iternum.value())
        elif value == self.iternum.value():
            self.iter_slider.setValue(value)
            if self.need_replot:
                self.fname.setText(self.folder+'/output/intens_%.3d.bin' % value)
        self._parse_and_plot()

    def _iterslider_moved(self, value):
        self.iternum.setValue(value)

    def _modenum_changed(self, value=None):
        if value == self.modenum.value():
            self.mode_slider.setValue(value)
        self._parse_and_plot()

    def _modeslider_moved(self, value):
        self.modenum.setValue(value)

    def _range_changed(self):
        self.need_replot = True

    def _read_config(self, config):
        try:
            self.folder = read_config.get_filename(config, 'emc', 'output_folder')
        except read_config.configparser.NoOptionError:
            self.folder = 'data/'

        try:
            self.logfname = read_config.get_filename(config, 'emc', 'log_file')
        except read_config.configparser.NoOptionError:
            self.logfname = 'EMC.log'

        try:
            self.recon_type = read_config.get_param(config, 'emc', 'recon_type').lower()
        except read_config.configparser.NoOptionError:
            self.recon_type = '3d'
        try:
            self.num_modes = int(read_config.get_param(config, 'emc', 'num_modes'))
        except read_config.configparser.NoOptionError:
            self.num_modes = 1

    def _update_layers(self, size, center):
        self.layer_slider.setRange(0, size-1)
        self.layernum.setMaximum(size-1)
        self.layer_slider.setValue(center)
        self._layerslider_moved(center)

    def _plot_vol(self, num=None):
        if num is None:
            num = int(self.layernum.text())
        self.vol_plotter.plot(num,
                              (float(self.rangemin.text()), 
                               float(self.rangestr.text())),
                              float(self.expstr.text()),
                              self.color_map.checkedAction().text())
        if self.recon_type == '2d':
            self.canvas.mpl_connect('button_press_event', self._select_mode)

    def _parse_and_plot(self):
        if not self.vol_plotter.image_exists or self.old_fname != self.fname.text():
            self.old_fname, size, center = self.vol_plotter.parse(self.fname.text())
            self._update_layers(size, center)
            self._plot_vol()
        elif self.num_modes > 1 and self.modenum.value() != self.old_modenum:
            self.old_fname, size, center = self.vol_plotter.parse(self.fname.text(),
                                             modenum=self.modenum.value())
            self._update_layers(size, center)
            self._plot_vol()
        elif self.need_replot:
            self._plot_vol()
        else:
            pass

    def _check_for_new(self):
        with open(self.logfname.text(), 'r') as fptr:
            last_line = fptr.readlines()[-1].rstrip().split()
        try:
            iteration = int(last_line[0])
        except ValueError:
            iteration = 0

        if iteration > 0 and self.max_iternum != iteration:
            self.fname.setText(self.folder+'/output/intens_%.3d.bin' % iteration)
            self.max_iternum = iteration
            self.iter_slider.setRange(0, self.max_iternum)
            self.iternum.setMaximum(self.max_iternum)
            self.iter_slider.setValue(iteration)
            self._iterslider_moved(iteration)
            log_text = self.log_plotter.plot(self.logfname.text(),
                 self.color_map.checkedAction().text())
            self._parse_and_plot()
            self.emclog_text.setText(log_text)

    def _keep_checking(self):
        if self.ifcheck.isChecked():
            self._check_for_new()
            self.checker.timeout.connect(self._check_for_new)
            self.checker.start(5000)
        else:
            self.checker.stop()

    def _select_mode(self, event):
        curr_mode = -1
        for i, subp in enumerate(self.vol_plotter.subplot_list):
            if event.inaxes is subp:
                curr_mode = i
        if curr_mode >= 0 and curr_mode != self.layernum.value():
            self.layer_slider.setValue(curr_mode)
            self.layernum.setValue(curr_mode)
            self._plot_vol(curr_mode)

    def _force_plot(self):
        self.old_fname, size, center = self.vol_plotter.parse(self.fname.text())
        self._update_layers(size, center)
        self._plot_vol()

    def _load_volume(self):
        fname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Load 3D Volume',
                                                         'data/', 'Binary data (*.bin)')
        if fname:
            self.fname.setText(fname)
            self._parse_and_plot()

    def _save_plot(self):
        default_name = 'images/'+os.path.splitext(os.path.basename(self.fname.text()))[0]+'.png'
        fname, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Volume Image',
                                                         default_name, 'Image (*.png)')
        if fname:
            self.fig.savefig(fname, bbox_inches='tight', dpi=120)
            sys.stderr.write('Saved to %s\n'%fname)

    def _save_log_plot(self):
        default_name = 'images/log_fig.png'
        fname, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Log Plots',
                                                         default_name, 'Image (*.png)')
        if fname:
            self.log_fig.savefig(fname, bbox_inches='tight', dpi=120)
            sys.stderr.write("Saved to %s\n"%fname)

    def _cmap_changed(self):
        if self.vol_plotter.image_exists:
            self.need_replot = True
            self._parse_and_plot()

    def keyPressEvent(self, event): # pylint: disable=C0103
        '''Override of default keyPress event handler'''
        key = event.key()
        mod = int(event.modifiers())

        if key == QtCore.Qt.Key_Return or key == QtCore.Qt.Key_Enter:
            self._parse_and_plot()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+Q'):
            self.close()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+S'):
            self._save_plot()
        else:
            event.ignore()
class SignalView(QtWidgets.QWidget):
    def __init__(self,parent=None,figure=None,data_source=None):
        super().__init__(parent=parent)
        if figure is None:
            figure = Figure(tight_layout=True)
        self.setFocusPolicy(QtCore.Qt.StrongFocus)
        self.figure = figure
        self.data_source=data_source
        self.fig_ax = self.figure.subplots()
        self.fig_canvas = FigureCanvas(self.figure)
        self.fig_canvas.draw()
        
        self.fig_toolbar = CustomNavToolbar(self.fig_canvas,self,coordinates=False)
        self.fig_toolbar.setParent(self.fig_canvas)
        self.fig_toolbar.setMinimumWidth(300)
        
        self.fig_canvas.mpl_connect("resize_event", self.resize)
        self.resize(None)
        
        self.layout = QtWidgets.QVBoxLayout(self)
        self.layout.setContentsMargins(0,0,0,0)
        self.layout.setSpacing(0)
        self.layout.addWidget(self.fig_canvas)
        
        self.toolbar_shown(False)
        
        self.legend = False
        self.raw_adc = True
        self.raw_time = True
        self.pedestal = None
        self.distribute = None
        self.fft = False
        self.selected = None
        
        self.save_props = ['legend','selected','raw_adc','raw_time','pedestal','distribute','fft']
        
        self.autoscale = True
        self.last_lims = None
        
        self.times,self.data = None,None
    
    def resize(self, event):
        x,y = self.figure.axes[0].transAxes.transform((0,0.0))
        figw, figh = self.figure.get_size_inches()
        ynew = figh*self.figure.dpi-y - self.fig_toolbar.frameGeometry().height()
        self.fig_toolbar.move(int(x),int(ynew))
        
    def focusInEvent(self, *args, **kwargs):
        super().focusInEvent(*args, **kwargs)
        self.resize(None)
        self.toolbar_shown(True)
        
    def focusOutEvent(self, *args, **kwargs):
        super().focusOutEvent(*args, **kwargs)
        self.toolbar_shown(False)
    
    def toolbar_shown(self,shown):
        if shown:
            self.fig_toolbar.show()
        else:
            self.fig_toolbar.hide()
            
    def get_state(self):
        all_props = self.__dict__
        return {prop:getattr(self,prop) for prop in self.save_props if prop in all_props}
            
    def set_state(self, state):
        all_props = self.__dict__
        for prop,val in state.items():
            if prop in all_props:
                setattr(self,prop,val)
            
    def load_data(self):
        self.times = []
        self.data = []
        self.raw_data = []
        
        if self.data_source.timestamps is None or self.data_source.samples is None or self.selected is None:
            return
        
        for sig_idx,sel in enumerate(self.selected):
            femb,adc,ch = sel
            if self.raw_time:
                times = np.arange(self.data_source.timestamps.shape[-1])
            else:
                times = self.data_source.timestamps[femb//2]
            if not self.fft:
                self.times.append(times)
            samples = self.data_source.samples[femb,adc*16+ch]
            if self.pedestal is not None:
                ped_min,ped_max = self.pedestal
                i = ped_min if ped_min > 0 else 0
                j = ped_max if ped_max < len(times) else len(times)-1
                try:
                    pedestal = np.mean(samples[i:j])
                    samples = samples - pedestal
                except:
                    print('Pedestal correction failed')
                    pass
            if self.distribute is not None:
                samples = samples + sig_idx*self.distribute
            if not self.fft:
                self.data.append(samples)
            if self.fft:
                fft = np.fft.fft(samples)
                freq = np.fft.fftfreq(len(samples),1)
                idx = np.argsort(freq)[len(freq)//2::]
                
                self.times.append(freq[idx])
                self.data.append(np.square(np.abs(fft[idx])))
        
    def select_signals(self):
        current_props = {x:self.__dict__[x] for x in self.save_props}
        selector = SignalSelector(parent=self, **current_props)
        result = selector.exec_()
        self.selected = selector.get_selected()
        self.raw_adc = selector.get_raw_adc()
        self.raw_time = selector.get_raw_time()
        self.pedestal = selector.get_pedestal()
        self.distribute = selector.get_distribute()
        self.fft = selector.get_fft()
        self.load_data()
        self.plot_signals()
        
    def plot_signals(self,rescale=False):
        ax = self.fig_ax
        if rescale:
            self.autoscale = True
        else:
            next_lims = (ax.get_xlim(), ax.get_ylim())
            self.autoscale = self.autoscale and (next_lims == self.last_lims or self.last_lims is None)
        ax.clear()
        
        if self.selected:
            if not self.times or not self.data:
                self.load_data()
            
            if not self.times or not self.data:
                return
        
            for t,v,(femb,adc,ch) in zip(self.times,self.data,self.selected):
                label = 'FEMB%i ADC%i CH%i (%i)'%(femb,adc,ch,adc*16+ch)
                ax.plot(t,v,drawstyle='steps' if not self.fft else None,label=label)
                        
        if self.fft:
            ax.set_yscale('log')
            ax.set_xlabel('Frequency (1/sample)')
            ax.set_ylabel('Power Spectrum')
        else:
            ax.set_xlabel('Sample' if self.raw_time else 'Timestamp')
            ax.set_ylabel('ADC Counts' if self.raw_adc else ('Voltage (mV)' if not self.distribute else 'Arb. Shifted Voltage (mV)'))
        if not self.autoscale:
            ax.set_xlim(*next_lims[0])
            ax.set_ylim(*next_lims[1])
        self.last_lims = (ax.get_xlim(), ax.get_ylim())
        if self.legend:
            ax.legend()
            
        ax.figure.canvas.draw()
        self.resize(None)
Example #22
0
class ProgressViewer(QtWidgets.QMainWindow):
    '''GUI to track progress of EMC reconstruction
    Shows orthogonal volumes slices, plots of metrics vs iteration and log file
    Can periodically poll log file for updates and automatically update plots

    Can also be used to view slices through other 3D volumes using the '-f' option
    '''
    def __init__(self, config='config.ini', model=None):
        super(ProgressViewer, self).__init__()
        self.config = config
        self.model_name = model
        self.max_iternum = 0
        plt.style.use('dark_background')

        self.beta_change = self.num_rot_change = []
        self.checker = QtCore.QTimer(self)

        self._read_config(config)
        self._init_ui()
        if model is not None:
            self._parse_and_plot(rots=False)
        self.old_fname = self.fname.text()
        self.fviewer = None

    def _init_ui(self):
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'py_src/style.css'), 'r') as f:
            self.setStyleSheet(f.read())
        self.setWindowTitle('Dragonfly Progress Viewer')
        self.setGeometry(100, 100, 1600, 800)
        overall = QtWidgets.QWidget()
        self.setCentralWidget(overall)
        layout = QtWidgets.QHBoxLayout(overall)
        layout.setContentsMargins(0, 0, 0, 0)

        self._init_menubar()
        plot_splitter = self._init_plotarea()
        options_widget = self._init_optionsarea()

        main_splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal)
        main_splitter.setObjectName('frame')
        layout.addWidget(main_splitter)
        main_splitter.addWidget(plot_splitter)
        main_splitter.addWidget(options_widget)

        self.show()

    def _init_menubar(self):
        menubar = self.menuBar()
        menubar.setNativeMenuBar(False)

        # File Menu
        filemenu = menubar.addMenu('&File')
        action = QtWidgets.QAction('&Load Volume', self)
        action.triggered.connect(self._load_volume)
        action.setToolTip('Load 3D volume (h5 or bin)')
        filemenu.addAction(action)
        action = QtWidgets.QAction('&Quit', self)
        action.triggered.connect(self.close)
        filemenu.addAction(action)

        # Image Menu
        imagemenu = menubar.addMenu('&Image')
        action = QtWidgets.QAction('&Save Slices Image', self)
        action.triggered.connect(self._save_plot)
        action.setToolTip('Save current plot of slices as image')
        imagemenu.addAction(action)
        action = QtWidgets.QAction('Save Log &Plot', self)
        action.triggered.connect(self._save_log_plot)
        action.setToolTip('Save panel of metrics plots as image')
        imagemenu.addAction(action)
        action = QtWidgets.QAction('Save &Layer Movie', self)
        action.triggered.connect(self._save_layer_movie)
        action.setToolTip('Save slices plot animation as a function of layer')
        imagemenu.addAction(action)
        action = QtWidgets.QAction('Save &Iteration Movie', self)
        action.triggered.connect(self._save_iter_movie)
        action.setToolTip('Save slices plot animation as a function of iteration')
        imagemenu.addAction(action)
        
        # -- Color map picker
        cmapmenu = imagemenu.addMenu('&Color Map')
        self.color_map = QtWidgets.QActionGroup(self, exclusive=True)
        for i, cmap in enumerate(['coolwarm', 'cubehelix', 'CMRmap', 'gray', 'gray_r', 'jet']):
            action = self.color_map.addAction(QtWidgets.QAction(cmap, self, checkable=True))
            if i == 0:
                action.setChecked(True)
            action.triggered.connect(self._cmap_changed)
            action.setToolTip('Set color map')
            cmapmenu.addAction(action)

        # Analysis menu
        analysismenu = menubar.addMenu('&Analysis')
        action = QtWidgets.QAction('Open &Frameviewer', self)
        action.triggered.connect(self._open_frameviewer)
        action.setToolTip('View frames related to given mode')
        if self.recon_type == '3d':
            action.setEnabled(False)
        analysismenu.addAction(action)
        action = QtWidgets.QAction('Subtract radial minimum', self)
        action.triggered.connect(self._subtract_radmin)
        action.setToolTip('Subtract radial minimum from intensities')
        analysismenu.addAction(action)

    def _init_plotarea(self):
        plot_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical)
        plot_splitter.setObjectName('plots')

        # Volume slices figure
        self.fig = matplotlib.figure.Figure(figsize=(14, 5))
        self.fig.subplots_adjust(left=0.0, bottom=0.00, right=0.99, wspace=0.0)
        #self.fig.set_facecolor('#232629')
        #self.fig.set_facecolor('#112244')
        self.fig.set_facecolor('#222222')
        self.canvas = FigureCanvas(self.fig)
        self.canvas.show()
        plot_splitter.addWidget(self.canvas)
        self.vol_plotter = VolumePlotter(self.fig, self.recon_type, self.num_modes, self.num_nonrot, self.num_rot)
        self.need_replot = self.vol_plotter.need_replot

        # Progress plots figure
        self.log_fig = matplotlib.figure.Figure(figsize=(14, 5), facecolor='w')
        #self.log_fig.set_facecolor('#232629')
        #self.log_fig.set_facecolor('#112244')
        self.log_fig.set_facecolor('#222222')
        self.plotcanvas = FigureCanvas(self.log_fig)
        self.plotcanvas.show()
        plot_splitter.addWidget(self.plotcanvas)
        self.log_plotter = LogPlotter(self.log_fig, self.folder)
        
        return plot_splitter

    def _init_optionsarea(self):
        options_widget = QtWidgets.QWidget()
        vbox = QtWidgets.QVBoxLayout()
        options_widget.setLayout(vbox)

        # -- Log file
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('Log file name:', self)
        hbox.addWidget(label)
        self.logfname = QtWidgets.QLineEdit(self.logfname, self)
        self.logfname.setMinimumWidth(160)
        self.logfname.setToolTip('Path to log file to get metrics and latest iterations')
        hbox.addWidget(self.logfname)
        label = QtWidgets.QLabel('VRange:', self)
        hbox.addWidget(label)
        self.rangemin = QtWidgets.QLineEdit('0', self)
        self.rangemin.setFixedWidth(48)
        self.rangemin.returnPressed.connect(self._range_changed)
        self.rangemin.setToolTip('Minimum value of color scale')
        hbox.addWidget(self.rangemin)
        self.rangestr = QtWidgets.QLineEdit('1', self)
        self.rangestr.setFixedWidth(48)
        self.rangestr.returnPressed.connect(self._range_changed)
        self.rangestr.setToolTip('Maximum value of color scale')
        hbox.addWidget(self.rangestr)

        # -- Volume file
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('File name:', self)
        hbox.addWidget(label)
        if self.model_name is None:
            self.fname = QtWidgets.QLineEdit(self.folder+'/output/intens_001.bin', self)
        else:
            self.fname = QtWidgets.QLineEdit(self.model_name, self)
        self.fname.setMinimumWidth(160)
        self.fname.setToolTip('Path to volume to be plotted')
        hbox.addWidget(self.fname)
        label = QtWidgets.QLabel('Exp:', self)
        hbox.addWidget(label)
        self.expstr = QtWidgets.QLineEdit('1', self)
        self.expstr.setFixedWidth(48)
        self.expstr.returnPressed.connect(self._range_changed)
        self.expstr.setToolTip('Exponent, or gamma, for color scale. Enter the string "log" for the symlog normalization')
        hbox.addWidget(self.expstr)

        # -- Sliders
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        label = QtWidgets.QLabel('Iteration', self)
        hbox.addWidget(label)
        self.iter_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
        self.iter_slider.setRange(0, 1)
        self.iter_slider.sliderMoved.connect(self._iterslider_moved)
        self.iter_slider.sliderReleased.connect(self._iternum_changed)
        self.iter_slider.setToolTip('Set iteration to view')
        hbox.addWidget(self.iter_slider)
        self.iternum = MySpinBox(self)
        self.iternum.setValue(self.iter_slider.value())
        self.iternum.setMinimum(0)
        self.iternum.setMaximum(1)
        #self.iternum.valueChanged.connect(self._iternum_changed)
        self.iternum.editingFinished.connect(self._iternum_changed)
        self.iternum.setFixedWidth(60)
        self.iternum.setToolTip('Set iteration to view')
        hbox.addWidget(self.iternum)
        if self.recon_type == '3d':
            hbox = QtWidgets.QHBoxLayout()
            vbox.addLayout(hbox)
            label = QtWidgets.QLabel('Layer num.', self)
            hbox.addWidget(label)
            self.layer_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
            self.layer_slider.setRange(0, 200)
            self.layer_slider.sliderMoved.connect(self._layerslider_moved)
            self.layer_slider.sliderReleased.connect(self._layernum_changed)
            self.layer_slider.setToolTip('Set layer number in 3D volume')
            hbox.addWidget(self.layer_slider)
            self.layernum = MySpinBox(self)
            self.layernum.setValue(self.layer_slider.value())
            self.layernum.setMinimum(0)
            self.layernum.setMaximum(200)
            self.layernum.valueChanged.connect(self._layernum_changed)
            self.layernum.editingFinished.connect(self._layernum_changed)
            self.layernum.setFixedWidth(60)
            self.layernum.setToolTip('Set layer number in 3D volume')
            hbox.addWidget(self.layernum)
        if self.num_modes > 1:
            hbox = QtWidgets.QHBoxLayout()
            vbox.addLayout(hbox)
            label = QtWidgets.QLabel('Mode', self)
            hbox.addWidget(label)
            self.mode_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
            self.mode_slider.setRange(0, self.num_modes-1)
            self.mode_slider.sliderMoved.connect(self._modeslider_moved)
            self.mode_slider.sliderReleased.connect(self._modenum_changed)
            self.mode_slider.setToolTip('Set mode number')
            hbox.addWidget(self.mode_slider)
            self.modenum = MySpinBox(self)
            self.modenum.setValue(self.iter_slider.value())
            self.modenum.setMinimum(0)
            self.modenum.setMaximum(self.num_modes-1)
            #self.modenum.valueChanged.connect(self._modenum_changed)
            self.modenum.editingFinished.connect(self._modenum_changed)
            self.modenum.setFixedWidth(60)
            self.modenum.setToolTip('Set mode number')
            hbox.addWidget(self.modenum)
            self.old_modenum = self.modenum.value()

        # -- Buttons
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        button = QtWidgets.QPushButton('Check', self)
        button.clicked.connect(self._check_for_new)
        button.setToolTip('Examine log file to see whether any new iterations have been completed')
        hbox.addWidget(button)
        self.ifcheck = QtWidgets.QCheckBox('Keep checking', self)
        self.ifcheck.stateChanged.connect(self._keep_checking)
        self.ifcheck.setChecked(False)
        self.ifcheck.setToolTip('Check log file every 5 seconds')
        hbox.addWidget(self.ifcheck)
        hbox.addStretch(1)
        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        hbox.addStretch(1)
        button = QtWidgets.QPushButton('Plot', self)
        button.clicked.connect(self._parse_and_plot)
        button.setToolTip('Plot volume (shortcut: ENTER)')
        hbox.addWidget(button)
        button = QtWidgets.QPushButton('Reparse', self)
        button.clicked.connect(lambda: self._parse_and_plot(force=True))
        button.setToolTip('Force reparsing of file and plot')
        hbox.addWidget(button)
        button = QtWidgets.QPushButton('Quit', self)
        button.clicked.connect(self.close)
        hbox.addWidget(button)

        # -- Log file display
        log_area = QtWidgets.QScrollArea(self)
        vbox.addWidget(log_area)
        log_area.setMinimumWidth(300)
        log_area.setWidgetResizable(True)
        self.emclog_text = QtWidgets.QTextEdit(
            'Press \'Check\' to synchronize with log file<br>'
            'Select \'Keep Checking\' to periodically synchronize<br><br>'
            'The top half of the display area will show three orthogonal<br>'
            'slices of the 3D volume. The bottom half will show plots of<br>'
            'various parameters vs iteration.', self)
        self.emclog_text.setReadOnly(True)
        self.emclog_text.setFontPointSize(8)
        self.emclog_text.setFontFamily('Courier')
        self.emclog_text.setFontWeight(QtGui.QFont.DemiBold)
        self.emclog_text.setTabStopWidth(22)
        self.emclog_text.setLineWrapMode(QtWidgets.QTextEdit.NoWrap)
        self.emclog_text.setObjectName('logtext')
        self.emclog_text.setToolTip('Log file contents')
        log_area.setWidget(self.emclog_text)

        return options_widget

    def _layernum_changed(self, value=None):
        if value is None:
            # Slider released or editing finished
            self.need_replot = True
        elif value == self.layernum.value():
            self.layer_slider.setValue(value)
        self._parse_and_plot()

    def _layerslider_moved(self, value):
        self.layernum.setValue(value)

    def _iternum_changed(self, value=None):
        if value is None:
            self.fname.setText(self._gen_model_fname(self.iternum.value()))
        elif value == self.iternum.value():
            self.iter_slider.setValue(value)
            if self.need_replot:
                self.fname.setText(self._gen_model_fname(self.iternum.value()))
        self._parse_and_plot()

    def _iterslider_moved(self, value):
        self.iternum.setValue(value)

    def _modenum_changed(self, value=None):
        if value == self.modenum.value():
            self.mode_slider.setValue(value)
        if self.recon_type == '3d':
            self._parse_and_plot()
        else:
            self._plot_vol(update=True)

    def _modeslider_moved(self, value):
        self.modenum.setValue(value)

    def _range_changed(self):
        self.need_replot = True

    def _gen_model_fname(self, num):
        h5_fname = self.folder+'/output_%.3d.h5' % num
        if os.path.isfile(h5_fname):
            return h5_fname
        else:
            return self.folder+'/output/intens_%.3d.bin' % num

    def _read_config(self, config):
        try:
            self.folder = read_config.get_filename(config, 'emc', 'output_folder')
        except read_config.configparser.NoOptionError:
            self.folder = 'data/'

        try:
            self.logfname = read_config.get_filename(config, 'emc', 'log_file')
        except read_config.configparser.NoOptionError:
            self.logfname = 'EMC.log'

        try:
            self.recon_type = read_config.get_param(config, 'emc', 'recon_type').lower()
        except read_config.configparser.NoOptionError:
            self.recon_type = '3d'
        self.num_modes = 1
        self.num_nonrot = 0
        self.num_rot = None
        try:
            self.num_modes = int(read_config.get_param(config, 'emc', 'num_modes'))
            self.num_nonrot = int(read_config.get_param(config, 'emc', 'num_nonrot_modes'))
            self.num_rot = int(read_config.get_param(config, 'emc', 'num_rot'))
        except read_config.configparser.NoOptionError:
            pass

    def _init_sliders(self, slider_type, numvals, init):
        if slider_type == 'layer':
            self.layer_slider.setRange(0, numvals-1)
            self.layernum.setMaximum(numvals-1)
            self.layer_slider.setValue(init)
            self._layerslider_moved(init)
        elif slider_type == 'mode':
            self.mode_slider.setRange(0, numvals-1)
            self.modenum.setMaximum(numvals-1)
            self.mode_slider.setValue(init)
            self._modeslider_moved(init)

    def _plot_vol(self, num=None, update=False):
        if self.recon_type == '2d':
            self.canvas.mpl_connect('button_press_event', self._select_mode)
            if num is None:
                if self.num_modes > 1:
                    num = int(self.modenum.text())
                else:
                    num = 0
        elif num is None:
            num = int(self.layernum.text())
        argsdict = {'vrange': (float(self.rangemin.text()), float(self.rangestr.text())),
                    'exponent': self.expstr.text(),
                    'cmap': self.color_map.checkedAction().text()}
        if update:
            self.vol_plotter.update_mode(num, **argsdict)
        else:
            self.vol_plotter.plot(num, **argsdict)
        if self.num_modes > 1:
            self.old_modenum = self.modenum.value()

    def _parse_and_plot(self, force=False, rots=True):
        if force or not self.vol_plotter.image_exists or self.old_fname != self.fname.text():
            if self.num_modes > 1:
                self._init_sliders('mode', self.num_modes+self.num_nonrot, self.modenum.value())
                modenum = self.modenum.value()
            else:
                modenum = 0
            self.old_fname, size, center = self.vol_plotter.parse(self.fname.text(),
                                            modenum=modenum, rots=rots)
            if self.recon_type == '3d':
                self._init_sliders('layer', size, center)
            self._plot_vol()
        elif self.num_modes > 1 and self.modenum.value() != self.old_modenum:
            self.old_fname, size, center = self.vol_plotter.parse(self.fname.text(),
                                             modenum=self.modenum.value(), rots=rots)
            if self.recon_type == '3d':
                self._init_sliders('layer', size, center)
            elif self.num_modes > 1:
                self._init_sliders('mode', self.num_modes+self.num_nonrot, self.modenum.value())
            self._plot_vol()
        elif self.need_replot:
            self._plot_vol()
        else:
            pass

    def _check_for_new(self):
        if not os.path.isfile(self.logfname.text()):
            return
        with open(self.logfname.text(), 'r') as fptr:
            last_line = fptr.readlines()[-1].rstrip().split()
        try:
            iteration = int(last_line[0])
        except ValueError:
            iteration = 0

        if iteration > 0 and self.max_iternum != iteration:
            self.fname.setText(self._gen_model_fname(iteration))
            self.max_iternum = iteration
            self.iter_slider.setRange(0, self.max_iternum)
            self.iternum.setMaximum(self.max_iternum)
            self.iter_slider.setValue(iteration)
            self._iterslider_moved(iteration)
            log_text = self.log_plotter.plot(self.logfname.text(),
                 self.color_map.checkedAction().text())
            self._parse_and_plot()
            self.emclog_text.setText(log_text)

    def _keep_checking(self):
        if self.ifcheck.isChecked():
            self._check_for_new()
            self.checker.timeout.connect(self._check_for_new)
            self.checker.start(5000)
        else:
            self.checker.stop()

    def _select_mode(self, event):
        curr_mode = -1
        for i, subp in enumerate(self.vol_plotter.subplot_list):
            if event.inaxes is subp:
                curr_mode = i

        if curr_mode >= 0 and curr_mode != self.modenum.value():
            self.mode_slider.setValue(curr_mode)
            self.modenum.setValue(curr_mode)
            self._modenum_changed()

            if self.fviewer is not None:
                self.fviewer.mode = curr_mode
                self.fviewer.label.setText('Class %d frames'%curr_mode)
                self.fviewer.numlist = np.where(self.vol_plotter.modes == curr_mode)[0]

    def _load_volume(self):
        fpath = QtWidgets.QFileDialog.getOpenFileName(self, 'Load 3D Volume',
                                                      'data/', 'Binary data (*.bin)')
        if os.environ['QT_API'] == 'pyqt5':
            fname = fpath[0]
        else:
            fname = fpath
        if fname:
            self.fname.setText(fname)
            self._parse_and_plot()

    def _save_plot(self):
        default_name = 'images/'+os.path.splitext(os.path.basename(self.fname.text()))[0]+'.png'
        fpath = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Volume Image',
                                                      default_name, 'Image (*.png)')
        if os.environ['QT_API'] == 'pyqt5':
            fname = fpath[0]
        else:
            fname = fpath
        if fname:
            self.fig.savefig(fname, bbox_inches='tight', dpi=120)
            sys.stderr.write('Saved to %s\n'%fname)

    def _save_log_plot(self):
        default_name = 'images/log_fig.png'
        fpath = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Log Plots',
                                                      default_name, 'Image (*.png)')
        if os.environ['QT_API'] == 'pyqt5':
            fname = fpath[0]
        else:
            fname = fpath
        if fname:
            self.log_fig.savefig(fname, bbox_inches='tight', dpi=120)
            sys.stderr.write("Saved to %s\n"%fname)

    def _plot_layer(self, num):
        self._plot_vol(num=num)
        self.fig.suptitle('Layer %d'%num, y=0.01, va='bottom')
        return self.fig,

    def _save_layer_movie(self):
        default_name = 'images/'+os.path.splitext(os.path.basename(self.fname.text()))[0]+'_layers.mp4'
        fpath = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Layer Animation Movie',
                                                      default_name, 'Movie (*.mp4)')
        if os.environ['QT_API'] == 'pyqt5':
            fname = fpath[0]
        else:
            fname = fpath
        if fname:
            sys.stderr.write('Saving layer animation to %s ...' % fname)
            Writer = animation.writers['ffmpeg']
            writer = Writer(fps=20, codec='h264', bitrate=1800)
            anim = animation.FuncAnimation(self.fig, self._plot_layer, self.layer_slider.maximum()+1, interval=50, repeat=False)
            anim.save(fname, writer=writer)
            self._parse_and_plot(force=True)
            sys.stderr.write('done\n')

    def _plot_iter(self, num):
        self.fname.setText(self._gen_model_fname(num))
        self._parse_and_plot()
        self.fig.suptitle('Iteration %d'%num, y=0.01, va='bottom')
        return self.fig,

    def _save_iter_movie(self):
        default_name = 'images/iterations.mp4'
        fpath = QtWidgets.QFileDialog.getSaveFileName(self, 'Save Layer Animation Movie',
                                                      default_name, 'Movie (*.mp4)')
        if os.environ['QT_API'] == 'pyqt5':
            fname = fpath[0]
        else:
            fname = fpath
        if fname:
            sys.stderr.write('Saving iteration animation to %s ...' % fname)
            Writer = animation.writers['ffmpeg']
            writer = Writer(fps=10, codec='h264', bitrate=1800)
            anim = animation.FuncAnimation(self.fig, self._plot_iter, self.iter_slider.maximum()+1, interval=50, repeat=False)
            anim.save(fname, writer=writer)
            self._parse_and_plot(force=True)
            sys.stderr.write('done\n')

    def _cmap_changed(self):
        if self.vol_plotter.image_exists:
            self.need_replot = True
            self._parse_and_plot()

    def _open_frameviewer(self):
        if self.fviewer is not None:
            return
        if self.num_modes > 1 and self.vol_plotter.rots is not None:
            mode = self.modenum.value()
            numlist = np.where(self.vol_plotter.modes == mode)[0]
            self.fviewer = MyFrameviewer(self.config, mode, numlist)
        else:
            self.fviewer = MyFrameviewer(self.config, -1, [])
        self.fviewer.windowClosed.connect(self._fviewer_closed)

    def _subtract_radmin(self):
        self.vol_plotter.subtract_radmin()
        self._plot_vol()

    @QtCore.Slot()
    def _fviewer_closed(self):
        self.fviewer = None

    def closeEvent(self, event): # pylint: disable=C0103
        if self.fviewer is not None:
            self.fviewer.close()
        event.accept()

    def keyPressEvent(self, event): # pylint: disable=C0103
        '''Override of default keyPress event handler'''
        key = event.key()
        mod = int(event.modifiers())

        if key == QtCore.Qt.Key_Return or key == QtCore.Qt.Key_Enter:
            self._parse_and_plot()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+Q'):
            self.close()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+S'):
            self._save_plot()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+K'):
            self._check_for_new()
        else:
            event.ignore()
Example #23
0
class SampleLogsView(QSplitter):
    """Sample Logs View

    This contains a table of the logs, a plot of the currently
    selected logs, and the statistics of the selected log.
    """
    def __init__(self, presenter, parent = None, name = '', isMD=False, noExp = 0):
        super(SampleLogsView, self).__init__(parent)

        self.presenter = presenter

        self.setWindowTitle("{} sample logs".format(name))
        self.setWindowFlags(Qt.Window)
        self.setAttribute(Qt.WA_DeleteOnClose, True)

        # left hand side
        self.frame_left = QFrame()
        layout_left = QVBoxLayout()

        # add a spin box for MD workspaces
        if isMD:
            layout_mult_expt_info = QHBoxLayout()
            layout_mult_expt_info.addWidget(QLabel("Experiment Info #"))
            self.experimentInfo = QSpinBox()
            self.experimentInfo.setMaximum(noExp-1)
            self.experimentInfo.valueChanged.connect(self.presenter.changeExpInfo)
            layout_mult_expt_info.addWidget(self.experimentInfo)
            layout_mult_expt_info.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Expanding))
            layout_left.addLayout(layout_mult_expt_info)

        # Create sample log table
        self.table = QTableView()
        self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.table.doubleClicked.connect(self.presenter.doubleClicked)
        self.table.contextMenuEvent = self.tableMenu
        layout_left.addWidget(self.table)
        self.frame_left.setLayout(layout_left)
        self.addWidget(self.frame_left)

        #right hand side
        self.frame_right = QFrame()
        layout_right = QVBoxLayout()

        #Add full_time and experimentinfo options
        layout_options = QHBoxLayout()

        if isMD:
            layout_options.addWidget(QLabel("Experiment Info #"))
            self.experimentInfo = QSpinBox()
            self.experimentInfo.setMaximum(noExp-1)
            self.experimentInfo.valueChanged.connect(self.presenter.changeExpInfo)
            layout_options.addWidget(self.experimentInfo)

        #check boxes
        self.full_time = QCheckBox("Relative Time")
        self.full_time.setToolTip(
            "Shows relative time in seconds from the start of the run.")
        self.full_time.setChecked(True)
        self.full_time.stateChanged.connect(self.presenter.plot_logs)
        layout_options.addWidget(self.full_time)
        self.show_filtered = QCheckBox("Filtered Data")
        self.show_filtered.setToolTip(
            "Filtered data only shows data while running and in this period.\nInvalid values are also filtered.")
        self.show_filtered.setChecked(True)
        self.show_filtered.stateChanged.connect(self.presenter.filtered_changed)
        layout_options.addWidget(self.show_filtered)
        self.spaceItem = QSpacerItem(10, 10, QSizePolicy.Expanding)
        layout_options.addSpacerItem(self.spaceItem)
        layout_right.addLayout(layout_options)

        # Sample log plot
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setSizePolicy(QSizePolicy.Expanding,QSizePolicy.Expanding)
        self.canvas.mpl_connect('button_press_event', self.presenter.plot_clicked)
        self.ax = self.fig.add_subplot(111, projection='mantid')
        layout_right.addWidget(self.canvas)

        # Sample stats
        self.create_stats_widgets()
        layout_stats = QFormLayout()
        layout_stats.addRow('', QLabel("Log Statistics"))
        layout_stats.addRow('Min:', self.stats_widgets["minimum"])
        layout_stats.addRow('Max:', self.stats_widgets["maximum"])
        layout_stats.addRow('Time Avg:', self.stats_widgets["time_mean"])
        layout_stats.addRow('Time Std Dev:', self.stats_widgets["time_standard_deviation"])
        layout_stats.addRow('Mean (unweighted):', self.stats_widgets["mean"])
        layout_stats.addRow('Median (unweighted):', self.stats_widgets["median"])
        layout_stats.addRow('Std Dev:', self.stats_widgets["standard_deviation"])
        layout_stats.addRow('Duration:', self.stats_widgets["duration"])
        layout_right.addLayout(layout_stats)
        self.frame_right.setLayout(layout_right)

        self.addWidget(self.frame_right)
        self.setStretchFactor(0,1)

        self.resize(1200,800)
        self.show()

    def closeEvent(self, event):
        self.deleteLater()
        super(SampleLogsView, self).closeEvent(event)

    def tableMenu(self, event):
        """Right click menu for table, can plot or print selected logs"""
        menu = QMenu(self)
        plotAction = menu.addAction("Plot selected")
        plotAction.triggered.connect(self.presenter.new_plot_logs)
        plotAction = menu.addAction("Print selected")
        plotAction.triggered.connect(self.presenter.print_selected_logs)
        menu.exec_(event.globalPos())

    def set_model(self, model):
        """Set the model onto the table"""
        self.model = model
        self.table.setModel(self.model)
        self.table.resizeColumnsToContents()
        self.table.horizontalHeader().setSectionResizeMode(2, QHeaderView.Stretch)
        self.table.selectionModel().selectionChanged.connect(self.presenter.update)

    def show_plot_and_stats(self, show_plot_and_stats):
        """sets wether the plot and stats section should be visible"""
        if self.frame_right.isVisible() != show_plot_and_stats:
            # the desired state is nor the current state
            self.setUpdatesEnabled(False)
            current_width = self.frame_right.width()
            if current_width:
                self.last_width = current_width
            else:
                current_width = self.last_width

            if show_plot_and_stats:
                self.resize(self.width() + current_width, self.height())
            else:
                self.resize(self.width() - current_width, self.height())
            self.frame_right.setVisible(show_plot_and_stats)
            self.setUpdatesEnabled(True)

    def plot_selected_logs(self, ws, exp, rows):
        """Update the plot with the selected rows"""
        if self.frame_right.isVisible():
            self.ax.clear()
            self.create_ax_by_rows(self.ax, ws, exp, rows)
            try:
                self.fig.canvas.draw()
            except ValueError as ve:
                #this can throw an error if the plot has recently been hidden, but the error does not matter
                if not str(ve).startswith("Image size of"):
                    raise

    def new_plot_selected_logs(self, ws, exp, rows):
        """Create a new plot, in a separate window for selected rows"""
        fig, ax = plt.subplots(subplot_kw={'projection': 'mantid'})
        self.create_ax_by_rows(ax, ws, exp, rows)
        fig.show()

    def create_ax_by_rows(self, ax, ws, exp, rows):
        """Creates the plots for given rows onto axis ax"""
        for row in rows:
            log_text = self.get_row_log_name(row)
            ax.plot(ws,
                    LogName=log_text,
                    label=log_text,
                    FullTime=not self.full_time.isChecked(),
                    Filtered=self.show_filtered.isChecked(),
                    ExperimentInfo=exp)

        ax.set_ylabel('')
        if ax.get_legend_handles_labels()[0]:
            ax.legend()

    def set_log_controls(self,are_logs_filtered):
        """Sets log specific settings based on the log clicked on"""
        self.show_filtered.setEnabled(are_logs_filtered)

    def get_row_log_name(self, i):
        """Returns the log name of particular row"""
        return str(self.model.item(i, 0).text())

    def get_exp(self):
        """Get set experiment info number"""
        return self.experimentInfo.value()

    def get_selected_row_indexes(self):
        """Return a list of selected row from table"""
        return [row.row() for row in self.table.selectionModel().selectedRows()]

    def set_selected_rows(self, rows):
        """Set seleceted rows in table"""
        mode = QItemSelectionModel.Select | QItemSelectionModel.Rows
        for row in rows:
            self.table.selectionModel().select(self.model.index(row, 0), mode)

    def create_stats_widgets(self):
        """Creates the statistics widgets"""
        self.stats_widgets = {"minimum": QLineEdit(),
                              "maximum": QLineEdit(),
                              "mean": QLineEdit(),
                              "median": QLineEdit(),
                              "standard_deviation": QLineEdit(),
                              "time_mean": QLineEdit(),
                              "time_standard_deviation": QLineEdit(),
                              "duration": QLineEdit()}
        for widget in self.stats_widgets.values():
            widget.setReadOnly(True)

    def set_statistics(self, stats):
        """Updates the statistics widgets from stats dictionary"""
        for param in self.stats_widgets.keys():
            self.stats_widgets[param].setText('{:.6}'.format(getattr(stats, param)))

    def clear_statistics(self):
        """Clears the values in statistics widgets"""
        for widget in self.stats_widgets.values():
            widget.clear()
Example #24
0
class makeManualMask(QDialog):
    def __init__(self,
                 file_in,
                 subfolder='result_segmentation',
                 fn=None,
                 parent=None,
                 wsize=(1000, 1000)):
        super(makeManualMask, self).__init__(parent)
        self.setWindowTitle('Manual mask: ' + file_in)
        QApplication.setStyle('Fusion')
        self.setWindowFlag(Qt.WindowCloseButtonHint, False)

        self.file_in = file_in
        self.subfolder = subfolder
        self.fn = fn
        img = imread(file_in)
        if len(img.shape) == 2:
            img = np.expand_dims(img, 0)
        if img.shape[-1] == np.min(img.shape):
            img = np.moveaxis(img, -1, 0)
        self.img = img[0]
        self.x = []
        self.y = []

        # a figure instance to plot on
        self.figure = Figure()
        # this is the Canvas Widget that displays the `figure`
        # it takes the `figure` instance as a parameter to __init__
        self.canvas = FigureCanvas(self.figure)
        # this is the Navigation widget
        # it takes the Canvas widget and a parent
        # self.toolbar = NavigationToolbar(self.canvas, self)

        self.plotImage()

        # Just some button connected to `plot` method
        self.button = QPushButton('Save mask')
        self.button.clicked.connect(self.saveMask)

        # set the layout
        layout = QVBoxLayout()
        # layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)
        layout.addWidget(self.button)
        self.setLayout(layout)

        self.resize(wsize[0], wsize[1])

        self.__cid2 = self.canvas.mpl_connect('button_press_event',
                                              self.__button_press_callback)

    def plotImage(self):
        ''' plot some random stuff '''
        # create an axis
        self.ax = self.figure.add_subplot(111)
        # discards the old graph
        self.ax.clear()
        # plot data
        self.ax.imshow(self.img,
                       cmap='gray',
                       vmin=np.percentile(self.img, 1.),
                       vmax=np.percentile(self.img, 99.))

        self.line = None  #ax.plot([],[],'-r')

        # refresh canvas
        self.canvas.draw()

    def saveMask(self):
        ny, nx = np.shape(self.img)
        poly_verts = ([(self.x[0], self.y[0])] +
                      list(zip(reversed(self.x), reversed(self.y))))
        # Create vertex coordinates for each grid cell...
        # (<0,0> is at the top left of the grid in this system)
        x, y = np.meshgrid(np.arange(nx), np.arange(ny))
        x, y = x.flatten(), y.flatten()
        points = np.vstack((x, y)).T

        roi_path = MplPath(poly_verts)
        mask = 1 * roi_path.contains_points(points).reshape((ny, nx))

        folder, filename = os.path.split(self.file_in)
        filename, extension = os.path.splitext(filename)
        if self.fn == None:
            self.fn = filename + '_manual' + extension
        imsave(os.path.join(folder, self.subfolder, self.fn),
               mask.astype(np.uint16))

        self.close()

    def __button_press_callback(self, event):
        if event.inaxes == self.ax:
            x, y = int(event.xdata), int(event.ydata)
            ax = event.inaxes
            n_p = len(self.x)
            self.ax.clear()
            self.ax.imshow(self.img,
                           cmap='gray',
                           vmin=np.percentile(self.img, 1.),
                           vmax=np.percentile(self.img, 99.))
            if (event.button == 1) and (event.dblclick is False):

                self.x.append(x)
                self.y.append(y)

                self.line = ax.plot(self.x, self.y, '-or')

            elif (event.button == 3) and (n_p > 1):
                self.x = self.x[:-1]
                self.y = self.y[:-1]
                self.line = ax.plot(self.x, self.y, '-or')

            elif (event.button == 3) and (n_p == 1):
                self.x = []
                self.y = []
                self.line = None

            elif (((event.button == 1) and (event.dblclick is True))
                  and (n_p > 2)):
                # Close the loop and disconnect
                self.x = self.x[:-1]
                self.y = self.y[:-1]
                self.x.append(x)
                self.x.append(self.x[0])
                self.y.append(y)
                self.y.append(self.y[0])
                self.line = ax.plot(self.x, self.y, '-or')
                # self.canvas.mpl_connect(self.__cid2)
            self.canvas.draw()
Example #25
0
class ApplicationWindow(QtWidgets.QMainWindow):
    """Main application window."""
    def __init__(self):
        """Initialise the application - includes loading settings from disc, initialising a lakeator, and setting up the GUI."""
        super().__init__()
        self._load_settings()

        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)
        layout = QtWidgets.QHBoxLayout(self._main)

        self.setWindowTitle('Locator')
        self.setWindowIcon(QtGui.QIcon("./kiwi.png"))

        self.loadAction = QtWidgets.QAction("&Load File", self)
        self.loadAction.setShortcut("Ctrl+L")
        self.loadAction.setStatusTip("Load a multichannel .wav file.")
        self.loadAction.triggered.connect(self.file_open)

        self.saveAction = QtWidgets.QAction("&Save Image", self)
        self.saveAction.setShortcut("Ctrl+S")
        self.saveAction.setStatusTip("Save the current display to a PNG file.")
        self.saveAction.triggered.connect(self.save_display)
        self.saveAction.setDisabled(True)

        self.saveGisAction = QtWidgets.QAction("&Save to GIS", self)
        self.saveGisAction.setShortcut("Ctrl+G")
        self.saveGisAction.setStatusTip(
            "Save the heatmap as a QGIS-readable georeferenced TIFF file.")
        self.saveGisAction.triggered.connect(self.exportGIS)
        self.saveGisAction.setDisabled(True)

        self.statusBar()

        mainMenu = self.menuBar()
        fileMenu = mainMenu.addMenu("&File")
        fileMenu.addAction(self.loadAction)
        fileMenu.addAction(self.saveAction)
        fileMenu.addAction(self.saveGisAction)

        setArrayDesign = QtWidgets.QAction("&Configure Array Design", self)
        setArrayDesign.setShortcut("Ctrl+A")
        setArrayDesign.setStatusTip(
            "Input relative microphone positions and array bearing.")
        setArrayDesign.triggered.connect(self.get_array_info)

        setGPSCoords = QtWidgets.QAction("&Set GPS Coordinates", self)
        setGPSCoords.setShortcut("Ctrl+C")
        setGPSCoords.setStatusTip(
            "Set the GPS coordinates for the array, and ESPG code for the CRS."
        )
        setGPSCoords.triggered.connect(self.get_GPS_info)

        arrayMenu = mainMenu.addMenu("&Array")
        arrayMenu.addAction(setArrayDesign)
        arrayMenu.addAction(setGPSCoords)

        setDomain = QtWidgets.QAction("&Set Heatmap Domain", self)
        setDomain.setShortcut("Ctrl+D")
        setDomain.setStatusTip(
            "Configure distances left/right up/down at which to generate the heatmap."
        )
        setDomain.triggered.connect(self.getBoundsInfo)

        self.refreshHeatmap = QtWidgets.QAction("&Calculate", self)
        self.refreshHeatmap.setShortcut("Ctrl+H")
        self.refreshHeatmap.setStatusTip("(Re)calculate heatmap.")
        self.refreshHeatmap.triggered.connect(self.generate_heatmap)
        self.refreshHeatmap.setDisabled(True)

        self.refreshView = QtWidgets.QAction("&Recalculate on View", self)
        self.refreshView.setShortcut("Ctrl+R")
        self.refreshView.setStatusTip(
            "Recalculate heatmap at current zoom level.")
        self.refreshView.triggered.connect(self.recalculateOnView)
        self.refreshView.setDisabled(True)

        heatmapMenu = mainMenu.addMenu("&Heatmap")
        heatmapMenu.addAction(setDomain)

        # Initialise canvas
        self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
        layout.addWidget(self.static_canvas)

        # Add a navbar
        navbar = NavigationToolbar(self.static_canvas, self)
        self.addToolBar(QtCore.Qt.BottomToolBarArea, navbar)

        # Override the default mpl save functionality to change default filename
        navbar._actions['save_figure'].disconnect()
        navbar._actions['save_figure'].triggered.connect(self.save_display)

        navbar._actions['home'].triggered.connect(lambda: print("testing"))

        self.img = None

        # Dynamically generate menu full of all available colourmaps. Do not add the inverted ones.
        self.colMenu = heatmapMenu.addMenu("&Choose Colour Map")
        self.colMenu.setDisabled(True)
        colGroup = QtWidgets.QActionGroup(self)
        for colour in sorted(colormaps(), key=str.casefold):
            if colour[-2:] != "_r":
                cm = self.colMenu.addAction(colour)
                cm.setCheckable(True)
                if colour == self.settings["heatmap"]["cmap"][:-2]:
                    cm.setChecked(True)
                receiver = lambda checked, cmap=colour: self.img.set_cmap(cmap)
                cm.triggered.connect(receiver)
                cm.triggered.connect(self._setcol)
                cm.triggered.connect(self.static_canvas.draw)
                colGroup.addAction(cm)

        self.invert = QtWidgets.QAction("&Invert Colour Map", self)
        self.invert.setShortcut("Ctrl+I")
        self.invert.setStatusTip("Invert the current colourmap.")
        self.invert.triggered.connect(self.invert_heatmap)
        self.invert.setCheckable(True)
        self.invert.setDisabled(True)
        heatmapMenu.addAction(self.invert)

        heatmapMenu.addSeparator()
        heatmapMenu.addAction(self.refreshHeatmap)
        heatmapMenu.addAction(self.refreshView)

        algoMenu = mainMenu.addMenu("Algorithm")
        self.algChoice = algoMenu.addMenu("&Change Algorithm")
        algGroup = QtWidgets.QActionGroup(self)
        for alg in sorted(["GCC", "MUSIC", "AF-MUSIC"], key=str.casefold):
            cm = self.algChoice.addAction(alg)
            cm.setCheckable(True)
            if alg == self.settings["algorithm"]["current"]:
                cm.setChecked(True)
            receiver = lambda checked, al=alg: self.setAlg(al)
            cm.triggered.connect(receiver)
            colGroup.addAction(cm)

        self.params = QtWidgets.QAction("&Algorithm Settings", self)
        self.params.setStatusTip("Alter algorithm-specific settings.")
        self.params.triggered.connect(self.getAlgoInfo)
        algoMenu.addAction(self.params)

        # Display a "ready" message
        self.statusBar().showMessage('Ready')

        # Boolean to keep track of whether we have GPS information for the array, and an image
        self._has_GPS = False
        self._has_heatmap = False

        # Keep track of the currently opened file
        self.open_filename = ""

        self.loc = lakeator.Lakeator(self.settings["array"]["mic_locations"])

    def setAlg(self, alg):
        """Change the current algorithm to `alg', and write settings to disc."""
        self.settings["algorithm"]["current"] = alg
        self._save_settings()

    def ondraw(self, event):
        """Return the new axis limits when the figure is zoomed, but not on window resize."""
        if self._has_heatmap and (self.settings["heatmap"]["xlim"][0] != self._static_ax.get_xlim()[0] or \
            self.settings["heatmap"]["xlim"][1] != self._static_ax.get_xlim()[1] or \
            self.settings["heatmap"]["ylim"][0] != self._static_ax.get_ylim()[0] or \
            self.settings["heatmap"]["ylim"][1] != self._static_ax.get_ylim()[1]):
            self.refreshView.setDisabled(False)
        self.last_zoomed = [
            self._static_ax.get_xlim(),
            self._static_ax.get_ylim()
        ]

    def recalculateOnView(self):
        """If the image has been zoomed, calling this method will recalculate the heatmap on the current zoom level."""
        if hasattr(self, "last_zoomed"):
            self.settings["heatmap"]["xlim"] = self.last_zoomed[0]
            self.settings["heatmap"]["ylim"] = self.last_zoomed[1]
            self._save_settings()
            self.generate_heatmap()

    def invert_heatmap(self):
        """Add or remove _r to the current colourmap before setting it (to invert the colourmap), then redraw the canvas."""
        if self.settings["heatmap"]["cmap"][-2:] == "_r":
            self.settings["heatmap"]["cmap"] = self.settings["heatmap"][
                "cmap"][:-2]
            self.img.set_cmap(self.settings["heatmap"]["cmap"])
            self.static_canvas.draw()
        else:
            try:
                self.img.set_cmap(self.settings["heatmap"]["cmap"] + "_r")
                self.settings["heatmap"][
                    "cmap"] = self.settings["heatmap"]["cmap"] + "_r"
                self.static_canvas.draw()
            except ValueError as inst:
                print(type(inst), inst)
        self._save_settings()

    def _setcol(self, c):
        """Set the colourmap attribute to the name of the cmap - needed as I'm using strings to set the cmaps rather than cmap objects."""
        self.settings["heatmap"]["cmap"] = self.img.get_cmap().name
        self._save_settings()

    def generate_heatmap(self):
        """Calculate and draw the heatmap."""
        # Initialise the axis on the canvas, refresh the screen
        self.static_canvas.figure.clf()
        self._static_ax = self.static_canvas.figure.subplots()

        cid = self.static_canvas.mpl_connect('draw_event', self.ondraw)

        # Show a loading message while the user waits
        self.statusBar().showMessage('Calculating heatmap...')
        # dom = self.loc.estimate_DOA_heatmap(self.settings["algorithm"]["current"], xrange=self.last_zoomed[0], yrange=self.last_zoomed[1], no_fig=True)

        dom = self.loc.estimate_DOA_heatmap(
            self.settings["algorithm"]["current"],
            xrange=self.settings["heatmap"]["xlim"],
            yrange=self.settings["heatmap"]["ylim"],
            no_fig=True,
            freq=self.settings["algorithm"]["MUSIC"]["freq"],
            AF_freqs=(self.settings["algorithm"]["AF-MUSIC"]["f_min"],
                      self.settings["algorithm"]["AF-MUSIC"]["f_max"]),
            f_0=self.settings["algorithm"]["AF-MUSIC"]["f_0"])

        # Show the image and set axis labels & title
        self.img = self._static_ax.imshow(
            dom,
            cmap=self.settings["heatmap"]["cmap"],
            interpolation='none',
            origin='lower',
            extent=[
                self.settings["heatmap"]["xlim"][0],
                self.settings["heatmap"]["xlim"][1],
                self.settings["heatmap"]["ylim"][0],
                self.settings["heatmap"]["ylim"][1]
            ])
        self._static_ax.set_xlabel("Horiz. Dist. from Center of Array [m]")
        self._static_ax.set_ylabel("Vert. Dist. from Center of Array [m]")
        self._static_ax.set_title("{}-based Source Location Estimate".format(
            self.settings["algorithm"]["current"]))

        # Add a colourbar and redraw the screen
        self.static_canvas.figure.colorbar(self.img)
        self.static_canvas.draw()

        # Once there's an image being displayed, you can save it and change the colours
        self.saveAction.setDisabled(False)
        if self._has_GPS:
            self.saveGisAction.setDisabled(False)
        self.statusBar().showMessage('Ready.')
        self.colMenu.setDisabled(False)
        self.invert.setDisabled(False)
        self._has_heatmap = True

    def file_open(self):
        """Let the user pick a file to open, and then calculate the cross-correlations."""
        self.statusBar().showMessage('Loading...')
        name, _ = QtWidgets.QFileDialog.getOpenFileName(
            self, "Load .wav file", "./", "Audio *.wav")
        if name:
            try:
                self.loc.load(name,
                              rho=self.settings["algorithm"]["GCC"]["rho"])
            except IndexError:
                msg = QtWidgets.QMessageBox()
                msg.setIcon(QtWidgets.QMessageBox.Critical)
                msg.setText(
                    "File Error\nThe number of microphones in the current array configuration ({0}) is greater than the number of tracks in the selected audio file. Please select a {0}-track audio file."
                    .format(self.loc.mics.shape[0]))
                msg.setWindowTitle("File Error")
                msg.setMinimumWidth(200)
                msg.exec_()
                return
            if self.loc.mics.shape[0] < self.loc.data.shape[1]:
                msg = QtWidgets.QMessageBox()
                msg.setIcon(QtWidgets.QMessageBox.Critical)
                msg.setText(
                    "File Error\nThe number of microphones in the current array configuration ({0}) is less than the number of tracks in the selected audio file ({1}). Please select a {0}-track audio file, or configure the microphone locations to match the current file."
                    .format(self.loc.mics.shape[0], self.loc.data.shape[1]))
                msg.setWindowTitle("File Error")
                msg.setMinimumWidth(200)
                msg.exec_()
                return
            self.open_filename = name
            self.refreshHeatmap.setDisabled(False)
            self.statusBar().showMessage('Ready.')

    def save_display(self):
        """Save the heatmap and colourbar with a sensible default filename."""
        defaultname = self.open_filename[:-4] + "_" + self.settings[
            "algorithm"]["current"] + "_heatmap.png"
        name, _ = QtWidgets.QFileDialog.getSaveFileName(
            self, "Save image", defaultname, "PNG files *.png;; All Files *")
        if name:
            name = name + ".png"
            self.static_canvas.figure.savefig(name)

    def get_GPS_info(self):
        """Create a popup to listen for the GPS info, and connect the listener."""
        self.setGPSInfoDialog = Dialogs.GPSPopUp(
            coords=self.settings["array"]["GPS"]["coordinates"],
            EPSG=self.settings["array"]["GPS"]["EPSG"]["input"],
            pEPSG=self.settings["array"]["GPS"]["EPSG"]["projected"],
            tEPSG=self.settings["array"]["GPS"]["EPSG"]["target"])
        self.setGPSInfoDialog.activate.clicked.connect(self.changeGPSInfo)
        self.setGPSInfoDialog.exec()

    def changeGPSInfo(self):
        """Listener for the change GPS info dialog - writes the new information to disc and enables the ExportToGIS option."""
        try:
            lat, long, EPSG, projEPSG, targetEPSG = self.setGPSInfoDialog.getValues(
            )
            self.settings["array"]["GPS"]["EPSG"]["input"] = EPSG
            self.settings["array"]["GPS"]["EPSG"]["projected"] = projEPSG
            self.settings["array"]["GPS"]["EPSG"]["target"] = targetEPSG
            self.settings["array"]["GPS"]["coordinates"] = (lat, long)
            self._save_settings()

            self._has_GPS = True
            if self._has_heatmap:
                self.saveGisAction.setDisabled(False)
            self.setGPSInfoDialog.close()
        except EPSGError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter EPSG codes for coordinate systems as integers, e.g. 4326 or 2193. To find the EPSG of a given coordinate system, visit https://epsg.io/"
            )
            msg.setWindowTitle("Error with EPSG code")
            msg.setMinimumWidth(200)
            msg.exec_()
        except GPSError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter only the numerical portion of the coordinates, in the order governed by ISO19111 (see https://proj.org/faq.html#why-is-the-axis-ordering-in-proj-not-consistent)"
            )
            msg.setWindowTitle("Error with GPS input.")
            msg.setMinimumWidth(200)
            msg.exec_()

    def get_array_info(self):
        """Create a popup to listen for the mic position info, and connect the listener."""
        self.setMicsInfoDialog = Dialogs.MicPositionPopUp(
            cur_locs=self.settings["array"]["mic_locations"])
        self.setMicsInfoDialog.activate.clicked.connect(self.changeArrayInfo)
        self.setMicsInfoDialog.exec()

    def changeArrayInfo(self):
        """Listener for the change array info dialog - writes the information to disc and re-initialises the locator."""
        # TODO: reload current file, or disable heatmap again after this call
        try:
            miclocs = self.setMicsInfoDialog.getValues()
            self.settings["array"]["mic_locations"] = miclocs
            self._save_settings()
            self.loc = lakeator.Lakeator(
                self.settings["array"]["mic_locations"])
            self.setMicsInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease enter microphone coordinates in meters as x,y pairs, one per line; e.g.\n0.0, 0.0\n0.1, 0.0\n0.0, -0.1\n-0.1, 0.0\n0.0, 0.1"
            )
            msg.setWindowTitle("Error with microphone location input")
            msg.setMinimumWidth(200)
            msg.exec_()

    def getBoundsInfo(self):
        """Create a popup to listen for the change heatmap bounds info, and connect the listener."""
        l, r = self.settings["heatmap"]["xlim"]
        d, u = self.settings["heatmap"]["ylim"]
        self.setBoundsInfoDialog = Dialogs.HeatmapBoundsPopUp(l, r, u, d)
        self.setBoundsInfoDialog.activate.clicked.connect(
            self.changeBoundsInfo)
        self.setBoundsInfoDialog.exec()

    def changeBoundsInfo(self):
        """ Listener change heatmap bounds info dialog - save the information to disc and regenerate the heatmap on the new zoom area."""
        try:
            l_new, r_new, u_new, d_new = self.setBoundsInfoDialog.getValues()
            self.settings["heatmap"]["xlim"] = [l_new, r_new]
            self.settings["heatmap"]["ylim"] = [d_new, u_new]
            self._save_settings()
            # if self.open_filename:
            # self.generate_heatmap()
            self.setBoundsInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that all distances are strictly numeric, e.g. enter '5' or '5.0', rather than '5m' or 'five'."
            )
            msg.setWindowTitle("Error with heatmap bounds")
            msg.setMinimumWidth(200)
            msg.exec_()
        except NegativeDistanceError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that Left/West < Right/East, \nand Up/North < Down/South."
            )
            msg.setWindowTitle("Error; impossible region")
            msg.setMinimumWidth(200)
            msg.exec_()

    def getAlgoInfo(self):
        """Create a popup to listen for the algorithm settings, and attach the listener."""
        self.setAlgoInfoDialog = Dialogs.AlgorithmSettingsPopUp(
            self.settings["algorithm"])
        self.setAlgoInfoDialog.activate.clicked.connect(self.changeAlgoInfo)
        self.setAlgoInfoDialog.cb.currentIndexChanged.connect(self.procChange)
        self.setAlgoInfoDialog.exec()

    def procChange(self):
        self.settings["algorithm"]["GCC"][
            "processor"] = self.setAlgoInfoDialog.cb.currentText()
        self._save_settings()

    def changeAlgoInfo(self):
        """ Listener for the change algorithm settings dialog - saves to disc after obtaining new information."""
        try:
            self.settings["algorithm"] = self.setAlgoInfoDialog.getValues()
            self._save_settings()
            self.setAlgoInfoDialog.close()
        except ValueError:
            msg = QtWidgets.QMessageBox()
            msg.setIcon(QtWidgets.QMessageBox.Critical)
            msg.setText(
                "Value Error\nPlease ensure that all frequencies are strictly numeric, e.g. enter '100' or '100.0', rather than '100 Hz' or 'one hundred'."
            )
            msg.setWindowTitle("Error with frequency input")
            msg.setMinimumWidth(200)
            msg.exec_()

    def exportGIS(self):
        """Export the current heatmap to disc as a TIF file, with associated {}.tif.points georeferencing data. 
        
        This is handled by the lakeator - this method is simply a wrapper and filepath selector."""
        defaultname = self.open_filename[:-4] + "_" + self.settings[
            "algorithm"]["current"] + "_heatmap"
        name, _ = QtWidgets.QFileDialog.getSaveFileName(
            self, "Save image & GIS Metadata", defaultname,
            "TIF files *.tif;; All Files *")
        if name:
            name = name + ".tif"
            self.loc.heatmap_to_GIS(
                self.settings["array"]["GPS"]["coordinates"],
                self.settings["array"]["GPS"]["EPSG"]["input"],
                projected_EPSG=self.settings["array"]["GPS"]["EPSG"]
                ["projected"],
                target_EPSG=self.settings["array"]["GPS"]["EPSG"]["target"],
                filepath=name)

    def _load_settings(self, settings_file="./settings.txt"):
        """Load settings from disc."""
        with open(settings_file, "r") as f:
            self.settings = json.load(f)

    def _save_settings(self, settings_file="./settings.txt"):
        """Save settings to disc."""
        with open(settings_file, "w") as f:
            stngsstr = json.dumps(self.settings, sort_keys=True, indent=4)
            f.write(stngsstr)
Example #26
0
class double_pendulum_window(QW.QMainWindow):
    """
    The main window which houses both options and plot canvas
    """
    def __init__(self):
        super().__init__()

        # Create the main Widget and layout
        self._main = QW.QWidget()
        self.setWindowTitle('Double Pendulum Simulation')
        self.setCentralWidget(self._main)
        self.layout_main = QW.QHBoxLayout(self._main)
        # A shortcut to close the app.
        self.closer = QW.QShortcut(QG.QKeySequence('Ctrl+Q'), self, self.quit)

        self.create_options()
        self.create_plot_window()

    def create_options(self):
        # Create all the options. Both the necessary backend and frontend

        # Backend - here are all the parameters
        # Since QSlider only works for integers, we create a linspace vector
        # for each parameter and use the QSlider value as the index for the
        # linspace vector.
        self.param_names = ['r1', 'm1', 'm2', 'g']
        self.param_min = [0.05, 0.1, 0.1, 1]
        self.param_max = [0.95, 10, 10, 100]
        self.param_start = [45, 9, 9, 18]
        self.param_intervals = [0.01, 0.1, 0.1, 0.5]
        self.param_values = []
        self.current_values = []
        self.param_nums = [
            ((max_ - min_) / int_ + 1) for min_, max_, int_ in zip(
                self.param_min, self.param_max, self.param_intervals)
        ]
        self.param_nums = [np.round(i).astype(int) for i in self.param_nums]

        for min_, max_, nums, start in zip(self.param_min, self.param_max,
                                           self.param_nums, self.param_start):
            # Here we create the actual linspace vectors and add them to the
            # backend
            values = np.linspace(min_, max_, nums)
            self.param_values.append(values)
            self.current_values.append(values[start])

        # Frontend
        self.param_labels = []
        self.param_fields = []
        self.param_value_labels = []

        self.layout_options = QW.QVBoxLayout()
        self.button_restart = QW.QPushButton('Restart program', self)
        self.button_restart.clicked.connect(self.restart_plot)

        # Create each line in the parameter layout
        for i, (name, max_, start, values) in enumerate(
                zip(self.param_names, self.param_nums, self.param_start,
                    self.param_values)):
            label = QW.QLabel(name, self)
            field = QW.QSlider(QC.Qt.Horizontal)
            field.setMinimum(0)
            field.setMaximum(max_ - 1)
            field.setValue(start)
            field.valueChanged.connect(
                lambda sv, i=i: self.update_param_value(sv, i))
            value_label = QW.QLabel(f'{values[start]:.2f}')
            self.param_labels.append(label)
            self.param_fields.append(field)
            self.param_value_labels.append(value_label)

        # Add the parameters to the layout
        self.layout_parameters = QW.QGridLayout()
        for n in range(len(self.param_fields)):
            self.layout_parameters.addWidget(self.param_labels[n], n, 0)
            self.layout_parameters.addWidget(self.param_fields[n], n, 1)
            self.layout_parameters.addWidget(self.param_value_labels[n], n, 2)

        self.layout_options.addWidget(self.button_restart)
        self.layout_options.addLayout(self.layout_parameters)
        self.layout_main.addLayout(self.layout_options)

    def create_plot_window(self):
        # Creates the actual plot window and initializes the animation
        self.fig, self.ax, self.ax2 = dp.animation_window()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setFixedSize(600, 800)

        self.initialize_plot()
        self.tool = NavigationToolbar(self.canvas, self)
        self.addToolBar(self.tool)

        self.layout_main.addWidget(self.canvas)

    def update_param_value(self, slider_index, i):
        # updates the i'th parameter value
        new_value = self.param_values[i][slider_index]
        self.param_value_labels[i].setText(f'{new_value:.2f}')
        self.current_values[i] = new_value

    def restart_plot(self):
        # Clears the plotting window and makes way for a new animtion
        # Stop the animation
        self.canvas.close_event()

        # Delete the animation connection ID, figure and axes objects
        del self.cid
        del self.fig
        del self.ax
        del self.ax2

        # Remove and delete the toolbar
        self.removeToolBar(self.tool)
        del self.tool

        # Delete the canvas
        self.layout_main.removeWidget(self.canvas)
        self.canvas.deleteLater()
        self.canvas = None

        # Create the new window
        self.create_plot_window()

    def initialize_plot(self):
        # Initialize the animation class
        r1, m1, m2, g = self.current_values
        r2 = 1 - r1
        N = 3001
        dt = 0.01
        self.cid = self.canvas.mpl_connect(
            'button_press_event', lambda event: dp._on_mouse(event,
                                                             r1=r1,
                                                             r2=r2,
                                                             ax=self.ax,
                                                             ax2=self.ax2,
                                                             fig=self.fig,
                                                             N=N,
                                                             dt=dt,
                                                             m1=m1,
                                                             m2=m2,
                                                             g=g))

    def quit(self):
        sys.exit()
class MatplotlibWidget(QtWidgets.QWidget, DisplayOptionsListener):
    def __init__(self, parent):
        super(MatplotlibWidget, self).__init__(parent)

        self._canvas = FigureCanvas(Figure())
        self._layout = QtWidgets.QVBoxLayout()
        self._layout.addWidget(self._canvas)
        self.setLayout(self._layout)

        self._display_options = None
        self._map_axes = None
        self._histogram_axes = None
        self._transect_axes = None

        # Variables set by call to update()
        self._plot_type = PlotType.INVALID
        self._array_type = ArrayType.INVALID
        self._array = None
        self._array_stats = None
        self._title = None
        self._name = None
        self._map_zoom = None  # None or float array of shape (2,2).
        self._map_pixel_zoom = None  # None or int array ((imin,imax), (jmin,jmax))

        self._image = None  # Image used for element map.
        self._bar = None  # Bar used for histogram.
        self._bar_norm_x = None  # Normalised x-positions of centres of bars
        #   in range 0 (= min) to 1.0 (= max value).
        self._cmap_int_max = None  # One beyond end, as in numpy slicing.
        self._scale_bar = None
        self._colourbar = None
        self._histogram = None  # Latest (histogram, bin_edges, bin_width).

        self._map_line_points = None  # Line drawn on map (showing transect).
        self._map_line = None
        self._transect = None  # Latest transect (xs, ys, values) before
        #   interpolation.

        # Scale and units initially from display options, but may need to
        # change them if distances are too large, e.g. 1000 mm goes to 1 m.
        self._scale = None
        self._units = None

        # Initialised in initialise().
        self._owning_window = None
        self._mode_type = ModeType.INVALID
        self._mode_handler = None

        # Created when first needed.
        self._black_colourmap = None
        self._white_colourmap = None

    def __del__(self):
        self.set_display_options(None)

    def _adjust_layout(self):
        #self._canvas.figure.tight_layout(pad=1.5)
        pass

    def _create_black_colourmap(self):
        if self._black_colourmap == None:
            colours = [(0, 0, 0), (0, 0, 0)]
            self._black_colourmap = \
                LinearSegmentedColormap.from_list('black', colours, N=1)
        return self._black_colourmap

    def _create_scale_bar(self):
        if self._scale_bar:
            # Clear old scale bar.
            if self._map_axes and self._scale_bar in self._map_axes.artists:
                try:
                    self._scale_bar.remove()
                except:
                    pass
            self._scale_bar = None

        if self._map_axes is not None:
            options = self._display_options
            xticks = self._map_axes.get_xticks()
            size = xticks[1] - xticks[0]
            label = '{:g} {}'.format(size, self._units)
            self._scale_bar = ScaleBar( \
                ax=self._map_axes, size=size, label=label,
                loc=options.scale_bar_location, colour=options.scale_bar_colour)
            self._map_axes.add_artist(self._scale_bar)

    def _create_white_colourmap(self):
        if self._white_colourmap == None:
            colours = [(1, 1, 1), (1, 1, 1)]
            self._white_colourmap = \
                LinearSegmentedColormap.from_list('white', colours, N=1)
        return self._white_colourmap

    def _get_scaled_extent(self):
        # Return extent rectangle (for imshow), corrected for scale and units
        # if using physical units.  self._scale and self._units are updated,
        # and may be different from display_options if showing a particularly
        # large or small extent.
        options = self._display_options
        ny, nx = self._array.shape
        extent = np.array([0.0, nx, ny, 0.0])

        if options.use_scale:
            self._scale = options.scale
            self._units = options.units
            extent *= self._scale
            if self._map_zoom is not None:
                display_rectangle = self._map_zoom * self._scale
            else:
                display_rectangle = extent
            max_dimension = np.absolute(
                np.diff(display_rectangle.reshape((2, 2)))).max()
            if max_dimension > 1000.0:
                self._scale /= 1000.0
                self._units = options.get_next_larger_units(self._units)
                extent /= 1000.0
        else:
            self._scale = 1.0
            self._units = 'pixels'

        return extent

    def _redraw(self):
        with warnings.catch_warnings():
            # Ignore RuntimeWarning when determining colour from colourmap if
            # value is NaN.
            warnings.filterwarnings( \
                'ignore', message='invalid value encountered in less')
            self._canvas.draw()

    def _update_draw(self, refresh=True):
        # Draw using cached variables.

        options = self._display_options
        mpl.rcParams.update({'font.size': options.font_size})

        # Derived quantities.
        show_colourbar = True
        cmap_int_max = None
        if self._array_type == ArrayType.CLUSTER:
            if 'k' in self._array_stats:
                cmap_int_max = self._array_stats['k'] + 1
            else:
                cmap_int_max = self._array_stats['max'] + 1
        elif self._array_type in (ArrayType.PHASE, ArrayType.REGION):
            show_colourbar = False
            cmap_int_max = 2

        self._cmap_int_max = cmap_int_max

        figure = self._canvas.figure
        figure.clear()

        self._map_axes = None
        self._histogram_axes = None
        self._transect_axes = None
        if self._plot_type == PlotType.INVALID:
            return
        elif self._plot_type == PlotType.MAP:
            self._map_axes = figure.subplots()
        elif self._plot_type == PlotType.HISTOGRAM:
            self._histogram_axes = figure.subplots()
        elif self._plot_type == PlotType.MAP_AND_HISTOGRAM:
            self._map_axes, self._histogram_axes = figure.subplots( \
                nrows=2, gridspec_kw={'height_ratios': (3,1)})
        elif self._plot_type == PlotType.MAP_AND_TRANSECT:
            self._map_axes, self._transect_axes = figure.subplots( \
                nrows=2, gridspec_kw={'height_ratios': (3,1)})
        else:
            raise RuntimeError('Invalid plot type')

        cmap = self.create_colourmap()

        cmap_limits = None
        if cmap_int_max is None:
            show_stats = True
            if options.manual_colourmap_zoom:
                cmap_limits = (options.lower_colourmap_limit,
                               options.upper_colourmap_limit)
            else:
                cmap_limits = (self._array_stats['min'],
                               self._array_stats['max'])
            norm = Normalize(cmap_limits[0], cmap_limits[1])
            cmap_ticks = None
        else:
            show_stats = False
            norm = Normalize(-0.5, cmap_int_max - 0.5)
            cmap_ticks = np.arange(0, cmap_int_max)
            if cmap_int_max >= 15:
                cmap_ticks = cmap_ticks[::2]

        if show_stats:
            show_stats = options.show_mean_median_std_lines

        if self._map_axes is None:
            self._image = None
            self._scale = None
            self._units = None
            self._colourbar = None
        else:
            extent = self._get_scaled_extent()
            self._image = self._map_axes.imshow(self._array,
                                                cmap=cmap,
                                                norm=norm,
                                                extent=extent)

            if self._map_line is not None and self.has_transect_axes():
                # Redraw existing map_line on new map_axes.
                try:
                    self._map_line.remove()
                except:
                    pass
                path_effects = self._map_line.get_path_effects()
                self._map_line = self._map_axes.plot( \
                    self._map_line_points[:, 0]*self._scale,
                    self._map_line_points[:, 1]*self._scale, '-', c='k',
                    path_effects=path_effects)[0]

            if self._map_zoom is not None:
                self._map_axes.set_xlim(self._map_zoom[0] * self._scale)
                self._map_axes.set_ylim(self._map_zoom[1] * self._scale)

            if show_colourbar:
                self._colourbar = figure.colorbar(self._image,
                                                  ax=self._map_axes,
                                                  ticks=cmap_ticks)

            if self._title is not None:
                self._map_axes.set_title(self._title + ' map')

            if options.use_scale and options.show_scale_bar:
                self._create_scale_bar()

            # Hide ticks only after creating scale bar as use tick locations
            # to determine scale bar size.
            self._update_map_axes_ticks_and_labels()

        if self._histogram_axes is None:
            self._bar = None
            self._bar_norm_x = None
            self._histogram = None
        else:
            # May only want histogram of zoomed sub array.
            subarray = self._array
            if options.zoom_updates_stats and self._map_pixel_zoom is not None:
                ((imin, imax), (jmin, jmax)) = self._map_pixel_zoom
                subarray = subarray[jmin:jmax, imin:imax]

            if cmap_int_max is not None:
                bins = np.arange(0, cmap_int_max + 1) - 0.5
            elif options.use_histogram_bin_count:
                bins = options.histogram_bin_count
            else:
                # Use bin width, but only if max count not exceeded.
                bin_width = options.histogram_bin_width

                if options.manual_colourmap_zoom:
                    subarray_limits = cmap_limits
                else:
                    subarray_limits = (subarray.min(), subarray.max())
                if (subarray_limits[0] is np.ma.masked
                        or subarray_limits[1] is np.ma.masked
                        or subarray_limits[0] == subarray_limits[1]):
                    # Subarray is all masked out, so cannot display histogram.
                    bins = 1
                else:
                    min_index = math.floor(subarray_limits[0] / bin_width)
                    max_index = math.ceil(subarray_limits[1] / bin_width) - 1
                    bins = max_index - min_index
                    if bins < options.histogram_max_bin_count:
                        bins = bin_width * np.arange(min_index, max_index + 2)
                    else:
                        bins = options.histogram_max_bin_count

            hist, bin_edges = np.histogram(np.ma.compressed(subarray),
                                           bins=bins,
                                           range=cmap_limits)
            bin_width = bin_edges[1] - bin_edges[0]
            self._histogram = (hist, bin_edges, bin_width)
            bin_centres = bin_edges[:-1] + 0.5 * bin_width
            self._bar_norm_x = norm(bin_centres)
            colours = cmap(self._bar_norm_x)
            self._bar = self._histogram_axes.bar(bin_centres,
                                                 hist,
                                                 bin_width,
                                                 color=colours)
            if cmap_ticks is not None:
                self._histogram_axes.set_xticks(cmap_ticks)

            if show_stats:
                mean = self._array_stats.get('mean')
                median = self._array_stats.get('median')
                std = self._array_stats.get('std')
                if mean is not None:
                    label = 'mean'
                    if options.show_mean_median_std_values:
                        label += ' ({:g})'.format(mean)
                    self._histogram_axes.axvline(mean,
                                                 c='k',
                                                 ls='-',
                                                 label=label)
                    if std is not None:
                        label = 'mean \u00b1 std'
                        if options.show_mean_median_std_values:
                            label += ' ({:g})'.format(std)
                        self._histogram_axes.axvline(mean - std,
                                                     c='k',
                                                     ls='-.',
                                                     label=label)
                        self._histogram_axes.axvline(mean + std,
                                                     c='k',
                                                     ls='-.')
                if median is not None:
                    label = 'median'
                    if options.show_mean_median_std_values:
                        label += ' ({:g})'.format(median)
                    self._histogram_axes.axvline(median,
                                                 c='k',
                                                 ls='--',
                                                 label=label)

                if mean is not None or median is not None:
                    self._histogram_axes.legend()

            if self._map_axes is None and self._title is not None:
                self._histogram_axes.set_title(self._title + ' histogram')

        if self.has_transect_axes() and self._map_line is not None:
            x, y = self._map_line.get_data()
            points = np.stack((x, y), axis=1) / self._scale
            self.set_transect(points)

        figure.suptitle(options.overall_title)
        if options.show_project_filename:
            figure.text(0.01, 0.01, options.project_filename)
        if options.show_date:
            figure.text(0.99, 0.01, options.date, ha='right')

        self._adjust_layout()

        if self._mode_handler:
            self._mode_handler.move_to_new_axes()

        if refresh:
            self._redraw()

    def _update_map_axes_ticks_and_labels(self):
        if self._display_options.show_ticks_and_labels:
            self._map_axes.set_xlabel(self._units)
            self._map_axes.set_ylabel(self._units)
        else:
            self._map_axes.set_xticks([])
            self._map_axes.set_yticks([])

    def clear(self):
        # Clear current plots.
        self._canvas.figure.clear()
        self._map_axes = None
        self._histogram_axes = None
        self._transect_axes = None
        self._plot_type = PlotType.INVALID
        self._array_type = ArrayType.INVALID
        self._array = None
        self._array_stats = None
        self._title = None
        self._name = None
        self._image = None
        self._bar = None
        self._bar_norm_x = None
        self._cmap_int_max = None
        self._scale_bar = None
        self._colourbar = None
        self._histogram = None
        self._map_line_points = None
        self._map_line = None
        self._transect = None
        self._scale = None
        self._units = None

        self._redraw()

    def clear_all(self):
        # Clear everything, including cached zoom extent, etc.
        self.clear_map_zoom()
        self.clear()

    def clear_map_zoom(self):
        self._map_zoom = None
        self._map_pixel_zoom = None

    def create_cluster_colourmap(self, k):
        return cm.get_cmap(self._display_options.colourmap_name, k)

    def create_colourmap(self):
        if self._array_type in (ArrayType.PHASE, ArrayType.REGION):
            return self._create_black_colourmap()
        if self._cmap_int_max is None:
            return cm.get_cmap(self._display_options.colourmap_name)
        else:
            return self.create_cluster_colourmap(self._cmap_int_max)

    def create_map_line(self, points, path_effects):
        if self._map_line is not None:
            self._map_line_points = None
            try:
                self._map_line.remove()
            except:
                pass
            self._map_line = None

        if self._map_axes is not None:
            self._map_line_points = points
            self._map_line = self._map_axes.plot( \
                points[:, 0]*self._scale, points[:, 1]*self._scale, '-', c='k',
                path_effects=path_effects)[0]

            self._redraw()

    def export_to_file(self, filename, dpi):
        figure = self._canvas.figure
        figure.savefig(filename, dpi=dpi)

    def get_histogram_at_x(self, x):
        # Return histogram data at specified x value.
        if self._histogram is None:
            return None

        hist, bin_edges, bin_width = self._histogram
        nbins = len(hist)
        i = math.floor((x - bin_edges[0]) / bin_width)
        if i < 0 or i >= len(bin_edges) - 1:
            # In histogram, but not within a bin.
            return [bin_width, nbins]
        else:
            # Within a histogram bin.
            bin_low = bin_edges[i]
            bin_high = bin_edges[i + 1]
            count = hist[i]
            return [bin_width, nbins, bin_low, bin_high, count]

    def get_transect_at_lambda(self, lambda_):
        if self._transect is not None:
            points = self._map_line_points
            xy = (points[0] + lambda_ * (points[1] - points[0])).astype(np.int)
            x, y = xy
            value = self.get_value_at_position(x, y)
            return [x, y, value]
        else:
            return None

    def get_value_at_position(self, x, y):
        # Return value at (x, y) indices of the current array, or None if there
        # is no such value or the value is masked.
        if self._array is not None:
            value = self._array[y, x]
            if value is not np.ma.masked:
                return value
        return None

    def has_content(self):
        return self._plot_type != PlotType.INVALID

    def has_histogram_axes(self):
        return self._histogram_axes is not None

    def has_map_axes(self):
        return self._map_axes is not None

    def has_transect_axes(self):
        return self._transect_axes is not None

    def has_transect_contents(self):
        return (self._transect_axes is not None
                and self._map_line_points is not None)

    def initialise(self,
                   owning_window,
                   display_options,
                   zoom_enabled=True,
                   status_callback=None):
        self._owning_window = owning_window
        self.set_display_options(display_options)
        self._status_callback = status_callback

        if zoom_enabled:
            self._canvas.mpl_connect('axes_enter_event', self.on_axes_enter)
            self._canvas.mpl_connect('axes_leave_event', self.on_axes_leave)
            self._canvas.mpl_connect('button_press_event', self.on_mouse_down)
            self._canvas.mpl_connect('button_release_event', self.on_mouse_up)
            self._canvas.mpl_connect('motion_notify_event', self.on_mouse_move)

            self.set_default_mode_type()
        else:
            self.set_mode_type(ModeType.INVALID)

        self._canvas.mpl_connect('resize_event', self.on_resize)

    def is_region_mode_type(self):
        return self._mode_type in (ModeType.REGION_RECTANGLE,
                                   ModeType.REGION_ELLIPSE,
                                   ModeType.REGION_POLYGON)

    def on_axes_enter(self, event):
        if self._mode_handler:
            self._mode_handler.on_axes_enter(event)

    def on_axes_leave(self, event):
        if self._mode_handler:
            self._mode_handler.on_axes_leave(event)

    def on_mouse_down(self, event):
        if self._mode_handler:
            self._mode_handler.on_mouse_down(event)

    def on_mouse_move(self, event):
        if self._mode_handler:
            self._mode_handler.on_mouse_move(event)

    def on_mouse_up(self, event):
        if self._mode_handler:
            self._mode_handler.on_mouse_up(event)

    def on_resize(self, event):
        # Ticks may have changed, so need to recalculate scale bar.
        if self._scale_bar:
            self._create_scale_bar()

        self._adjust_layout()

    def set_colourmap_limits(self, lower, upper):
        # Needed for new_phase_filtered_dialog only.
        if self._image is not None:
            self._image.set_clim(lower, upper)
            cmap = self._image.get_cmap()
            cmap.set_over('w')
            cmap.set_under('w')
            self._redraw()

    def set_default_mode_type(self):
        # Return True if need to call reset on mode_handler after widget is
        # drawn.
        if self._plot_type == PlotType.MAP_AND_TRANSECT:
            self.set_mode_type(ModeType.TRANSECT)
        else:
            self.set_mode_type(ModeType.ZOOM)

    def set_display_options(self, display_options):
        if self._display_options is not None:
            self._display_options.unregister_listener(self)
        self._display_options = display_options
        if self._display_options is not None:
            self._display_options.register_listener(self)

        if self._mode_handler is not None:
            self._mode_handler.set_display_options(display_options)

    def set_mode_type(self, mode_type, listener=None):
        if mode_type != self._mode_type:
            self._mode_type = mode_type

            if self._mode_handler:
                self._mode_handler.clear()

            options = self._display_options

            if mode_type == ModeType.ZOOM:
                self._mode_handler = ZoomHandler(self, options, \
                    self._status_callback)
            elif mode_type == ModeType.REGION_RECTANGLE:
                self._mode_handler = RectangleRegionHandler(
                    self, options, self._status_callback, listener)
            elif mode_type == ModeType.REGION_ELLIPSE:
                self._mode_handler = EllipseRegionHandler(
                    self, options, self._status_callback, listener)
            elif mode_type == ModeType.REGION_POLYGON:
                self._mode_handler = PolygonRegionHandler(
                    self, options, self._status_callback, listener)
            elif mode_type == ModeType.TRANSECT:
                self._mode_handler = TransectHandler(self, options,
                                                     self._status_callback)
            else:
                self._mode_handler = None

    def set_transect(self, points):
        if not self.has_transect_axes():
            raise RuntimeError('MatplotlibWidget does not have transect axes')

        lambdas, xs, ys, values = calculate_transect( \
            self._array, points[0], points[1])

        self._transect = (xs, ys, values)

        axes = self._transect_axes
        axes.clear()

        if self._display_options.transect_uses_colourmap:
            lambdas, values = adaptive_interp(lambdas, values, 19)

            points = np.array([lambdas, values]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lines = LineCollection(segments,
                                   cmap=self.create_colourmap(),
                                   norm=self._image.norm)
            # Set the array that is used to determine colours.
            lines.set_array(0.5 * (points[:-1, 0, 1] + points[1:, 0, 1]))

            lines = axes.add_collection(lines)
            axes.autoscale_view(scalex=False, scaley=True)
        else:
            axes.plot(lambdas, values)

        axes.set_xlim(0.0, 1.0)  # Needed in case end points are masked out.

        if self._owning_window is not None:
            self._owning_window.update_controls()

    def update(self,
               plot_type,
               array_type,
               array,
               array_stats,
               title,
               name,
               map_zoom=None,
               map_pixel_zoom=None,
               refresh=True):
        self._plot_type = plot_type
        self._array_type = array_type
        self._array = array
        self._array_stats = array_stats
        self._title = title
        self._name = name
        self._map_zoom = map_zoom
        self._map_pixel_zoom = map_pixel_zoom

        if not self.is_region_mode_type():
            self.set_default_mode_type()

        self._update_draw(refresh)

    def update_colourmap_name(self):
        # Handler for DisplayOptions callback.
        colourmap_name = self._display_options.colourmap_name
        cmap = self.create_colourmap()

        if self._array_type == ArrayType.PHASE:
            return

        if self._image is not None:
            self._image.set_cmap(cmap)

        if self._bar is not None:
            colours = cmap(self._bar_norm_x)
            for index, item in enumerate(self._bar):
                item.set_color(colours[index])

        if (self.has_transect_axes()
                and self._display_options.transect_uses_colourmap):
            line_collection = self._transect_axes.collections[0]
            line_collection.set_cmap(cmap)

        self._redraw()

    def update_map_line(self, points):
        if self._map_line is not None and self._map_axes is not None:
            self._map_line_points = points
            self._map_line.set_data(points[:, 0] * self._scale,
                                    points[:, 1] * self._scale)
            self._redraw()
Example #28
0
class CalibWin(QWidget, _calf.calib_functions_mixin, _ie.Imp_Exp_Mixin):
    """ Energy calibration window

    For the names of the children widgets, I tried to put suffixes that indicate clearly their types:
    *_btn -> QPushButton,
    *_le -> QLineEdit,
    *_lb -> QLabel,
    *layout -> QHBoxLayout, QVBoxLayout or QGridLayout,
    *_box -> QGroupBox,
    *_cb -> QCheckBox,
    *_rb -> QRadioButton

    The functions that are connected to a widget's event have the suffix _lr (for 'listener'). For example, a button named
    test_btn will be connected to a function test_lr.
    Some functions may be connected to widgets but without the suffix _lr in their names. It means that they are not only
    called when interacting with the widget.
    """

    ##################################################################################
    ############################ Widget Initialization ###############################

    def __init__(self, parent=None):
        """Initialization of the window

        the main layout is called mainLayout. It is divided in two:
            - graphLayout: the left part, contains all the figures
            - commandLayout: the right part, contains all the buttons, fields, checkboxes...
        Both graphLayout and commandLayout are divided into sub-layouts.

        This function calls several functions to initialize each part of the window. The name of these functions contains
        'init_*layout'.
        """
        super(CalibWin, self).__init__(parent=parent)

        self.setWindowTitle("Energy Calibration")
        self.mainlayout = QHBoxLayout()
        self.graphlayout = QVBoxLayout()
        self.commandlayout = QVBoxLayout()
        self.commandlayout.setSpacing(10)

        # initialization of all the widgets/layouts
        self.init_btnlayout()
        self.init_tof2evlayout()
        self.init_eparlayout()
        self.init_fitparlayout()
        self.init_envectlayout()
        self.init_tofgraphlayout()
        self.init_graphauxlayout()

        # making the buttons not resizable
        for widget in self.children():
            if isinstance(widget, QPushButton):
                widget.setSizePolicy(0, 0)

        self.mainlayout.addLayout(self.graphlayout)
        self.mainlayout.addLayout(self.commandlayout)
        self.setLayout(self.mainlayout)

        self.init_var()

        self.show()

    def init_var(self):
        ''' Initialization of instance attributes'''

        self.withsb_bool = False
        self.calibloaded = False
        self.dataloaded = False
        self.bgndremoved = False
        self.threshyBool = False
        self.threshxminBool = False
        self.threshxmaxBool = False
        self.showexppeaksBool = False
        self.calibBool = False
        self.peaksfound = False

        self.thxmin = 0
        self.thy = 0
        self.thxmax = 0

        self.counts = []

        self.gas_combo.setCurrentIndex(2)  # Argon

    def init_btnlayout(self):
        ''' In commandLayout - Initialization of the top right part of the layout, with 6 buttons (see just below)'''

        btnlayout = QGridLayout()
        btnlayout.setSpacing(10)

        self.load_btn = QPushButton("Load data", self)
        self.rmbgnd_btn = QPushButton("Remove bgnd", self)
        self.findpeaks_btn = QPushButton("Find peaks", self)
        self.rmpeaks_btn = QPushButton("Remove peaks", self)
        self.exportcalib_btn = QPushButton("Export calib", self)
        self.importcalib_btn = QPushButton("Import calib", self)
        self.exportXUV_btn = QPushButton("Export XUV", self)

        self.rmbgnd_btn.setEnabled(False)
        self.findpeaks_btn.setEnabled(False)
        self.rmpeaks_btn.setEnabled(False)
        self.exportcalib_btn.setEnabled(False)
        self.importcalib_btn.setEnabled(False)
        self.exportXUV_btn.setEnabled(False)

        self.load_btn.clicked.connect(self.loadfile_lr)
        self.rmbgnd_btn.clicked.connect(self.rmbgnd_lr)
        self.findpeaks_btn.clicked.connect(self.findpeaks_lr)
        self.importcalib_btn.clicked.connect(self.importcalib_lr)
        self.rmpeaks_btn.clicked.connect(self.removepeaks_lr)
        self.exportXUV_btn.clicked.connect(self.exportXUV_lr)
        self.exportcalib_btn.clicked.connect(self.exportcalib_lr)

        btnlayout.addWidget(self.load_btn, 0, 0)
        btnlayout.addWidget(self.rmbgnd_btn, 0, 1)
        btnlayout.addWidget(self.findpeaks_btn, 1, 0)
        btnlayout.addWidget(self.rmpeaks_btn, 1, 1)
        btnlayout.addWidget(self.exportcalib_btn, 1, 3)
        btnlayout.addWidget(self.importcalib_btn, 1, 2)
        btnlayout.addWidget(self.exportXUV_btn, 0, 3)
        self.commandlayout.addLayout(btnlayout)

    def init_tof2evlayout(self):
        ''' In commandLayout - Initialization of the tof to eV section: parameters of the af.find_local_maxima function,
            'TOF to energy' button and 'with sidebands checkbox' '''

        tof2evlayout = QHBoxLayout()
        flmlayout = QGridLayout()
        flmlayout.setSpacing(10)
        self.flm_box = QGroupBox(self)
        self.flm_box.setTitle("Find local maxima parameters")

        self.sm1_le = QLineEdit("5", self)
        self.sm2_le = QLineEdit("100", self)
        self.mindt_le = QLineEdit("10", self)

        flmlayout.addWidget(QLabel("smooth1"), 0, 0)
        flmlayout.addWidget(self.sm1_le, 1, 0)
        flmlayout.addWidget(QLabel("smooth2"), 0, 1)
        flmlayout.addWidget(self.sm2_le, 1, 1)
        flmlayout.addWidget(QLabel("min dt"), 0, 2)
        flmlayout.addWidget(self.mindt_le, 1, 2)
        self.flm_box.setLayout(flmlayout)

        for widget in self.flm_box.children():
            if isinstance(widget, QLineEdit):
                widget.setSizePolicy(0, 0)
                widget.setFixedSize(50, 20)

        self.tof2en_btn = QPushButton("TOF to energy", self)
        self.tof2en_btn.clicked.connect(self.tof2en_lr)
        self.tof2en_btn.setEnabled(False)

        self.withsb_cb = QCheckBox("With sidebands", self)
        self.withsb_cb.stateChanged.connect(self.withsb_fn)

        tof2evlayout.addWidget(self.flm_box)
        tof2evlayout.addWidget(self.withsb_cb)
        tof2evlayout.addWidget(self.tof2en_btn)
        self.commandlayout.addLayout(tof2evlayout)

    def init_eparlayout(self):
        ''' In commandLayout - Initialization of the experimental parameters section: Retarding potential, TOF length,
            wavelength, gas and first harmonic expected to see.'''

        gases = cts.GASLIST

        epar_box = QGroupBox(self)
        epar_box.setTitle("Experimental parameters")
        epar_box.setSizePolicy(0, 0)
        eparlayout = QGridLayout()
        eparlayout.setSpacing(10)

        self.retpot_le = QLineEdit(str(cts.cur_Vp), self)
        self.toflength_le = QLineEdit(str(cts.cur_L), self)
        self.wvlength_le = QLineEdit(str(cts.lambda_start), self)
        self.gas_combo = QComboBox(self)
        self.gas_combo.addItems(gases)
        self.firstharm_le = QLineEdit(str(cts.first_harm), self)

        self.retpot_le.returnPressed.connect(self.update_cts_fn)
        self.toflength_le.returnPressed.connect(self.update_cts_fn)
        self.wvlength_le.returnPressed.connect(self.update_cts_fn)
        self.firstharm_le.returnPressed.connect(self.update_cts_fn)
        self.gas_combo.currentIndexChanged.connect(self.gas_combo_lr)

        eparlayout.addWidget(QLabel("Ret. pot. (V)"), 0, 0)
        eparlayout.addWidget(self.retpot_le, 1, 0)
        eparlayout.addWidget(QLabel("TOF length (m)"), 0, 1)
        eparlayout.addWidget(self.toflength_le, 1, 1)
        eparlayout.addWidget(QLabel("lambda (nm)"), 0, 2)
        eparlayout.addWidget(self.wvlength_le, 1, 2)
        eparlayout.addWidget(QLabel("gas"), 0, 3)
        eparlayout.addWidget(self.gas_combo, 1, 3)
        eparlayout.addWidget(QLabel("1st harm."), 0, 4)
        eparlayout.addWidget(self.firstharm_le, 1, 4)

        epar_box.setLayout(eparlayout)

        for widget in epar_box.children():
            if isinstance(widget, QLineEdit) or isinstance(widget, QComboBox):
                widget.setFixedSize(50, 20)
                widget.setSizePolicy(0, 0)

        self.commandlayout.addWidget(epar_box)

    def init_fitparlayout(self):
        ''' In commandLayout - Initialization of the fit parameter section with a, t0 and c. First line: guess values calculated
            from the experimental parameters. Second line: fitted values.'''

        fitpar_box = QGroupBox()
        fitpar_box.setTitle("Calibration parameters")
        fitpar_box.setSizePolicy(0, 0)
        fitparlayout = QGridLayout()
        fitparlayout.setSpacing(10)

        self.aguess_lb = QLabel(self)
        self.afit_lb = QLabel(self)
        self.t0guess_lb = QLabel(self)
        self.t0fit_lb = QLabel(self)
        self.cguess_lb = QLabel(self)
        self.cfit_lb = QLabel(self)

        fitparlayout.addWidget(QLabel("a"), 0, 0)
        fitparlayout.addWidget(self.aguess_lb, 1, 0)
        fitparlayout.addWidget(self.afit_lb, 2, 0)
        fitparlayout.addWidget(QLabel("t0"), 0, 1)
        fitparlayout.addWidget(self.t0guess_lb, 1, 1)
        fitparlayout.addWidget(self.t0fit_lb, 2, 1)
        fitparlayout.addWidget(QLabel("c"), 0, 2)
        fitparlayout.addWidget(self.cguess_lb, 1, 2)
        fitparlayout.addWidget(self.cfit_lb, 2, 2)
        fitparlayout.addWidget(QLabel("guess"), 1, 3)
        fitparlayout.addWidget(QLabel("calc"), 2, 3)
        text = "q=a*1/(t-t0)²+c"
        fitparlayout.addWidget(QLabel(text), 1, 4)

        fitpar_box.setLayout(fitparlayout)

        for widget in fitpar_box.children():
            if isinstance(widget, QLabel) and widget.text() != text:
                widget.setSizePolicy(0, 0)
                widget.setFixedSize(55, 15)

        self.commandlayout.addWidget(fitpar_box)

    def init_envectlayout(self):
        ''' In commandLayout - Initialization of the resulting energy vector section, with elow, ehigh and dE'''

        envect_box = QGroupBox()
        envect_box.setTitle("Energy vector parameters")
        envect_box.setSizePolicy(0, 0)
        envectlayout = QGridLayout()
        envectlayout.setSpacing(10)

        self.elow_le = QLineEdit("{:.2f}".format(cts.elow), self)
        self.ehigh_le = QLineEdit("{:.2f}".format(cts.ehigh), self)
        self.dE_le = QLineEdit(str(cts.dE), self)

        self.elow_le.returnPressed.connect(self.update_envect_fn)
        self.ehigh_le.returnPressed.connect(self.update_envect_fn)
        self.dE_le.returnPressed.connect(self.update_envect_fn)

        envectlayout.addWidget(QLabel("E low (eV)"), 0, 0)
        envectlayout.addWidget(self.elow_le, 1, 0)
        envectlayout.addWidget(QLabel("E high (eV)"), 0, 1)
        envectlayout.addWidget(self.ehigh_le, 1, 1)
        envectlayout.addWidget(QLabel("dE (eV)"), 0, 2)
        envectlayout.addWidget(self.dE_le, 1, 2)

        envect_box.setLayout(envectlayout)

        for widget in envect_box.children():
            if isinstance(widget, QLabel) or isinstance(widget, QLineEdit):
                widget.setSizePolicy(0, 0)
                widget.setFixedSize(55, 20)

        self.commandlayout.addWidget(envect_box)
        self.update_envect_fn()

    def init_tofgraphlayout(self):
        ''' In graphLayout - Initialization of the top figure on the window, where the time of flight is plotted'''

        tof_fig = Figure(figsize=(4, 3), dpi=100)
        self.tof_fc = FigureCanvas(tof_fig)
        self.tof_fc.mpl_connect('button_press_event', self.onclick)
        self.tof_ax = self.tof_fc.figure.add_subplot(111)
        self.tof_fc.draw()
        tof_nav = NavigationToolbar2QT(self.tof_fc, self)
        tof_nav.setStyleSheet("QToolBar { border: 0px }")

        tgparalayout = QGridLayout()
        self.threshtype_bgroup = QButtonGroup()
        rblayout = QGridLayout()
        self.setth_cb = QCheckBox("Set threshold", self)
        self.setth_cb.setEnabled(False)
        self.setth_cb.stateChanged.connect(self.setth_lr)
        self.addpeak_cb = QCheckBox("Add peak", self)
        self.addpeak_cb.setEnabled(False)
        self.addpeak_cb.stateChanged.connect(self.addpeak_lr)
        self.showexppeaks_cb = QCheckBox("Show expected peaks", self)
        self.showexppeaks_cb.stateChanged.connect(self.showexppeaks_lr)
        self.showexppeaks_cb.setEnabled(False)

        self.clear_btn = QPushButton("Clear", self)
        self.clear_btn.clicked.connect(self.clear_lr)
        self.clear_btn.setEnabled(False)

        self.Y_rb = QRadioButton("Y", self)
        self.Y_rb.value = "Y"
        self.Y_rb.toggled.connect(self.threshtype_lr)
        self.Y_rb.toggle()
        self.Y_rb.setEnabled(False)

        self.xmin_rb = QRadioButton("X min", self)
        self.xmin_rb.value = "Xm"
        self.xmin_rb.toggled.connect(self.threshtype_lr)
        self.xmin_rb.setEnabled(False)

        self.xmax_rb = QRadioButton("X max", self)
        self.xmax_rb.value = "XM"
        self.xmax_rb.toggled.connect(self.threshtype_lr)
        self.xmax_rb.setEnabled(False)

        self.minus_sign_cb = QCheckBox("*(-1)", self)
        self.minus_sign_cb.setEnabled(False)
        self.minus_sign_cb.stateChanged.connect(self.minus_sign_lr)

        self.threshtype_bgroup.addButton(self.Y_rb)
        self.threshtype_bgroup.addButton(self.xmin_rb)
        self.threshtype_bgroup.addButton(self.xmax_rb)

        rblayout.addWidget(self.Y_rb, 0, 0)
        rblayout.addWidget(self.xmin_rb, 0, 1)
        rblayout.addWidget(self.xmax_rb, 0, 2)
        tgparalayout.addWidget(self.setth_cb, 0, 0)
        tgparalayout.addWidget(self.addpeak_cb, 0, 1)
        tgparalayout.addWidget(self.showexppeaks_cb, 0, 2)
        tgparalayout.addWidget(self.clear_btn, 0, 3)
        tgparalayout.addLayout(rblayout, 1, 0, 1, 3)
        tgparalayout.addWidget(self.minus_sign_cb, 1, 3)

        self.graphlayout.addWidget(self.tof_fc)
        self.graphlayout.addWidget(tof_nav)
        self.graphlayout.addLayout(tgparalayout)

    def init_graphauxlayout(self):
        ''' In graphLayout - Initialization the two bottom figures on the window'''

        graphauxlayout = QHBoxLayout()
        ga1layout = QVBoxLayout()
        ga2layout = QVBoxLayout()

        fit_fig = Figure(figsize=(2, 2), dpi=100)
        self.fit_fc = FigureCanvas(fit_fig)
        self.fit_fc.setSizePolicy(1, 0)
        self.fit_ax = self.fit_fc.figure.add_subplot(111)
        self.fit_ax.tick_params(labelsize=8)
        ga1layout.addWidget(self.fit_fc)
        self.fit_fc.draw()

        en_fig = Figure(figsize=(2, 2), dpi=100)
        self.en_fc = FigureCanvas(en_fig)
        #self.en_fc.setSizePolicy(1, 0)
        self.en_ax = self.en_fc.figure.add_subplot(111)
        self.en_ax.tick_params(labelsize=8)
        self.en_fc.draw()
        en_nav = NavigationToolbar2QT(self.en_fc, self)
        en_nav.setStyleSheet("QToolBar { border: 0px }")
        ga2layout.addWidget(self.en_fc)
        ga2layout.addWidget(en_nav)

        graphauxlayout.addLayout(ga1layout)
        graphauxlayout.addLayout(ga2layout)
        self.graphlayout.addLayout(graphauxlayout)

    #################################################################################
    ############################ Other methods ######################################

    def update_fitpar_fn(self):
        ''' Updating the fit parameters on the window'''

        if (self.calibBool):
            self.afit_lb.setText("{:.3e}".format(cts.afit))
            self.t0fit_lb.setText("{:.3e}".format(cts.t0fit))
            self.cfit_lb.setText("{:.3f}".format(cts.cfit))

    def update_envect_fn(self):
        ''' Updating th energy vector parameters with the values written in the associated QLineEdit objects'''

        cts.elow = float(self.elow_le.text())
        cts.ehigh = float(self.ehigh_le.text())
        cts.dE = float(self.dE_le.text())
        self.elow_le.setText("{:.2f}".format(cts.elow))
        self.ehigh_le.setText("{:.2f}".format(cts.ehigh))
        self.window().updateglobvar_fn()

    def gas_combo_lr(self, i):
        ''' Gas QCombobox listener'''
        cts.cur_Ip = cts.IPLIST[i]
        cts.first_harm = cts.FIRST_HARMLIST[i]
        self.firstharm_le.setText(str(cts.first_harm))
        cts.elow = (cts.first_harm - 1) * cts.HEV * cts.cur_nu
        self.elow_le.setText("{:.2f}".format(cts.elow))
        self.update_cts_fn()

    def threshtype_lr(self):
        ''' Listener of the threshold radiobuttons: Y, Xmin or Xmax'''
        rb = self.sender()
        self.threshtype = rb.value

    def withsb_fn(self, state):
        ''' "with sidebands" checkbox listener '''
        if state == Qt.Checked:
            self.withsb_bool = True
        else:
            self.withsb_bool = False

    def showexppeaks_lr(self, state):
        ''' "show expected peaks" checkbox listener '''
        if state == Qt.Checked:
            self.showexppeaksBool = True
        else:
            self.showexppeaksBool = False
        self.refreshplot_fn()

    def setth_lr(self):
        ''' "set threshold" checkbox listener '''
        if self.setth_cb.isChecked():
            self.addpeak_cb.setCheckState(Qt.Unchecked)

    def addpeak_lr(self):
        ''' "add peak" checkbox listener '''
        if self.addpeak_cb.isChecked():
            self.setth_cb.setCheckState(Qt.Unchecked)

    def removepeaks_lr(self):
        ''' "remove peaks" button listener '''
        rmp = rmPeaksDialog(self)  # new class defined below

    def refreshplot_fn(self):
        ''' Updating the top left (TOF) graph'''
        xmin, xmax = self.tof_ax.get_xlim()
        ymin, ymax = self.tof_ax.get_ylim()
        self.tof_ax.cla()

        self.tof_ax.xaxis.set_major_formatter(FormatStrFormatter('%2.e'))
        self.tof_ax.yaxis.set_major_formatter(FormatStrFormatter('%1.e'))
        self.tof_ax.set_ylabel("counts (arb. units)")
        self.tof_ax.set_xlabel("TOF (s)")

        if self.dataloaded:
            self.tof_ax.plot(self.counts[:, 0], self.counts[:, 1])
        if self.threshyBool:
            self.tof_ax.plot(self.threshyline[:, 0], self.threshyline[:, 1],
                             'k')
        if self.threshxminBool:
            self.tof_ax.plot(self.threshxminline[:, 0],
                             self.threshxminline[:, 1], 'k')
        if self.threshxmaxBool:
            self.tof_ax.plot(self.threshxmaxline[:, 0],
                             self.threshxmaxline[:, 1], 'k')
        if self.peaksfound:
            self.tof_ax.plot(self.maximaIndices, self.maximaIntensity, 'ro')
            self.tof_ax.plot(self.counts[:, 0], self.convolvedsignal)

        if self.showexppeaksBool:
            qq2 = np.arange(cts.first_harm, 35, 1)

            y = np.linspace(ymax - (ymax - ymin) * 0.2, ymax, 100)
            for i in range((len(qq2))):
                if qq2[i] % 2 == 0:
                    c = 'r'
                else:
                    c = 'k'
                try:
                    xval = float(
                        np.math.sqrt(0.5 * cts.ME * cts.cur_L**2 / cts.QE /
                                     (qq2[i] * cts.HEV * cts.cur_nu -
                                      cts.cur_Ip)) + 6e-8)
                    # NB: cts.QE is used to convert the energy from eV to Joules. It's not the electron's charge
                    x = np.full((100, 1), xval)
                except Exception:
                    print(traceback.format_exception(*sys.exc_info()))
                self.tof_ax.plot(x, y, color=c, linewidth=1.0)

        if self.bgndremoved:
            self.bgndremoved = False  #this means that when we remove bgnd, we don't keep the same scale
        else:
            self.tof_ax.set_ylim(ymin, ymax)

        self.tof_ax.set_xlim(xmin, xmax)
        self.tof_fc.draw()

    def onclick(self, event):
        ''' called when double-clicking on the TOF graph'''
        if self.dataloaded and self.addpeak_cb.isChecked(
        ):  #add a peak on clicking on the figure
            i = 0
            ifound = False
            while (i < len(self.maximaIndices) and ifound == False):
                if (self.maximaIndices[i] < event.xdata):
                    i += 1
                else:
                    ifound = True
            self.maximaIndices.insert(i, event.xdata)
            self.maximaIntensity.insert(i, event.ydata)
            self.refreshplot_fn()

        if self.dataloaded and self.setth_cb.isChecked():
            if self.Y_rb.isChecked():
                self.thy = event.ydata
                self.threshyline = np.full((len(self.counts), 2), self.thy)
                self.threshyline[:, 0] = self.counts[:, 0]
                self.threshyBool = True
            elif self.xmin_rb.isChecked():
                self.thxmin = event.xdata
                self.threshxminline = np.full((len(self.counts), 2),
                                              self.thxmin)
                self.threshxminline[:, 1] = self.counts[:, 1]
                self.threshxminBool = True
            elif self.xmax_rb.isChecked():
                self.thxmax = event.xdata
                self.threshxmaxline = np.full((len(self.counts), 2),
                                              self.thxmax)
                self.threshxmaxline[:, 1] = self.counts[:, 1]
                self.threshxmaxBool = True
            self.refreshplot_fn()

    def exportXUV_lr(self):
        ''' "Export XUV" button listener. Saving the energy vector and the energy-converted TOF signal'''
        filename = QFileDialog.getSaveFileName(self, 'Save XUV')
        fname = filename[0]
        if fname:
            XUV_array = np.vstack((self.Eevlin, self.signal)).T
            np.savetxt(fname, XUV_array, delimiter='\t')

    def clear_lr(self):
        ''' "Clear" button listener. Resets all the objects, but not the global variables'''
        self.tof_ax.cla()
        self.fit_ax.cla()
        self.en_ax.cla()

        self.dataloaded = False
        self.threshyBool = False
        self.threshxminBool = False
        self.threshxmaxBool = False
        self.calibloaded = False
        self.calibBool = False

        self.Y_rb.setEnabled(False)
        self.xmin_rb.setEnabled(False)
        self.xmax_rb.setEnabled(False)
        self.minus_sign_cb.setEnabled(False)

        for w in self.children():
            if isinstance(w, QPushButton) or isinstance(w, QCheckBox):
                w.setEnabled(False)
        self.load_btn.setEnabled(True)
        self.importcalib_btn.setEnabled(False)

        self.minus_sign_cb.setEnabled(False)

        self.tof_fc.draw()
        self.fit_fc.draw()
        self.en_fc.draw()

        cts.clear_varlist()
        self.window().updateglobvar_fn()
class MyTableWidget(QWidget):
    def __init__(self, parent):
        super(QWidget, self).__init__(parent)
        self.layout = QVBoxLayout(self)

        # Initialize tabs ----------------------------------
        self.tabs = QTabWidget()
        self.Load_data = QWidget()  # create tab 1
        self.SettRzPlot_tab = QWidget()  # create tab 2
        self.Rz_tab = QWidget()  # create tab 3
        self.tabs.resize(300, 200)

        # Add tabs to the Main WIndow
        self.tabs.addTab(self.Load_data, "Load data")  # tab 1
        self.tabs.addTab(self.SettRzPlot_tab, "Rz plot Settings")  # tab 2
        self.tabs.addTab(self.Rz_tab, "Rz plot")  # tab 3

        # Add tabs to widget
        self.layout.addWidget(self.tabs)
        self.setLayout(self.layout)
        self.show()
        # ----------------------------------------------------------------------------------

        # Load_data tab - content
        self.data_loaded = False
        layout_load = QtWidgets.QVBoxLayout(self.Load_data)  # main layout
        sublayout_load = QtWidgets.QGridLayout()  # layout for inputs
        layout_load.addLayout(sublayout_load)

        # Input widgets
        # Shot
        self.Shot_lbl_load = QLabel(self.Load_data)
        self.Shot_lbl_load.setText('Shot # ')
        self.Shot_ed_load = QLineEdit(self.Load_data)
        self.Shot_ed_load.setText('25781')
        # Diag
        self.Diag_lbl_load = QLabel(self.Load_data)
        self.Diag_lbl_load.setText('Diag: ')
        self.Diag_load = QComboBox(self.Load_data)
        self.Diag_load.addItems(['ECI', 'TDI'])
        self.Diag_lbl_EQ_load = QLabel(self.Load_data)
        self.Diag_lbl_EQ_load.setText('Equilibrium: ')
        self.Diag_EQ_load = QComboBox(self.Load_data)
        self.Diag_EQ_load.addItems(['EQH'])
        # Load button
        self.Butt_load = QPushButton("Load ECEI and equilibrium data",
                                     self.Load_data)
        self.Butt_load.clicked.connect(self.Load_ECEI_data)
        # Monitor
        self.Monitor_load = QtWidgets.QTextBrowser(self.Load_data)
        self.Monitor_load.setText("Status:\nECEI data is not loaded")

        # Add widgets to layout
        sublayout_load.setSpacing(5)
        sublayout_load.addWidget(self.Shot_lbl_load, 0, 0)
        sublayout_load.addWidget(self.Diag_lbl_load, 1, 0)
        sublayout_load.addWidget(self.Diag_lbl_EQ_load, 2, 0)
        sublayout_load.addWidget(self.Shot_ed_load, 0, 1)
        sublayout_load.addWidget(self.Diag_load, 1, 1)
        sublayout_load.addWidget(self.Diag_EQ_load, 2, 1)
        sublayout_load.addWidget(self.Butt_load, 3, 1)

        sublayout_2_load = QtWidgets.QGridLayout()  # layout for inputs
        layout_load.addLayout(sublayout_2_load)
        sublayout_2_load.addWidget(self.Monitor_load, 1, 0)

        # stretch free space (compress widgets at the top)
        layout_load.addStretch()

        # ----------------------------------------------------------------------------------
        # Rz plot tab - content
        # Create layouts
        layout_RzPl = QtWidgets.QVBoxLayout(self.Rz_tab)  # main layout
        sublayout_RzPl = QtWidgets.QGridLayout()  # layout for inputs
        layout_RzPl.addLayout(sublayout_RzPl)

        # Input widgets
        # labels
        self.tB_lbl_RzPl = QLabel(self.Rz_tab)
        self.tB_lbl_RzPl.setText('tB [s]:')
        self.tE_lbl_RzPl = QLabel(self.Rz_tab)
        self.tE_lbl_RzPl.setText('tE [s]:')
        self.tCnt_lbl_RzPl = QLabel(self.Rz_tab)
        self.tCnt_lbl_RzPl.setText('tCenter [s] (optional):')
        self.dt_lbl_RzPl = QLabel(self.Rz_tab)
        self.dt_lbl_RzPl.setText('dt [s](optional) :')
        # filter labels
        self.Fourier_lbl0_RzPl = QLabel(self.Rz_tab)
        self.Fourier_lbl0_RzPl.setText('Fourier lowpass f [kHz]:')
        self.Fourier2_lbl0_RzPl = QLabel(self.Rz_tab)
        self.Fourier2_lbl0_RzPl.setText('Fourier highpass f [kHz]:')
        self.SavGol_lbl0_RzPl = QLabel(self.Rz_tab)
        self.SavGol_lbl0_RzPl.setText('SavGol win_len:')
        self.SavGol_lbl1_RzPl = QLabel(self.Rz_tab)
        self.SavGol_lbl1_RzPl.setText('SavGol pol_ord:')
        self.Binning_lbl_RzPl = QLabel(self.Rz_tab)
        self.Binning_lbl_RzPl.setText('Binning [kHz]:')
        self.Contour_lbl_RzPl = QLabel(self.Rz_tab)
        self.Contour_lbl_RzPl.setText('Contour [1 or 0]')
        self.NNcont_lbl_RzPl = QLabel(self.Rz_tab)
        self.NNcont_lbl_RzPl.setText('NNcont:')
        self.tplot_lbl_RzPl = QLabel(self.Rz_tab)
        self.tplot_lbl_RzPl.setText('t_plot [s](within tB and tE):')
        self.dtplot_lbl_RzPl = QLabel(self.Rz_tab)
        self.dtplot_lbl_RzPl.setText('dt_plot [s]:')
        self.FourMult_lbl_RzPl = QLabel(self.Rz_tab)
        self.FourMult_lbl_RzPl.setText('Fourier multiple f [kHz]:')

        # plot params labels
        self.vmin_lbl_RzPl = QLabel(self.Rz_tab)
        self.vmin_lbl_RzPl.setText('vmin:')
        self.vmax_lbl_RzPl = QLabel(self.Rz_tab)
        self.vmax_lbl_RzPl.setText('vmax:')
        self.chzz_lbl_RzPl = QLabel(self.Rz_tab)
        self.chzz_lbl_RzPl.setText('Remove LOS:')
        self.chRR_lbl_RzPl = QLabel(self.Rz_tab)
        self.chRR_lbl_RzPl.setText('Remove R chs:')

        # velocimetry specific labels
        self.rhop_lbl_RzPl = QLabel(self.Rz_tab)
        self.rhop_lbl_RzPl.setText('rho_pol:')

        # line edits
        # time edits
        self.tB_ed_RzPl = QLineEdit(self.Rz_tab)
        self.tB_ed_RzPl.setText('4.488525')
        self.tB_ed_RzPl.setMinimumSize(QtCore.QSize(55, 0))
        self.tE_ed_RzPl = QLineEdit(self.Rz_tab)
        self.tE_ed_RzPl.setText('4.489525')
        self.tE_ed_RzPl.setMinimumSize(QtCore.QSize(55, 0))
        self.tCnt_ed_RzPl = QLineEdit(self.Rz_tab)
        self.tCnt_ed_RzPl.setMinimumSize(QtCore.QSize(50, 0))
        self.dt_ed_RzPl = QLineEdit(self.Rz_tab)
        self.dt_ed_RzPl.setText('0.001')
        self.dt_ed_RzPl.setMinimumSize(QtCore.QSize(100, 0))
        self.Butt_dt_RzPl = QPushButton("Calc t", self.Rz_tab)
        self.Butt_dt_RzPl.clicked.connect(lambda: self.tBE_from_tCnt(9))
        # plot params edits
        self.vmin_ed_RzPl = QLineEdit(self.Rz_tab)
        self.vmin_ed_RzPl.setText('None')
        self.vmin_ed_RzPl.setMinimumSize(QtCore.QSize(40, 0))
        self.vmax_ed_RzPl = QLineEdit(self.Rz_tab)
        self.vmax_ed_RzPl.setText('None')
        self.vmax_ed_RzPl.setMinimumSize(QtCore.QSize(40, 0))
        self.chzz_ed_RzPl = QLineEdit(self.Rz_tab)
        self.chzz_ed_RzPl.setMinimumSize(QtCore.QSize(100, 0))
        self.chRR_ed_RzPl = QLineEdit(self.Rz_tab)
        self.chRR_ed_RzPl.setMinimumSize(QtCore.QSize(100, 0))
        # Filters edits
        self.Fourier_cut_RzPl = QLineEdit(self.Rz_tab)
        self.Fourier_cut_RzPl.setText('30.0')
        self.Fourier2_cut_RzPl = QLineEdit(self.Rz_tab)
        self.Fourier2_cut_RzPl.setText('2.0')
        self.SavGol_ed0_RzPl = QLineEdit(self.Rz_tab)
        self.SavGol_ed0_RzPl.setText('11')
        self.SavGol_ed0_RzPl.setMinimumSize(QtCore.QSize(20, 0))
        self.SavGol_ed1_RzPl = QLineEdit(self.Rz_tab)
        self.SavGol_ed1_RzPl.setText('3')
        self.Binning_ed_RzPl = QLineEdit(self.Rz_tab)
        self.Binning_ed_RzPl.setText('60.0')
        self.Binning_ed_RzPl.setMinimumSize(QtCore.QSize(40, 0))
        self.Contour_ed_RzPl = QLineEdit(self.Rz_tab)
        self.Contour_ed_RzPl.setText('0')
        self.NNcont_ed_RzPl = QLineEdit(self.Rz_tab)
        self.NNcont_ed_RzPl.setText('60')
        self.tplot_ed_RzPl = QLineEdit(self.Rz_tab)
        self.tplot_ed_RzPl.setText('4.488550')
        self.tplot_ed_RzPl.setMinimumSize(QtCore.QSize(50, 0))
        self.dtplot_ed_RzPl = QLineEdit(self.Rz_tab)
        self.dtplot_ed_RzPl.setText('5.0e-6')
        self.dtplot_ed_RzPl.setMinimumSize(QtCore.QSize(50, 0))
        self.FourMult_ed_RzPl = QLineEdit(self.Rz_tab)
        self.FourMult_ed_RzPl.setText('13.0,15.0;26,30')
        self.FourMult_ed_RzPl.setMinimumSize(QtCore.QSize(100, 0))
        # velocimetry specific line edits
        self.rhop_ed_RzPl = QLineEdit(self.Rz_tab)
        self.rhop_ed_RzPl.setText('0.3')
        self.sendpoints_butt_RzPl = QPushButton("Send t,R,z,r", self.Rz_tab)
        self.sendpoints_butt_RzPl.clicked.connect(self.send_points)
        self.clearpoints_butt_RzPl = QPushButton("Clear table", self.Rz_tab)
        self.clearpoints_butt_RzPl.clicked.connect(self.clear_table)

        # what to plot (type of filter)
        self.ImgType_plot_RzPl = QComboBox(self.Rz_tab)
        self.ImgType_plot_RzPl.addItems([
            'no Image filter', 'Gaussian', 'Median', 'Bilateral',
            'Conservative_smoothing'
        ])
        self.type_plot_RzPl = QComboBox(self.Rz_tab)
        self.type_plot_RzPl.addItems([
            'no 1D filter', 'Fourier highpass', 'Fourier lowpass',
            'Fourier both', 'Fourier multiple', 'SavGol', 'Binning'
        ])
        self.Interp_plot_RzPl = QComboBox(self.Rz_tab)
        self.Interp_plot_RzPl.addItems(
            ['no interpolation', 'with interpolation', 'set to zero'])
        # self.Interp_plot_RzPl.setMaximumSize(QtCore.QSize(90, 0))
        self.Save_plot_RzPl = QComboBox(self.Rz_tab)
        self.Save_plot_RzPl.addItems(
            ['do not save', 'save as pdf', 'save as png'])
        # plot buttom
        self.MinusTplot_butt_RzPl = QPushButton("< -dt", self.Rz_tab)
        self.PlusTplot_butt_RzPl = QPushButton("+dt >", self.Rz_tab)
        self.tplot_butt_RzPl = QPushButton("plot time", self.Rz_tab)
        self.MinusTplot_butt_RzPl.clicked.connect(lambda: self.f_Rz_plot(1))
        self.PlusTplot_butt_RzPl.clicked.connect(lambda: self.f_Rz_plot(2))
        self.tplot_butt_RzPl.clicked.connect(lambda: self.f_Rz_plot(3))

        # Add widgets to layout
        # First row
        sublayout_RzPl.setSpacing(2)
        sublayout_RzPl.addWidget(self.tB_lbl_RzPl, 0, 0)
        sublayout_RzPl.addWidget(self.tB_ed_RzPl, 0, 1)
        sublayout_RzPl.addWidget(self.tE_lbl_RzPl, 0, 2)
        sublayout_RzPl.addWidget(self.tE_ed_RzPl, 0, 3)
        sublayout_RzPl.addWidget(self.tCnt_lbl_RzPl, 0, 4)
        sublayout_RzPl.addWidget(self.tCnt_ed_RzPl, 0, 5)
        sublayout_RzPl.addWidget(self.dt_lbl_RzPl, 0, 6)
        sublayout_RzPl.addWidget(self.dt_ed_RzPl, 0, 7)
        sublayout_RzPl.addWidget(self.Butt_dt_RzPl, 0, 8)
        # Second row
        sublayout_RzPl.addWidget(self.Fourier2_lbl0_RzPl, 1, 0)
        sublayout_RzPl.addWidget(self.Fourier2_cut_RzPl, 1, 1)
        sublayout_RzPl.addWidget(self.Fourier_lbl0_RzPl, 1, 2)
        sublayout_RzPl.addWidget(self.Fourier_cut_RzPl, 1, 3)
        sublayout_RzPl.addWidget(self.FourMult_lbl_RzPl, 1, 4)
        sublayout_RzPl.addWidget(self.FourMult_ed_RzPl, 1, 5)
        ######
        sublayout_RzPl.addWidget(self.SavGol_lbl0_RzPl, 1, 6)
        sublayout_RzPl.addWidget(self.SavGol_ed0_RzPl, 1, 7)
        sublayout_RzPl.addWidget(self.SavGol_lbl1_RzPl, 1, 8)
        sublayout_RzPl.addWidget(self.SavGol_ed1_RzPl, 1, 9)
        sublayout_RzPl.addWidget(self.Binning_lbl_RzPl, 1, 10)
        sublayout_RzPl.addWidget(self.Binning_ed_RzPl, 1, 11)
        ######
        sublayout_RzPl.addWidget(self.chzz_lbl_RzPl, 2, 0)
        sublayout_RzPl.addWidget(self.chzz_ed_RzPl, 2, 1)
        sublayout_RzPl.addWidget(self.chRR_lbl_RzPl, 2, 2)
        sublayout_RzPl.addWidget(self.chRR_ed_RzPl, 2, 3)
        ######
        sublayout_RzPl.addWidget(self.vmin_lbl_RzPl, 2, 4)
        sublayout_RzPl.addWidget(self.vmin_ed_RzPl, 2, 5)
        sublayout_RzPl.addWidget(self.vmax_lbl_RzPl, 2, 6)
        sublayout_RzPl.addWidget(self.vmax_ed_RzPl, 2, 7)
        sublayout_RzPl.addWidget(self.Contour_lbl_RzPl, 2, 8)
        sublayout_RzPl.addWidget(self.Contour_ed_RzPl, 2, 9)
        sublayout_RzPl.addWidget(self.NNcont_lbl_RzPl, 2, 10)
        sublayout_RzPl.addWidget(self.NNcont_ed_RzPl, 2, 11)
        #####
        ######
        # Third row
        sublayout_RzPl.addWidget(self.tplot_lbl_RzPl, 3, 0)
        sublayout_RzPl.addWidget(self.tplot_ed_RzPl, 3, 1)
        sublayout_RzPl.addWidget(self.dtplot_lbl_RzPl, 3, 2)
        sublayout_RzPl.addWidget(self.dtplot_ed_RzPl, 3, 3)
        # Fourth row
        sublayout_RzPl.addWidget(self.rhop_lbl_RzPl, 4, 0)
        sublayout_RzPl.addWidget(self.rhop_ed_RzPl, 4, 1)
        sublayout_RzPl.addWidget(self.sendpoints_butt_RzPl, 4, 2)
        sublayout_RzPl.addWidget(self.clearpoints_butt_RzPl, 4, 3)
        # Plot control
        sublayout_RzPl.addWidget(self.ImgType_plot_RzPl, 1, 12)
        sublayout_RzPl.addWidget(self.type_plot_RzPl, 2, 12)
        sublayout_RzPl.addWidget(self.Save_plot_RzPl, 3, 7)
        sublayout_RzPl.addWidget(self.Interp_plot_RzPl, 3, 8)
        sublayout_RzPl.addWidget(self.MinusTplot_butt_RzPl, 3, 10)
        sublayout_RzPl.addWidget(self.PlusTplot_butt_RzPl, 3, 11)
        sublayout_RzPl.addWidget(self.tplot_butt_RzPl, 3, 12)

        # Add matplotlib plot
        self.figure_RzPl = Figure(figsize=(5, 3), constrained_layout=False)
        self.static_canvas_RzPl = FigureCanvas(self.figure_RzPl)
        layout_RzPl.addWidget(self.static_canvas_RzPl,
                              QtCore.Qt.AlignTop)  # align the plot up
        layout_RzPl.addStretch()  # stretch plot in all free space
        self.toolbar = NavigationToolbar(
            self.static_canvas_RzPl, self.Rz_tab,
            coordinates=True)  # add toolbar below the plot
        layout_RzPl.addWidget(self.toolbar)
        self._static_ax = self.static_canvas_RzPl.figure.subplots()  # add axes

        # velcimetry data
        self.Monitor_RzPl = QtWidgets.QTextBrowser(self.Rz_tab)
        self.Monitor_RzPl.setText("NN\tt\tR\tz\tr\n")
        self.counter = 1
        self.Monitor_RzPl.setMaximumSize(QtCore.QSize(1920, 50))
        sublayout2_RzPl = QtWidgets.QVBoxLayout()  # layout for monitor
        layout_RzPl.addLayout(sublayout2_RzPl)
        sublayout2_RzPl.addWidget(self.Monitor_RzPl, 0)

        # ----------------------------------------------------------------------------------
        # SettRz tab - content
        # Create layouts
        layout_RzSet = QtWidgets.QVBoxLayout(
            self.SettRzPlot_tab)  # main layout
        sublayout_RzSet = QtWidgets.QGridLayout()  # layout for inputs
        layout_RzSet.addLayout(sublayout_RzSet)

        # Input widgets
        # labels
        self.one_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.one_lbl_RzSet.setText('Gaussian filter:')
        self.two_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.two_lbl_RzSet.setText('Median filter:')
        self.three_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.three_lbl_RzSet.setText('Bilateral filter:')
        self.four_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.four_lbl_RzSet.setText('Conservative smoothing filter:')
        # filters parameters
        self.BilKernSize_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.BilKernSize_lbl_RzSet.setText('Kernel size:')
        self.BilS0_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.BilS0_lbl_RzSet.setText('s0:')
        self.BilS1_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.BilS1_lbl_RzSet.setText('s1:')
        self.MedKernSize_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.MedKernSize_lbl_RzSet.setText('Kernel size:')
        self.ConsSize_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.ConsSize_lbl_RzSet.setText('Neighborhood size:')
        self.GausSigma_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.GausSigma_lbl_RzSet.setText('sigma:')

        # Line edits (inputs)
        self.GausSigma_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.GausSigma_ed_RzSet.setText('1.0')
        self.BilKern_type_RzSet = QComboBox(self.SettRzPlot_tab)
        self.BilKern_type_RzSet.addItems(['disk', 'square'])
        self.BilKernSize_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.BilKernSize_ed_RzSet.setText('1')
        self.BilS0_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.BilS0_ed_RzSet.setText('100')
        self.BilS1_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.BilS1_ed_RzSet.setText('100')

        self.MedKern_type_RzSet = QComboBox(self.SettRzPlot_tab)
        self.MedKern_type_RzSet.addItems(['disk', 'square'])
        self.MedKernSize_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.MedKernSize_ed_RzSet.setText('1')
        self.ConsSize_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.ConsSize_ed_RzSet.setText('2')

        sublayout_RzSet.setSpacing(2)
        # First row
        sublayout_RzSet.addWidget(self.one_lbl_RzSet, 0, 0)
        sublayout_RzSet.addWidget(self.GausSigma_lbl_RzSet, 0, 2)
        sublayout_RzSet.addWidget(self.GausSigma_ed_RzSet, 0, 3)
        # Second row
        sublayout_RzSet.addWidget(self.two_lbl_RzSet, 1, 0)
        sublayout_RzSet.addWidget(self.MedKern_type_RzSet, 1, 1)
        sublayout_RzSet.addWidget(self.MedKernSize_lbl_RzSet, 1, 2)
        sublayout_RzSet.addWidget(self.MedKernSize_ed_RzSet, 1, 3)
        # Third row
        sublayout_RzSet.addWidget(self.three_lbl_RzSet, 2, 0)
        sublayout_RzSet.addWidget(self.BilKern_type_RzSet, 2, 1)
        sublayout_RzSet.addWidget(self.BilKernSize_lbl_RzSet, 2, 2)
        sublayout_RzSet.addWidget(self.BilKernSize_ed_RzSet, 2, 3)
        sublayout_RzSet.addWidget(self.BilS0_lbl_RzSet, 2, 4)
        sublayout_RzSet.addWidget(self.BilS0_ed_RzSet, 2, 5)
        sublayout_RzSet.addWidget(self.BilS1_lbl_RzSet, 2, 6)
        sublayout_RzSet.addWidget(self.BilS1_ed_RzSet, 2, 7)
        # Fourth row
        sublayout_RzSet.addWidget(self.four_lbl_RzSet, 3, 0)
        sublayout_RzSet.addWidget(self.ConsSize_lbl_RzSet, 3, 2)
        sublayout_RzSet.addWidget(self.ConsSize_ed_RzSet, 3, 3)

        sublayout1_RzSet = QtWidgets.QVBoxLayout()  # one more layout for title
        layout_RzSet.addLayout(sublayout1_RzSet)

        self.Info1_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.Info1_lbl_RzSet.setText(
            '====== Matrix for interpolation (scipy.interpolate.interp2d, type = cubic) or "set to zero" options ======'
        )
        sublayout1_RzSet.addWidget(self.Info1_lbl_RzSet)

        sublayout2_RzSet = QtWidgets.QGridLayout(
        )  # one more layout for interpolation
        layout_RzSet.addLayout(sublayout2_RzSet)

        LOSlabels = {}
        self.LOSlabels = {}
        for i_L in range(20):
            LOSlabels['%d' % (i_L)] = (i_L, 0)
        for sText, pos in LOSlabels.items():
            # QLabels
            self.LOSlabels[sText] = QLabel("LOS: %d" % (int(sText) + 1))
            sublayout2_RzSet.addWidget(self.LOSlabels[sText], pos[0] + 1,
                                       pos[1])

        checks = {}
        self.checks = {}
        for i_L in range(20):
            for i_R in range(8):
                checks['%d,%d' % (i_L, i_R)] = (i_L, i_R)
        for sText, pos in checks.items():
            # QCheckBoxes
            self.checks[sText] = QCheckBox("%d,%d" % (pos[0] + 1, pos[1] + 1))
            sublayout2_RzSet.addWidget(self.checks[sText], pos[0] + 1,
                                       pos[1] + 1)
        sublayout2_RzSet.setSpacing(2)

        sublayout3_RzSet = QtWidgets.QHBoxLayout()  # one more layout for path
        layout_RzSet.addLayout(sublayout3_RzSet)

        self.path_lbl_RzSet = QLabel(self.SettRzPlot_tab)
        self.path_lbl_RzSet.setText(
            'Path to save Rz plots (path should end with "/" symbol):')

        self.path_ed_RzSet = QLineEdit(self.SettRzPlot_tab)
        self.path_ed_RzSet.setText('/afs/ipp/home/o/osam/Documents/output/')
        sublayout3_RzSet.addWidget(self.path_lbl_RzSet)
        sublayout3_RzSet.addWidget(self.path_ed_RzSet)

        layout_RzSet.addStretch(
        )  # stretch free space (compress widgets at the top)
# ----------------------------------------------------------------------------------

# ----------------------------------------------------------------------------------
# ---------------METHODS-------------

    def tBE_from_tCnt(self, number):
        try:
            if (number == 9):
                t = float(self.tCnt_ed_RzPl.text())
                dt = float(self.dt_ed_RzPl.text())
                tB = t - dt / 2.0
                tE = t + dt / 2.0
                self.tB_ed_RzPl.setText('%0.7g' % (tB))
                self.tE_ed_RzPl.setText('%0.7g' % (tE))
                self.tplot_ed_RzPl.setText('%0.7g' % (np.mean([tB, tE])))
                self.f_Rz_plot(3)

        except Exception as exc:
            print("!!! Incorrect input. ERROR: %s" % (exc))
        pass

    def Load_ECEI_data(self):
        try:
            self.Shot = int(self.Shot_ed_load.text())
            self.Diag = self.Diag_load.currentText()
            self.Diag_EQ = self.Diag_EQ_load.currentText()
            self.Monitor_load.setText("Status:\nLoading %s: #%d ... " %
                                      (self.Diag, self.Shot))
            allow_to_load = True
        except Exception as exc:
            print("!!! Incorrect input. ERROR: %s" % (exc))
            self.Monitor_load.setText("Status:\nPlease enter shot number.")
            allow_to_load = False

        if (self.Diag_EQ == 'EQH') & (allow_to_load):
            try:
                # load EQH
                self.Monitor_load.setText("")
                EQ = EQH.EQH()
                EQ.Load(self.Shot)
                self.EQ_rhopM = EQ.rhopM
                self.EQ_time = EQ.time
                self.EQ_R = EQ.R
                self.EQ_z = EQ.z
                self.EQ_Rmag = EQ.Rmag
                self.EQ_zmag = EQ.zmag
                self.Monitor_load.insertPlainText(
                    "EQH data has been loaded succesfully.\n")
            except Exception as exc:
                traceback.print_exc()
                print("!!! Coudn't load EQH. ERROR: %s" % (exc))
                self.Monitor_load.setText(
                    "Status:\nError in loading ECI data.")
                self.Monitor_load.insertPlainText("!!! EQH data NOT loaded.")
                print("+++ EQH has been loaded +++")

        if (self.Diag == 'TDI') & (allow_to_load):
            try:
                TD = TDI.TDI()
                TD.Load(self.Shot)
                TD.Load_FakeRz()
                self.ECEId = TD.ECEId.copy()
                self.ECEId_time = TD.time.copy()
                self.ECEId_RR = TD.RR_fake.copy()
                self.ECEId_zz = TD.zz_fake.copy()
                self.ECEId_R = TD.R_fake.copy()
                self.ECEId_z = TD.z_fake.copy()
                self.Monitor_load.insertPlainText(
                    "Status:\nTDI #%d\ntB = %g, tE = %g s\nLoaded succesfully."
                    % (self.Shot, TD.time[0], TD.time[-1]))

                self.data_loaded = True
                print("+++ The data has been loaded succesfully. +++")
            except Exception as exc:
                print("!!! Coudn't load TDI. ERROR: %s" % (exc))
                self.Monitor_load.insertPlainText(
                    "Status:\nError in loading ECI data.")

        if (self.Diag == 'ECI') & (allow_to_load):
            try:
                EI = ECI.ECI()
                EI.Load(self.Shot)
                EI.Load_FakeRz()
                self.ECEId = EI.ECEId.copy()
                self.ECEId_time = EI.time.copy()
                self.ECEId_RR = EI.RR_fake.copy()
                self.ECEId_zz = EI.zz_fake.copy()
                self.ECEId_R = EI.R_fake.copy()
                self.ECEId_z = EI.z_fake.copy()
                self.Monitor_load.insertPlainText(
                    "Status:\nECI #%d\ntB = %g, tE = %g s\nLoaded succesfully."
                    % (self.Shot, EI.time[0], EI.time[-1]))
                self.data_loaded = True
                print("+++ The data has been loaded succesfully. +++")
            except Exception as exc:
                print("!!! Coudn't load ECI. ERROR: %s" % (exc))
                self.Monitor_load.insertPlainText(
                    "Status:\nError in loading ECI data.")

    def f_Rz_plot(self, which_plot):
        if (self.data_loaded):  # check whether ECEI data is loaded
            try:
                import matplotlib.pyplot as plt
                plt.rcParams.update({'font.size': 10})
                # data preparation
                self.tB_ed_RzPl
                tB = float(self.tB_ed_RzPl.text())
                tE = float(self.tE_ed_RzPl.text())
                if (which_plot == 1):
                    tplot_old = float(self.tplot_ed_RzPl.text())
                    dtplot = float(self.dtplot_ed_RzPl.text())
                    tplot = tplot_old - dtplot
                    self.tplot_ed_RzPl.setText("%0.7g" % tplot)
                if (which_plot == 2):
                    tplot_old = float(self.tplot_ed_RzPl.text())
                    dtplot = float(self.dtplot_ed_RzPl.text())
                    tplot = tplot_old + dtplot
                    self.tplot_ed_RzPl.setText("%0.7g" % tplot)
                if (which_plot == 3):
                    tplot = float(self.tplot_ed_RzPl.text())
                    self.counter_save = 0

                self.tplot = tplot
                dtplot = float(self.dtplot_ed_RzPl.text())
                contour_check = self.Contour_ed_RzPl.text()
                mf = my_funcs.my_funcs()
                mf.CutDataECEI(self.ECEId_time, self.ECEId, tBegin=tB, tEnd=tE)
                mf.relECEI(mf.ECEId_C)

                mf.cutDataEQH(self.EQ_time, self.EQ_rhopM, self.EQ_R,
                              self.EQ_z, self.EQ_Rmag, self.EQ_zmag, tplot)
                time_plot, data_plot = mf.time_C, mf.ECEId_rel

                filter_status = "None"

                if (self.type_plot_RzPl.currentText() == 'Fourier lowpass'):
                    f_cut = float(self.Fourier_cut_RzPl.text()) * 1.0e3
                    noise_ampl = 1.0
                    mf.Fourier_analysis_ECEI_lowpass(time_plot, data_plot,
                                                     noise_ampl, f_cut)
                    data_plot = mf.ECEId_fft_f_ifft
                    filter_status = "Fourier lowpass, freq_cut = %g kHz" % (
                        f_cut * 1.0e-3)

                if (self.type_plot_RzPl.currentText() == 'Fourier highpass'):
                    f_cut = float(self.Fourier2_cut_RzPl.text()) * 1.0e3
                    noise_ampl = 1.0
                    mf.Fourier_analysis_ECEI_highpass(time_plot, data_plot,
                                                      noise_ampl, f_cut)
                    data_plot = mf.ECEId_fft_f_ifft
                    filter_status = "Fourier highpass, freq_cut = %g kHz" % (
                        f_cut * 1.0e-3)

                if (self.type_plot_RzPl.currentText() == 'Fourier both'):
                    f_cut_lp = float(self.Fourier_cut_RzPl.text()) * 1.0e3
                    noise_ampl_lp = 1.0
                    f_cut_hp = float(self.Fourier2_cut_RzPl.text()) * 1.0e3
                    noise_ampl_hp = 1.0
                    mf.Fourier_analysis_ECEI_lowpass(time_plot, data_plot,
                                                     noise_ampl_lp, f_cut_lp)
                    data_plot = mf.ECEId_fft_f_ifft.copy()
                    mf.Fourier_analysis_ECEI_highpass(time_plot, data_plot,
                                                      noise_ampl_hp, f_cut_hp)
                    data_plot = mf.ECEId_fft_f_ifft.copy()
                    filter_status = "Fourier high and low pass, freq_cut_hp = %g kHz, freq_cut_lp = %g kHz" % (
                        f_cut_hp * 1.0e-3, f_cut_lp * 1.0e-3)

                if (self.type_plot_RzPl.currentText() == 'Fourier multiple'):
                    string = self.FourMult_ed_RzPl.text()
                    freq_num = len(string.split(";"))
                    f_hp = np.zeros(freq_num)
                    f_lp = np.zeros(freq_num)
                    for i in range(freq_num):
                        f_hp[i] = string.split(";")[i].split(",")[0]
                        f_hp[i] *= 1.0e3
                        f_lp[i] = string.split(";")[i].split(",")[1]
                        f_lp[i] *= 1.0e3
                    mf.Fourier_analysis_ECEI_multiple(time_plot, data_plot,
                                                      f_hp, f_lp)
                    data_plot = mf.ECEId_fft_f_ifft
                    filter_status = "Fourier multiple, freqs: %s kHz" % (
                        string)

                if (self.type_plot_RzPl.currentText() == 'SavGol'):
                    win_len = int(self.SavGol_ed0_RzPl.text())
                    pol_ord = int(self.SavGol_ed1_RzPl.text())
                    mf.SavGol_filter_ECEI(data_plot, win_len, pol_ord)
                    data_plot = mf.ECEId_savgol
                    filter_status = "Savgol, win_len = %g, pol_ord = %g" % (
                        win_len, pol_ord)

                if (self.type_plot_RzPl.currentText() == 'Binning'):
                    binning_freq = float(self.Binning_ed_RzPl.text())
                    time_plot, data_plot = mf.dataBinningECEI(
                        time_plot, data_plot, binning_freq)
                    filter_status = "Binning, freq = %g kHz" % (binning_freq)

                RR_plot, zz_plot = self.ECEId_RR, self.ECEId_zz

                removeLOS_ch = self.chzz_ed_RzPl.text()
                if removeLOS_ch:
                    removeLOS_ch = np.array(
                        self.chzz_ed_RzPl.text().split(','))
                    removeLOS_ch = removeLOS_ch.astype(int) - 1
                else:
                    removeLOS_ch = []
                removeRR_ch = self.chRR_ed_RzPl.text()
                if removeRR_ch:
                    removeRR_ch = np.array(self.chRR_ed_RzPl.text().split(','))
                    removeRR_ch = removeRR_ch.astype(int) - 1
                else:
                    removeRR_ch = []

                NN_LOS, NN_R = data_plot.shape[1], data_plot.shape[2]
                ch_zz = np.arange(NN_LOS)
                ch_zz = np.delete(ch_zz, removeLOS_ch)
                ch_RR = np.arange(NN_R)
                ch_RR = np.delete(ch_RR, removeRR_ch)

                trace_1D = data_plot[:, 6, 3]
                # remove channels
                RR_plot = np.delete(RR_plot, removeLOS_ch, axis=0)
                RR_plot = np.delete(RR_plot, removeRR_ch, axis=1)
                zz_plot = np.delete(zz_plot, removeLOS_ch, axis=0)
                zz_plot = np.delete(zz_plot, removeRR_ch, axis=1)
                data_plot = np.delete(data_plot, removeLOS_ch, axis=1)
                data_plot = np.delete(data_plot, removeRR_ch, axis=2)

                check_vmin_vmax = 0
                if (self.vmin_ed_RzPl.text().replace('-', '',
                                                     1).replace('.', '',
                                                                1).isdigit()):
                    vmin = float(self.vmin_ed_RzPl.text())
                    check_vmin_vmax = 1
                else:
                    vmin = None

                if (self.vmax_ed_RzPl.text().replace('.', '', 1).isdigit()):
                    vmax = float(self.vmax_ed_RzPl.text())
                    check_vmin_vmax = 1
                else:
                    vmax = None

                if (self.NNcont_ed_RzPl.text().replace('.', '', 1).isdigit()):
                    NN_cont = int(self.NNcont_ed_RzPl.text())
                else:
                    NN_cont = 20

                # find time index of plot
                idx_tplot = mf.find_nearest_idx(time_plot, tplot)
                time_plot_t, data_plot_t = time_plot[idx_tplot], data_plot[
                    idx_tplot, :, :]

                if (self.Interp_plot_RzPl.currentText() == 'with interpolation'
                    ):
                    interp_mask = np.full((NN_LOS, NN_R), False)
                    for i_L in range(NN_LOS):
                        for i_R in range(NN_R):
                            interp_mask[i_L,
                                        i_R] = self.checks['%d,%d' %
                                                           (i_L,
                                                            i_R)].isChecked()

                    interp_mask = np.delete(interp_mask, removeLOS_ch, axis=0)
                    interp_mask = np.delete(interp_mask, removeRR_ch, axis=1)
                    data_to_interp = data_plot_t.copy()
                    data_to_interp[interp_mask] = np.NaN
                    data_plot_t = mf.nan_interp_2d(data_to_interp)

                if (self.Interp_plot_RzPl.currentText() == 'set to zero'):
                    interp_mask = np.full((NN_LOS, NN_R), False)
                    for i_L in range(NN_LOS):
                        for i_R in range(NN_R):
                            interp_mask[i_L,
                                        i_R] = self.checks['%d,%d' %
                                                           (i_L,
                                                            i_R)].isChecked()

                    interp_mask = np.delete(interp_mask, removeLOS_ch, axis=0)
                    interp_mask = np.delete(interp_mask, removeRR_ch, axis=1)
                    data_plot_t[interp_mask] = 0.0

                if (self.ImgType_plot_RzPl.currentText() == 'Gaussian'):
                    sigma = float(self.GausSigma_ed_RzSet.text())
                    data_plot_t = mf.gaussian_filter(data_plot_t, sigma)
                    filter_status += "; Img filt: Gaussian, sigma=%g" % (sigma)

                if (self.ImgType_plot_RzPl.currentText() == 'Bilateral'):
                    kernel = self.BilKern_type_RzSet.currentText()
                    kern_size = int(self.BilKernSize_ed_RzSet.text())
                    s0 = int(self.BilS0_ed_RzSet.text())
                    s1 = int(self.BilS1_ed_RzSet.text())
                    data_plot_t = mf.bilateral_filter(data_plot_t, kernel,
                                                      kern_size, s0, s1)
                    filter_status += "; Img filt: Bilateral, %s, kern_size=%g, s0=%g, s1=%g" % (
                        kernel, kern_size, s0, s1)

                if (self.ImgType_plot_RzPl.currentText() == 'Median'):
                    kernel = self.MedKern_type_RzSet.currentText()
                    kern_size = int(self.MedKernSize_ed_RzSet.text())
                    data_plot_t = mf.median_filter(data_plot_t, kernel,
                                                   kern_size)
                    filter_status += "; Img filt: Median, %s, kern_size=%g" % (
                        kernel, kern_size)

                if (self.ImgType_plot_RzPl.currentText() ==
                        'Conservative_smoothing'):
                    size_filt = int(self.ConsSize_ed_RzSet.text())
                    data_plot_t = mf.conservative_smoothing_filter(
                        data_plot_t, size_filt)
                    filter_status += "; Img filt: Conservative smoothing, filt_size=%g" % (
                        size_filt)

                # plotting
                # initiate plot
                self.figure_RzPl.clf()  # clear previous figure and axes
                self._static_ax = self.static_canvas_RzPl.figure.subplots(
                    1, 2, sharex=False, sharey=False)  # add axes
                if (check_vmin_vmax == 1):
                    levels_to_plot = np.linspace(vmin, vmax, NN_cont)
                if (check_vmin_vmax == 0):
                    levels_to_plot = NN_cont
                contours = self._static_ax[0].contourf(RR_plot,
                                                       zz_plot,
                                                       data_plot_t,
                                                       vmin=vmin,
                                                       vmax=vmax,
                                                       levels=levels_to_plot,
                                                       cmap='jet')
                cbar = self.figure_RzPl.colorbar(contours,
                                                 ax=self._static_ax[0],
                                                 pad=0.07)
                cbar.ax.set_ylabel('deltaTrad/<Trad>', rotation=90)
                if contour_check == '1':
                    self._static_ax[0].contour(RR_plot,
                                               zz_plot,
                                               data_plot_t,
                                               vmin=vmin,
                                               vmax=vmax,
                                               levels=levels_to_plot,
                                               cmap='binary')
                # cbar.ax.tick_params(labelsize=8, rotation=90)
                self._static_ax[0].plot(RR_plot, zz_plot, "ko", ms=2)

                if (self.Interp_plot_RzPl.currentText() == 'set to zero') | (
                        self.Interp_plot_RzPl.currentText()
                        == 'with interpolation'):
                    self._static_ax[0].plot(RR_plot[interp_mask],
                                            zz_plot[interp_mask],
                                            "wo",
                                            ms=6)

                self._static_ax[0].set_xlabel("R [m]")
                self._static_ax[0].set_ylabel("z [m]")

                for i, txt in enumerate(ch_zz):
                    self._static_ax[0].annotate(txt + 1,
                                                (RR_plot[i, 0], zz_plot[i, 0]),
                                                fontsize=8)

                for i, txt in enumerate(ch_RR):
                    self._static_ax[0].annotate(txt + 1,
                                                (RR_plot[0, i], zz_plot[0, i]),
                                                fontsize=8)

                # EQ contours
                contours_rhop = self._static_ax[0].contour(
                    mf.RR_t, mf.zz_t, mf.rhopM_t, 50)
                self._static_ax[0].clabel(contours_rhop,
                                          inline=True,
                                          fontsize=10)
                self._static_ax[0].plot(mf.Rmag_t, mf.zmag_t, 'bo')
                self._static_ax[0].set_xlim([mf.Rmag_t, RR_plot[0, -1]])
                self._static_ax[0].set_ylim([zz_plot[0, 0], zz_plot[-1, 0]])

                rhop_to_plot = float(self.rhop_ed_RzPl.text())
                equ_data = equ.equ_map(self.Shot, 'EQH', 'AUGD')
                data_rz = equ_data.rho2rz(rhop_to_plot, tplot, 'rho_pol')
                R_from_rhop = data_rz[0][0][0]
                z_from_rhop = data_rz[1][0][0]
                self.Rmag_t = mf.Rmag_t
                self.zmag_t = mf.zmag_t
                r_rhop = np.sqrt((self.Rmag_t - R_from_rhop[0])**2 +
                                 (self.zmag_t - z_from_rhop[0])**2)
                self._static_ax[0].plot(R_from_rhop,
                                        z_from_rhop,
                                        'g-',
                                        linewidth=4.0)

                self.my_line, = self._static_ax[0].plot(
                    [self.Rmag_t, R_from_rhop[0]],
                    [self.zmag_t, z_from_rhop[0]],
                    marker='o',
                    color='b')
                self._static_ax[0].set_title(
                    "t = %0.7g s, rhop(green) = %0.2g, r_rhop = %0.4g m" %
                    (time_plot_t, rhop_to_plot, r_rhop))
                # 1D plot
                self._static_ax[1].plot(time_plot, trace_1D)
                self._static_ax[1].set_xlabel("t [s]")
                self._static_ax[1].set_ylabel("deltaTrad/<Trad>")
                self._static_ax[1].set_title(
                    "LOS = 7, R_ch = 4, dt resolut = %g s" %
                    (time_plot[1] - time_plot[0]))
                self._static_ax[1].axvline(x=time_plot_t, color="k")

                self.figure_RzPl.suptitle("ECEI, Shot #%d, Filter: %s" %
                                          (self.Shot, filter_status),
                                          fontsize=10)
                if (self.Save_plot_RzPl.currentText() == 'save as pdf') | (
                    (self.Save_plot_RzPl.currentText() == 'save as pdf') &
                    (self.counter_save == 0)):
                    path_to_save = self.path_ed_RzSet.text()
                    self.figure_RzPl.savefig(path_to_save + 'p_%03d.pdf' %
                                             (self.counter_save),
                                             bbox_inches='tight')
                    self.counter_save += 1
                if (self.Save_plot_RzPl.currentText() == 'save as png') | (
                    (self.Save_plot_RzPl.currentText() == 'save as pdf') &
                    (self.counter_save == 0)):
                    path_to_save = self.path_ed_RzSet.text()
                    self.figure_RzPl.savefig(path_to_save + 'p_%03d.png' %
                                             (self.counter_save),
                                             bbox_inches='tight')
                    self.counter_save += 1
                click_coord = self.static_canvas_RzPl.mpl_connect(
                    'button_press_event', self.mouse_click_Rz)
                self.static_canvas_RzPl.draw()
                # self.sync_tabs(9)
                print("+++ The data has been plotted succesfully. +++")

            except Exception as exc:
                traceback.print_exc()
                print("!!! Cannot plot. ERROR: %s" % (exc))
        else:
            print("Please load the ECEI data (first tab)")

    def mouse_click_Rz(self, event):
        if (event.dblclick == True) & (event.button == 1):
            ix, iy = event.xdata, event.ydata
            self.tplot_ed_RzPl.setText("%0.7g" % (ix))
            self.f_Rz_plot(3)

        if (event.dblclick == True) & (event.button == 3):
            ix, iy = event.xdata, event.ydata
            self.r_rhop_blue = np.sqrt((self.Rmag_t - ix)**2 +
                                       (self.zmag_t - iy)**2)
            self.R_blue = ix
            self.z_blue = iy
            self.my_line.remove()
            self.my_line, = self._static_ax[0].plot([self.Rmag_t, ix],
                                                    [self.zmag_t, iy],
                                                    marker='o',
                                                    color='b')
            self._static_ax[0].set_xlabel(
                "R [m]; blue: R = %0.4g m, z = %0.4g m, r_rhop = %0.4g m" %
                (self.R_blue, self.z_blue, self.r_rhop_blue))

    def sync_tabs(self, number):
        try:

            if (number == 9):
                tB_ed = self.tB_ed_RzPl.text()
                tE_ed = self.tE_ed_RzPl.text()
                tCnt_ed = self.tCnt_ed_RzPl.text()
                dt_ed = self.dt_ed_RzPl.text()
                Fourier_cut = self.Fourier_cut_RzPl.text()
                Fourier2_cut = self.Fourier2_cut_RzPl.text()
                Savgol_ed0 = self.SavGol_ed0_RzPl.text()
                Savgol_ed1 = self.SavGol_ed1_RzPl.text()
                Binning_ed = self.Binning_ed_RzPl.text()
            # 9
            self.tB_ed_RzPl.setText(tB_ed)
            self.tE_ed_RzPl.setText(tE_ed)
            self.tCnt_ed_RzPl.setText(tCnt_ed)
            self.dt_ed_RzPl.setText(dt_ed)
            self.Fourier_cut_RzPl.setText(Fourier_cut)
            self.Fourier2_cut_RzPl.setText(Fourier2_cut)
            self.SavGol_ed0_RzPl.setText(Savgol_ed0)
            self.SavGol_ed1_RzPl.setText(Savgol_ed1)
            self.Binning_ed_RzPl.setText(Binning_ed)

        except Exception as exc:
            print("!!! Couldn't synchronize tabs. ERROR: %s" % (exc))

    def send_points(self):
        try:
            self.Monitor_RzPl.moveCursor(QTextCursor.End)
            self.Monitor_RzPl.insertPlainText(
                "%d\t%0.7g\t%0.4g\t%0.4g\t%0.4g\n" %
                (self.counter, self.tplot, self.R_blue, self.z_blue,
                 self.r_rhop_blue))
            self.counter += 1
        except Exception as exc:
            traceback.print_exc()
            print("!!! Cannot plot. ERROR: %s" % (exc))

    def clear_table(self):
        try:
            self.Monitor_RzPl.setText("NN\tt\tR\tz\tr\n")
            self.counter = 1
        except Exception as exc:
            traceback.print_exc()
            print("!!! Cannot plot. ERROR: %s" % (exc))
Example #30
0
class Points_Input(QWidget):
    def __init__(self, parent):
        QWidget.__init__(self, parent)
        self.TRAIN_BUTTON = QWidget
        self.update_last_layer_input = lambda x: x
        self.layout = QVBoxLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        # Creating de graph
        self.fig = plt.figure(2)
        self.ax = plt.subplot()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setFocus()
        self.canvas.mpl_connect('button_press_event', self.onclick)
        self.layout.addWidget(self.canvas)

        self.init_graph()

        self.MIN_CLASSES = 2

        self.classes = []
        self.selected_class = []
        self.points = {}

        self.plane = []
        self.maped = True

        self.canvas.draw()

    def onclick(self, event):
        plt.figure(2)

        if self.maped:
            class_output = self.algorithm.forwardPropagation(
                [event.xdata, event.ydata])
            class_type = self.class_type(list(class_output))
            self.ax.scatter(event.xdata,
                            event.ydata,
                            s=10,
                            c=self.classes[class_type][1],
                            marker='o')
        elif self.selected_class:
            plt.scatter(event.xdata,
                        event.ydata,
                        s=10,
                        marker='o',
                        c=self.selected_class[1])

            if self.selected_class[0] in self.points.keys():
                self.points.get(self.selected_class[0]).append(
                    [event.xdata, event.ydata])
            else:
                self.points[self.selected_class[0]] = [[
                    event.xdata, event.ydata
                ]]
                classes = len(self.points.keys())
                if classes > 2:
                    try:
                        self.update_last_layer_input(classes)
                    except AttributeError:
                        pass

        self.canvas.draw()

        if len(self.points.keys()) >= self.MIN_CLASSES:
            self.TRAIN_BUTTON.setEnabled(True)

    def update_scatter_colors(self):
        plt.figure(2)
        self.ax = plt.gca()
        if not self.maped:
            self.clearPlot()
        for _class in self.points.items():
            points = _class[1]
            for point in points:
                plt.scatter(point[0],
                            point[1],
                            s=10,
                            marker='o',
                            c=self.classes[int(_class[0]) - 1][1])

        self.canvas.draw()

    def init_graph(self):
        plt.figure(2)
        # plt.tight_layout()
        self.ax = plt.gca()
        self.fig.set_facecolor('#323232')
        self.ax.grid(zorder=0)
        self.ax.set_axisbelow(True)
        self.ax.set_xlim([-5, 5])
        self.ax.set_ylim([-5, 5])
        self.ax.set_xticks(range(-5, 6))
        self.ax.set_yticks(range(-5, 6))
        self.ax.axhline(y=0, color='#323232')
        self.ax.axvline(x=0, color='#323232')
        self.ax.spines['right'].set_visible(False)
        self.ax.spines['top'].set_visible(False)
        self.ax.spines['bottom'].set_visible(False)
        self.ax.spines['left'].set_visible(False)
        self.ax.tick_params(axis='x', colors='#b1b1b1')
        self.ax.tick_params(axis='y', colors='#b1b1b1')

    def clearPlot(self):
        self.maped = False
        plt.figure(2)
        plt.clf()
        self.init_graph()
        self.canvas.draw()

    def set_donut(self):
        self.selected_class.clear()
        self.points.clear()
        self.clearPlot()

        SIZE = 10

        class_a = np.linspace(3, 3.5, 1)
        class_b = np.linspace(1.5, 2, 2)
        class_c = np.linspace(1, 1.5, 1)

        theta = np.linspace(0, 2 * np.pi, 60)
        for rad in class_a:
            for t in theta:
                x1 = rad * np.cos(t)
                x2 = rad * np.sin(t)
                self.ax.scatter(x1, x2, s=SIZE, c=self.classes[0][1])
                if self.classes[0][0] in self.points.keys():
                    self.points.get(self.classes[0][0]).append([x1, x2])
                else:
                    self.points[self.classes[0][0]] = [[x1, x2]]
        theta = np.linspace(0, 2 * np.pi, 30)
        for rad in class_b:
            for t in theta:
                x1 = rad * np.cos(t)
                x2 = rad * np.sin(t)
                self.ax.scatter(x1, x2, s=SIZE, c=self.classes[1][1])
                if self.classes[1][0] in self.points.keys():
                    self.points.get(self.classes[1][0]).append([x1, x2])
                else:
                    self.points[self.classes[1][0]] = [[x1, x2]]
        theta = np.linspace(0, 2 * np.pi, 10)
        for rad in class_c:
            for t in theta:
                x1 = rad * np.cos(t)
                x2 = rad * np.sin(t)
                self.ax.scatter(x1, x2, s=SIZE, c=self.classes[2][1])
                if self.classes[2][0] in self.points.keys():
                    self.points.get(self.classes[2][0]).append([x1, x2])
                else:
                    self.points[self.classes[2][0]] = [[x1, x2]]
        self.canvas.draw()

        self.update_last_layer_input(3)

    def set_map(self):
        self.selected_class.clear()
        self.points.clear()
        self.clearPlot()

        map = {
            '1': [[-0.4813362410180009, 0.6367826942149781],
                  [-0.7544243096977539, 0.24755223526818781],
                  [-0.9763083655000528, 0.12778901713071367],
                  [-1.3688724642271977, -0.26144144181607665],
                  [-1.522484502859558, -0.5009678780910249],
                  [-1.6419605329069498, -0.8901983370378153],
                  [-1.6078245243219813, -1.1896063823815002],
                  [-1.4030084728121661, -1.5488960367939226],
                  [-1.0275123783775069, -2.147712127481293],
                  [-0.7544243096977539, -2.147712127481293],
                  [-0.22531617663073256, -1.9980081048094505],
                  [0.13311191351144291, -1.788422473068871],
                  [0.6622200465784642, -1.5788368413282914],
                  [0.8329000895033092, -1.069843164244027],
                  [1.1059881581830622, -0.6207310962284991],
                  [1.0718521495980928, -0.021915005541128352],
                  [1.0547841453056082, 0.786486716886821],
                  [0.8158320852108245, 0.9361907395586639],
                  [0.4232679864836806, 1.355362003039823],
                  [-0.0546361337058876, 0.7266051078180844],
                  [-0.49840424531048555, 0.09784821259634491],
                  [-0.7373563054052692, -0.2913822463504454],
                  [-0.9421723569150835, -0.7105535098316045],
                  [-1.1128523998399293, -0.920139141572184],
                  [-0.9592403612075682, -1.3093696005189743],
                  [-0.7544243096977539, -1.5488960367939226],
                  [-0.3447922066781244, -1.638718450397028],
                  [0.1501799178039276, -1.3093696005189743],
                  [0.3549959693137419, -1.1896063823815002],
                  [0.5939480294085255, -0.7704351189003411],
                  [0.6622200465784642, -0.2913822463504454],
                  [0.5598120208235562, 0.18767062619945118],
                  [0.18431592638889605, 0.30743384433692533],
                  [-0.24238418092321723, -0.4710270735566562],
                  [-0.6008122710653927, -0.9800207506409206],
                  [-0.5325402538954549, -1.0399023597096582],
                  [0.01363588346405109, -0.8901983370378153],
                  [0.30379195643628876, -0.35126385541918204],
                  [0.18431592638889605, -0.32132305088481417],
                  [-0.5496082581879387, -0.5009678780910249]],
            '2': [[-1.2152604255948365, 0.9062499350242952],
                  [-1.6419605329069498, 0.9661315440930318],
                  [-2.0174566273416104, 1.1457763712992435],
                  [-2.3929527217762705, 1.2655395894367167],
                  [-2.6148367775785695, 1.5649476347804026],
                  [-2.8025848247958995, 1.774533266520982],
                  [-2.973264867720745, 2.22364533453651],
                  [-3.075672893475652, 2.5829349889489315],
                  [-3.161012914938075, 3.0320470569644584],
                  [-3.1098089020606214, 4.0200936065986195],
                  [-3.1098089020606214, 4.079975215667357],
                  [-3.161012914938075, 4.738672915423464],
                  [-2.870856841965838, 4.768613719957832],
                  [-2.666040790456023, 4.7087321108890965],
                  [-2.136932657389002, 4.738672915423464],
                  [-1.5907565200294966, 4.469205674614148],
                  [-1.07871639125496, 4.559028088217254],
                  [-0.6520162839428467, 4.5889688927516215],
                  [0.11604390921895824, 4.918317742629675],
                  [-0.22531617663073256, 4.529087283682884],
                  [0.5598120208235562, 4.529087283682884],
                  [1.174260175353, 4.738672915423464],
                  [1.2596001968154233, 4.738672915423464],
                  [-2.751380811918446, 3.810507974858041],
                  [-2.6148367775785695, 3.241632688705039],
                  [-2.290544696021363, 2.493112575345826],
                  [-1.983320618756641, 2.103882116399035],
                  [-1.7443685586618578, 1.774533266520982],
                  [-1.4883484942745895, 1.5350068302460338],
                  [-1.1128523998399293, 1.3254211985054543],
                  [-0.6690842882353305, 1.3254211985054543],
                  [-0.36186021097060905, 1.5948884393147704],
                  [0.20138393068138072, 2.463171770811458],
                  [0.5598120208235562, 2.373349357208353],
                  [1.0889201538905775, 2.0440005073302983],
                  [-2.1881366702664558, 3.7506263657893033],
                  [-1.6590285371994344, 2.7625798161551423],
                  [-1.2493964341798058, 1.9841188982615616],
                  [-0.8226963268676917, 2.103882116399035],
                  [-0.9421723569150835, 2.792520620689512],
                  [-1.1981924213023518, 3.4212775159112496],
                  [-1.4371444813971355, 3.780567170323671],
                  [-0.5496082581879387, 3.63086314765183],
                  [-0.31065619809315503, 3.241632688705039],
                  [-0.5496082581879387, 2.7326390116207744],
                  [0.6110160337010102, 3.9901528020642516],
                  [0.679288050870948, 3.6009223431174604],
                  [0.20138393068138072, 3.211691884170671],
                  [0.30379195643628876, 3.7506263657893033],
                  [1.0889201538905775, 2.8224614252238798],
                  [1.2596001968154233, 2.373349357208353],
                  [1.3961442311552998, 1.6248292438491392],
                  [1.4644162483252376, 1.2355987849023489],
                  [1.7033683084200213, 1.8942964846584562],
                  [1.7033683084200213, 2.253586139070878],
                  [1.6180282869575988, 2.6128757934832993],
                  [1.4132122354477836, 3.690744756720566],
                  [1.2937362054003918, 3.7506263657893033],
                  [1.6350962912500835, 4.109916020201725],
                  [1.8569803470523825, 3.4212775159112496],
                  [2.0105923856847427, 2.6128757934832993],
                  [1.942320368514805, 1.8643556801240875],
                  [1.6350962912500835, 0.45713786700876735],
                  [1.3790762268628152, -0.3812046599535508],
                  [1.5497562697876601, -1.1297247733127636],
                  [1.9764563770997743, -0.4410862690222874],
                  [2.1812724286095886, -0.17161902821297126],
                  [2.812788587431516, 1.1457763712992435],
                  [2.164204424317104, 1.2954803939710855],
                  [2.1300684157321346, 0.6068418896806103],
                  [2.386088480119403, 0.786486716886821],
                  [2.590904531629217, 0.8164275214211898],
                  [2.659176548799156, 1.654770048383508],
                  [2.437292492996857, 1.9841188982615616]],
            '3': [[1.9764563770997743, 4.559028088217254],
                  [2.164204424317104, 3.900330388461146],
                  [2.5226325144592803, 3.241632688705039],
                  [2.6762445530916406, 2.6128757934832993],
                  [2.915196613186424, 2.073941311864667],
                  [3.0688086518187845, 1.5350068302460338],
                  [2.983468630356363, 0.9062499350242952],
                  [2.7957205831390315, 0.06790740806197704],
                  [2.539700518751765, -0.5309086826253937],
                  [2.369020475826918, -0.8003759234347099],
                  [2.1812724286095886, -1.1596655778471323],
                  [2.044728394269712, -1.5788368413282914],
                  [1.9252523642223203, -1.908185691206345],
                  [1.8228443384674131, -2.0578897138781875],
                  [1.6863003041275366, -2.5369425864280837],
                  [1.4644162483252376, -3.40522591792477],
                  [1.3449402182778458, -3.824397181405929],
                  [1.3108042096928765, -3.9741012040777717],
                  [1.3278722139853611, -4.303450053955825],
                  [1.720436312712506, -4.782502926505722],
                  [1.993524381392259, -4.842384535574459],
                  [2.369020475826918, -4.872325340108827],
                  [3.137080668988723, -4.752562121971353],
                  [3.5467127720083518, -4.483094881162037],
                  [3.956344875027982, -4.692680512902616],
                  [4.451316999510034, -4.483094881162037],
                  [4.724405068189787, -4.542976490230774],
                  [4.792677085359724, -3.2555218952529277],
                  [4.912153115407117, -2.117771322946924],
                  [4.792677085359724, -1.638718450397028],
                  [4.639065046727364, 0.008025798993239519],
                  [4.673201055312333, 0.4271970624743986],
                  [4.707337063897302, 0.8463683259555577],
                  [4.843881098237178, 2.1637637254677724],
                  [4.843881098237178, 2.6128757934832993],
                  [4.929221119699601, 4.0200936065986195],
                  [4.809745089652209, 4.409324065545411],
                  [4.314772965170157, 4.678791306354727],
                  [3.973412879320467, 4.648850501820359],
                  [3.4443047462534455, 4.738672915423464],
                  [2.915196613186424, 4.7087321108890965],
                  [2.3007484586569813, 4.858436133560938],
                  [2.556768523044248, 4.1398568247360945],
                  [2.932264617478909, 3.5410407340487247],
                  [3.4443047462534455, 3.5709815385830925],
                  [3.922208866443013, 3.5709815385830925],
                  [4.348908973755126, 3.5410407340487247],
                  [3.358964724791022, 4.1398568247360945],
                  [3.751528823518168, 2.5829349889489315],
                  [3.137080668988723, 2.3134677481396153],
                  [3.6661888020557445, 2.1937045300021403],
                  [4.109956913660342, 2.0440005073302983],
                  [4.1611609265377965, 1.2355987849023489],
                  [3.751528823518168, 0.8164275214211898],
                  [3.4955087591308995, 0.3373746488712932],
                  [3.085876656111269, -0.1416782236786025],
                  [2.6933125573841252, -0.5907902916941303],
                  [2.4202244887043722, -1.4590736231908172],
                  [2.266612450072012, -2.566883390962452],
                  [2.0105923856847427, -3.1656994816498223],
                  [1.8228443384674131, -3.5249891360622443],
                  [1.993524381392259, -3.914219595009035],
                  [2.317816462949464, -4.033982813146509],
                  [2.6762445530916406, -3.914219595009035],
                  [2.932264617478909, -3.854337985940298],
                  [4.058752900782888, -3.5549299405966126],
                  [2.6933125573841252, -3.1357586771154535],
                  [3.2736247033286006, -3.1956402861841906],
                  [3.3931007333759915, -3.824397181405929],
                  [4.127024917952827, -3.9741012040777717],
                  [4.297704960877672, -2.4770609773593466],
                  [3.5467127720083518, -2.626765000031189],
                  [3.256556699036116, -2.3572977592218725],
                  [2.7445165702615792, -2.0578897138781875],
                  [3.085876656111269, -1.339310405053343],
                  [3.751528823518168, -0.4710270735566562],
                  [4.178228930830281, -0.4710270735566562],
                  [4.024616892197919, -1.339310405053343],
                  [3.4955087591308995, -2.1776529320156612],
                  [4.24650094800022, -0.7704351189003411],
                  [3.580848780593321, -1.2494879914502377]],
            '4': [[1.1913281796454847, -1.4291328186564485],
                  [1.2083961839379693, -1.9980081048094505],
                  [1.0377161410131235, -2.447120172824978],
                  [0.7987640809183398, -3.1656994816498223],
                  [0.6451520422859796, -3.794456376871561],
                  [0.4744719993611337, -3.8842787904746663],
                  [0.06483989634150422, -4.123805226749615],
                  [-0.0546361337058876, -4.21362764035272],
                  [-0.3277242023856397, -4.542976490230774],
                  [0.2696559478513194, -4.812443731040091],
                  [0.4744719993611337, -4.842384535574459],
                  [0.6622200465784642, -2.2974161501531354],
                  [0.3549959693137419, -2.507001781893715],
                  [0.01363588346405109, -2.746528218168663],
                  [-0.3277242023856397, -2.746528218168663],
                  [-0.8226963268676917, -2.5968241954968203],
                  [-1.5566205114445273, -2.507001781893715],
                  [-1.8297085801242803, -2.117771322946924],
                  [-1.9662526144641568, -1.5189552322595539],
                  [-2.2393406831439093, -1.7285408640001334],
                  [-2.478292743238693, -1.8483040821376076],
                  [-2.6831087947485077, -2.2974161501531354],
                  [-2.9903328720132296, -2.8962322408405057],
                  [-3.3658289664478893, -3.5848707451309814],
                  [-3.4853049964952816, -3.764515572337192],
                  [-3.8095970780524877, -4.123805226749615],
                  [-3.9120051038073953, -4.123805226749615],
                  [-4.321637206827024, -4.513035685696405],
                  [-4.714201305554169, -4.842384535574459],
                  [-4.475249245459385, -4.902266144643196],
                  [-3.7071890522975806, -4.812443731040091],
                  [-3.6047810265426734, -4.812443731040091],
                  [-2.870856841965838, -4.483094881162037],
                  [-2.358816713191301, -4.692680512902616],
                  [-1.846776584416765, -3.9741012040777717],
                  [-1.5395525071520426, -4.093864422215246],
                  [-1.2323284298873212, -3.9741012040777717],
                  [-1.07871639125496, -4.063923617680878],
                  [-0.7202883011127845, -4.303450053955825],
                  [-0.583744266772908, -4.782502926505722],
                  [-1.0445803826699906, -4.812443731040091],
                  [-1.300600447057259, -4.752562121971353],
                  [-1.6078245243219813, -4.722621317436984],
                  [-2.256408687436394, -4.513035685696405],
                  [-2.478292743238693, -4.722621317436984],
                  [-2.7172448033334766, -4.662739708368248],
                  [-3.041536884890683, -3.734574767802824],
                  [-2.785516820503415, -3.495048331527876],
                  [-2.2393406831439093, -2.327356954687504],
                  [-2.3076127003138476, -2.626765000031189],
                  [-2.4612247389462083, -3.01599545897798],
                  [-1.9150486015867028, -3.0758770680467165],
                  [-1.6078245243219813, -3.3154035043216648],
                  [-0.7714923139902385, -2.986054654443611],
                  [-0.6178802753578774, -2.926173045374874],
                  [-0.0546361337058876, -3.1656994816498223],
                  [0.11604390921895824, -3.2255810907185594],
                  [0.25258794355883474, -3.2555218952529277],
                  [-0.1911801680457632, -3.7046339632684555],
                  [-0.44720023243303153, -3.5848707451309814],
                  [-0.856832335452661, -3.5848707451309814],
                  [-1.0445803826699906, -3.5848707451309814],
                  [-1.5395525071520426, -3.5848707451309814],
                  [-2.631904781871054, -3.465107526993507],
                  [-2.290544696021363, -3.794456376871561],
                  [-2.5465647604086312, -4.033982813146509],
                  [-2.1540006616814864, -3.734574767802824]],
            '5': [[-1.795572571539311, 0.21761143073381906],
                  [-2.1198646530965175, -0.2315006372817079],
                  [-2.1881366702664558, -0.32132305088481417],
                  [-2.444156734653724, -0.7105535098316045],
                  [-2.666040790456023, -1.0997839687783948],
                  [-2.9561968634282603, -1.6686592549313968],
                  [-3.3316929578629204, -2.3572977592218725],
                  [-3.6047810265426734, -2.8064098272374],
                  [-4.133889159609694, -3.465107526993507],
                  [-4.4581812411669, -3.734574767802824],
                  [-4.7824733227241065, -3.9741012040777717],
                  [-4.662997292676716, -3.1357586771154535],
                  [-4.765405318431623, -2.9561138499092428],
                  [-4.748337314139138, -2.2974161501531354],
                  [-4.5605892669218075, -1.219547186915869],
                  [-4.611793279799262, -0.6806127052972357],
                  [-4.731269309846653, -0.2315006372817079],
                  [-4.731269309846653, 0.45713786700876735],
                  [-4.748337314139138, 1.20565798036798],
                  [-4.799541327016591, 1.774533266520982],
                  [-4.594725275506777, 2.8224614252238798],
                  [-4.765405318431623, 3.2715734932394067],
                  [-4.594725275506777, 4.738672915423464],
                  [-4.86781334418653, 4.4392648700797785],
                  [-4.850745339894045, 4.2296792383392],
                  [-4.4581812411669, 4.858436133560938],
                  [-4.338705211119509, 4.82849532902657],
                  [-4.08268514673224, 4.768613719957832],
                  [-3.7242570565900652, 4.798554524492202],
                  [-3.5877130222501887, 4.559028088217254],
                  [-3.690121048005096, 3.9901528020642516],
                  [-3.7071890522975806, 3.5709815385830925],
                  [-3.6047810265426734, 3.091928666033196],
                  [-3.570645017957704, 2.7026982070864065],
                  [-3.502373000787766, 2.103882116399035],
                  [-3.3999649750328587, 1.20565798036798],
                  [-3.1268769063531057, 0.8763091304899264],
                  [-2.6831087947485077, 0.7266051078180844],
                  [-2.478292743238693, 0.5170194760775049],
                  [-2.3929527217762705, 0.36731545340566196],
                  [-2.597768773286085, -0.08179661460986587],
                  [-2.819652829088384, -0.3812046599535508],
                  [-3.1439449106455903, -0.8303167279690786],
                  [-3.673053043712611, -1.4590736231908172],
                  [-3.997345125269818, -2.2974161501531354],
                  [-4.202161176779632, -2.626765000031189],
                  [-4.304569202534539, -2.566883390962452],
                  [-4.099753151024725, -1.8183632776032397],
                  [-4.031481133854787, -1.1896063823815002],
                  [-4.150957163902179, -0.8901983370378153],
                  [-4.270433193949571, -0.32132305088481417],
                  [-4.304569202534539, 0.4870786715431361],
                  [-4.2875011982420554, 1.4451844166429284],
                  [-4.389909223996963, 2.4332309662770886],
                  [-4.406977228289447, 2.8823430342926173],
                  [-4.236297185364601, 3.930271192995514],
                  [-4.202161176779632, 3.3015142977737764],
                  [-4.065617142439756, 2.7625798161551423],
                  [-3.8949370995149106, 2.1937045300021403],
                  [-3.8608010909299413, 1.2954803939710855],
                  [-3.6218490308351576, 0.5170194760775049],
                  [-3.3316929578629204, 0.09784821259634491],
                  [-3.0927408977681368, -0.1416782236786025],
                  [-2.887924846258322, -0.2315006372817079],
                  [-3.2122169278155286, -0.7404943143659732],
                  [-3.3999649750328587, -0.7704351189003411],
                  [-3.6389170351276423, -1.0997839687783948],
                  [-3.8437330866374566, -1.489014427725186],
                  [-3.8095970780524877, -0.6506719007628678],
                  [-3.758393065175034, -0.5608494871597616],
                  [-3.9290731080998795, -0.05185581007549711],
                  [-4.099753151024725, 0.4271970624743986],
                  [-4.270433193949571, 0.9960723486274006],
                  [-4.355773215411993, 1.6847108529178767],
                  [-4.406977228289447, 2.253586139070878]]
        }

        class_a = map['1']
        class_b = map['2']
        class_c = map['3']
        class_d = map['4']
        class_e = map['5']

        size = 10

        for item in class_a:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[0][1])
        for item in class_b:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[1][1])
        for item in class_c:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[2][1])
        for item in class_d:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[3][1])
        for item in class_e:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[4][1])

        self.points = map

        self.canvas.draw()

        self.update_last_layer_input(5)

    def set_xor(self):
        self.selected_class.clear()
        self.points.clear()
        self.clearPlot()

        xor = {
            '1': [[-1.8809125930017339, 0.30743384433692533],
                  [-1.6078245243219813, 0.2774930398025566],
                  [-1.2152604255948365, 0.3972562579400307],
                  [-0.8226963268676917, 0.2774930398025566],
                  [-0.3959962195555784, 0.21761143073381906],
                  [-0.17411216375327943, 0.3373746488712932],
                  [-0.20824817233824788, 0.9062499350242952],
                  [-0.20824817233824788, 1.4451844166429284],
                  [-0.20824817233824788, 1.8643556801240875],
                  [-1.0957843955474447, 1.7445924619866133],
                  [-0.6690842882353305, 1.8344148755897187],
                  [-0.4813362410180009, 1.80447407105535],
                  [-1.4542124856896201, 1.774533266520982],
                  [-1.7785045672468263, 1.8643556801240875],
                  [-1.846776584416765, 1.654770048383508],
                  [-1.8809125930017339, 1.20565798036798],
                  [-1.8638445887092492, 0.6966643032837156],
                  [-1.5054164985670742, 0.6367826942149781],
                  [-1.4712804899821048, 1.2355987849023489],
                  [-1.1981924213023518, 1.4451844166429284],
                  [-0.8226963268676917, 1.2655395894367167],
                  [-0.5666762624804234, 1.355362003039823],
                  [-0.4813362410180009, 1.085894762230506],
                  [-0.6520162839428467, 0.8463683259555577],
                  [-0.9763083655000528, 0.8763091304899264],
                  [-1.2493964341798058, 0.8463683259555577],
                  [-0.5154722496029702, 0.5469602806118727],
                  [0.09897590492647357, -0.26144144181607665],
                  [0.09897590492647357, -0.7404943143659732],
                  [0.13311191351144291, -1.2794287959846065],
                  [0.11604390921895824, -1.3093696005189743],
                  [0.16724792209641226, -1.7584816685345022],
                  [0.4061999821911959, -1.7584816685345022],
                  [0.7475600680408867, -1.7584816685345022],
                  [0.7134240594559174, -1.6986000594657655],
                  [1.1571921710605153, -1.6986000594657655],
                  [1.6180282869575988, -1.8483040821376076],
                  [1.8399123427598978, -1.7584816685345022],
                  [1.8569803470523825, -1.1297247733127636],
                  [1.891116355637351, -0.5309086826253937],
                  [1.8057763341749284, -0.17161902821297126],
                  [1.3790762268628152, -0.2913822463504454],
                  [0.9353081152582163, -0.11173741914423463],
                  [0.3549959693137419, -0.17161902821297126],
                  [0.679288050870948, -0.32132305088481417],
                  [0.8670360980882785, -0.35126385541918204],
                  [0.6110160337010102, -0.5907902916941303],
                  [0.6110160337010102, -0.7404943143659732],
                  [0.679288050870948, -1.0399023597096582],
                  [0.9353081152582163, -1.3093696005189743],
                  [1.276668201107908, -1.339310405053343],
                  [1.2596001968154233, -0.8303167279690786],
                  [1.0547841453056082, -0.7704351189003411],
                  [0.38913197789871123, -1.0399023597096582],
                  [0.32085996072877254, -0.6207310962284991],
                  [1.498552256910207, -1.4590736231908172],
                  [1.5156202612026917, -0.7704351189003411]],
            '2': [[0.18431592638889605, 1.80447407105535],
                  [0.18431592638889605, 1.3254211985054543],
                  [0.20138393068138072, 0.7565459123524523],
                  [0.2184519349738654, 0.2774930398025566],
                  [0.4232679864836806, 0.30743384433692533],
                  [0.6280840379934949, 0.2774930398025566],
                  [1.003580132428155, 0.30743384433692533],
                  [1.276668201107908, 0.3972562579400307],
                  [1.5668242740801448, 0.3373746488712932],
                  [1.8569803470523825, 0.4271970624743986],
                  [1.7545723212974753, 1.0260131531617693],
                  [1.8057763341749284, 1.4751252211772972],
                  [1.7375043170049906, 1.8344148755897187],
                  [1.0889201538905775, 1.80447407105535],
                  [0.5427440165310724, 1.80447407105535],
                  [0.7475600680408867, 1.80447407105535],
                  [1.3790762268628152, 1.774533266520982],
                  [1.2425321925229387, 1.4152436121085596],
                  [0.5768800251160409, 1.355362003039823],
                  [0.7646280723333714, 1.3853028075741909],
                  [0.6110160337010102, 1.1158355667648747],
                  [0.5256760122385877, 0.7565459123524523],
                  [0.730492063748402, 0.786486716886821],
                  [0.9011721066732479, 1.1457763712992435],
                  [1.1230561624755468, 0.8463683259555577],
                  [1.3449402182778458, 0.8463683259555577],
                  [1.447348244032753, 1.20565798036798],
                  [1.5497562697876601, 1.3853028075741909],
                  [-1.8638445887092492, -0.26144144181607665],
                  [-1.846776584416765, -0.6806127052972357],
                  [-1.8809125930017339, -1.1596655778471323],
                  [-1.8126405758317956, -1.6986000594657655],
                  [-1.6419605329069498, -1.7285408640001334],
                  [-1.300600447057259, -1.6986000594657655],
                  [-1.0104443740850222, -1.788422473068871],
                  [-0.6690842882353305, -1.8183632776032397],
                  [-0.3959962195555784, -1.8183632776032397],
                  [-0.17411216375327943, -1.8183632776032397],
                  [-0.20824817233824788, -1.339310405053343],
                  [-0.20824817233824788, -0.8303167279690786],
                  [-0.20824817233824788, -0.32132305088481417],
                  [-0.4130642238480631, -0.20155983274734002],
                  [-0.7202883011127845, -0.32132305088481417],
                  [-1.1981924213023518, -0.20155983274734002],
                  [-1.5054164985670742, -0.11173741914423463],
                  [-1.7102325500768885, -0.5907902916941303],
                  [-1.7102325500768885, -0.920139141572184],
                  [-1.6078245243219813, -1.339310405053343],
                  [-1.4542124856896201, -1.3692512095877118],
                  [-1.1811244170098671, -1.3692512095877118],
                  [-0.805628322575207, -1.219547186915869],
                  [-0.7373563054052692, -1.1896063823815002],
                  [-0.7885603182827223, -0.8003759234347099],
                  [-1.0616483869624753, -0.6506719007628678],
                  [-1.300600447057259, -0.7105535098316045],
                  [-1.3347364556422283, -0.8901983370378153],
                  [-1.0616483869624753, -1.1297247733127636],
                  [-0.9933763697925375, -1.069843164244027],
                  [-0.6008122710653927, -1.4590736231908172],
                  [-0.4813362410180009, -0.6506719007628678],
                  [-0.4642682367255162, -1.0399023597096582]],
            '3': [[-1.8638445887092492, 2.7326390116207744],
                  [-1.5566205114445273, 2.7026982070864065],
                  [-1.0445803826699906, 2.6128757934832993],
                  [-0.7202883011127845, 2.7625798161551423],
                  [-0.31065619809315503, 2.642816598017669],
                  [0.16724792209641226, 2.7625798161551423],
                  [0.45740399506864904, 2.7625798161551423],
                  [0.7646280723333714, 2.672757402552037],
                  [1.1059881581830622, 2.523053379880194],
                  [1.5156202612026917, 2.523053379880194],
                  [1.993524381392259, 2.4032901617427207],
                  [2.1300684157321346, 2.4032901617427207],
                  [2.2324764414870426, 2.073941311864667],
                  [2.1983404329020733, 1.3254211985054543],
                  [2.2495444457795273, 0.8164275214211898],
                  [2.2324764414870426, 0.3972562579400307],
                  [2.2324764414870426, -0.17161902821297126],
                  [2.1812724286095886, -0.6806127052972357],
                  [2.2495444457795273, -1.1896063823815002],
                  [2.2495444457795273, -2.0578897138781875],
                  [2.2495444457795273, -2.4171793682906095],
                  [2.0617963985621968, -2.5968241954968203],
                  [1.7033683084200213, -2.5968241954968203],
                  [1.3790762268628152, -2.626765000031189],
                  [1.1230561624755468, -2.686646609099926],
                  [0.6110160337010102, -2.566883390962452],
                  [0.25258794355883474, -2.5369425864280837],
                  [-0.3789282152630937, -2.5369425864280837],
                  [-0.6861522925278152, -2.7165874136342945],
                  [-1.300600447057259, -2.626765000031189],
                  [-1.6078245243219813, -2.5968241954968203],
                  [-2.2052046745589404, -2.5968241954968203],
                  [-2.3929527217762705, -2.327356954687504],
                  [-2.4612247389462083, -1.8183632776032397],
                  [-2.375884717483786, -1.069843164244027],
                  [-2.1881366702664558, 0.06790740806197704],
                  [-2.2052046745589404, 0.786486716886821],
                  [-2.2734766917288782, 1.1457763712992435],
                  [-2.2052046745589404, 2.073941311864667],
                  [-2.2734766917288782, 2.5829349889489315],
                  [-2.1198646530965175, 2.6128757934832993],
                  [-2.358816713191301, -0.2913822463504454],
                  [-2.5124287518236623, -0.7404943143659732],
                  [-2.9903328720132296, -0.4410862690222874],
                  [-3.2975569492779515, 1.4751252211772972],
                  [-3.1098089020606214, 3.0320470569644584],
                  [-2.939128859135776, 3.3015142977737764],
                  [-2.8367208333808684, 1.9841188982615616],
                  [-2.9561968634282603, 0.6068418896806103],
                  [-3.0927408977681368, -1.7584816685345022],
                  [-3.0244688805981985, -3.1656994816498223],
                  [-2.3929527217762705, -3.914219595009035],
                  [-1.8979805972942185, -3.5848707451309814],
                  [-1.795572571539311, -3.6148115496653497],
                  [-2.853788837673353, -3.1656994816498223],
                  [-1.624892528614466, -3.734574767802824],
                  [-0.3789282152630937, -3.6746931587340868],
                  [0.23551993926635006, -3.5848707451309814],
                  [1.1059881581830622, -3.6447523541997184],
                  [1.8740483513448671, -3.2854626997872964],
                  [2.369020475826918, -3.2255810907185594],
                  [2.812788587431516, -2.447120172824978],
                  [2.8469245960164855, -1.489014427725186],
                  [2.8469245960164855, -0.11173741914423463],
                  [2.8469245960164855, 1.4751252211772972],
                  [2.7274485659690946, 2.523053379880194],
                  [2.539700518751765, 3.511099929514355],
                  [2.2324764414870426, 3.7506263657893033],
                  [1.4644162483252376, 3.900330388461146],
                  [1.0206481367206397, 4.050034411132989],
                  [-1.2152604255948365, 4.4392648700797785],
                  [-2.0857286445115486, 4.2296792383392],
                  [-1.9321166058791874, 3.780567170323671],
                  [-1.5907565200294966, 3.481159124979987],
                  [-1.1811244170098671, 3.3913367113768818],
                  [-0.44720023243303153, 3.481159124979987],
                  [0.2696559478513194, 3.481159124979987],
                  [1.1230561624755468, 3.511099929514355],
                  [1.771640325589959, 3.1518102751019335],
                  [2.095932407147166, 2.8524022297582476],
                  [2.71038056167661, 1.5948884393147704],
                  [-2.3246807046063322, 3.4512183204456175],
                  [-2.7343128076259613, 2.1937045300021403],
                  [-3.570645017957704, 1.1457763712992435],
                  [-3.536509009372735, -0.05185581007549711],
                  [-3.6047810265426734, -2.0279489093438188],
                  [-3.3316929578629204, -3.105817872581085],
                  [-3.1268769063531057, -3.764515572337192],
                  [-2.5807007689936, -4.183686835818351],
                  [-1.07871639125496, -4.692680512902616],
                  [-0.31065619809315503, -4.60285809929951],
                  [0.7475600680408867, -4.632798903833879],
                  [1.7033683084200213, -4.4232132720932995],
                  [2.539700518751765, -4.4232132720932995],
                  [3.239488694743631, -4.483094881162037],
                  [3.7173928149331985, -3.5249891360622443],
                  [3.734460819225683, -0.7105535098316045],
                  [3.683256806348229, 0.21761143073381906],
                  [3.64912079776326, 2.253586139070878],
                  [3.580848780593321, 3.900330388461146],
                  [3.512576763423384, 4.109916020201725],
                  [2.2495444457795273, 4.4392648700797785],
                  [1.003580132428155, 4.7087321108890965],
                  [0.1501799178039276, 4.738672915423464],
                  [-1.0104443740850222, 4.648850501820359],
                  [-2.068660640219064, 4.529087283682884],
                  [-2.819652829088384, 4.469205674614148],
                  [-3.4341009836178276, 4.349442456476673],
                  [-3.8095970780524877, 3.660803952186198],
                  [-3.980277120977333, 1.80447407105535],
                  [-3.7242570565900652, 0.6367826942149781],
                  [-3.9120051038073953, -1.7584816685345022],
                  [-3.997345125269818, -2.147712127481293],
                  [-4.185093172487148, -2.566883390962452],
                  [-4.219229181072117, -3.01599545897798],
                  [-4.031481133854787, -3.5249891360622443],
                  [-2.3929527217762705, -4.542976490230774],
                  [3.956344875027982, -4.303450053955825],
                  [4.24650094800022, -3.9441603995434034],
                  [4.690269059604818, -1.638718450397028],
                  [4.468385003802519, -0.920139141572184],
                  [3.6149847891782905, -1.8183632776032397],
                  [3.46137275054593, -2.447120172824978],
                  [4.4854530080950035, 1.1158355667648747],
                  [4.178228930830281, 3.63086314765183],
                  [4.041684896490404, 4.888376938095307],
                  [3.410168737668476, 3.0021062524300905],
                  [3.376032729083507, 0.4271970624743986],
                  [3.3931007333759915, -0.08179661460986587],
                  [4.058752900782888, 2.792520620689512],
                  [4.55372502526494, 1.9541780937271929],
                  [-4.304569202534539, 4.0200936065986195],
                  [-4.662997292676716, 2.373349357208353],
                  [-4.611793279799262, 0.45713786700876735],
                  [-4.424045232581932, -1.638718450397028],
                  [-4.321637206827024, -2.4171793682906095],
                  [-4.133889159609694, -3.345344308856033],
                  [-3.8608010909299413, -2.20759373655003],
                  [-3.5877130222501887, -1.0099615551752894],
                  [-3.536509009372735, -0.8303167279690786],
                  [-3.8095970780524877, 2.8524022297582476],
                  [-3.8266650823449724, 2.8524022297582476],
                  [4.58786103384991, -2.3872385637562408],
                  [4.417180990925065, -3.9741012040777717],
                  [4.4342489952175494, -4.572917294765142],
                  [2.6933125573841252, -4.273509249421457],
                  [2.8981286088939395, -3.345344308856033],
                  [3.410168737668476, -2.7764690227030315]]
        }

        class_a = xor['1']
        class_b = xor['2']
        class_c = xor['3']

        size = 10

        for item in class_a:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[0][1])

        for item in class_b:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[1][1])

        for item in class_c:
            self.ax.scatter(item[0], item[1], s=size, c=self.classes[2][1])

        self.points = xor

        self.canvas.draw()

        self.update_last_layer_input(3)

    def fill_plot(self, algorithm, progress_bar, size=30, dpi=40):
        self.maped = False
        self.algorithm = algorithm

        self.figure = plt.figure(2)
        plt.clf()
        self.init_graph()
        self.ax = plt.gca()
        self.colors_class_type(len(self.classes))

        progress = 20 / dpi
        progress_count = 80

        x = list(np.linspace(-5 + size * 0.005, 5 - size * 0.005, dpi))
        y = list(np.linspace(-5 + size * 0.005, 5 - size * 0.005, dpi))
        self.plane.clear()

        for ind, i in enumerate(y):
            self.plane.append([])
            for j in x:
                class_output = self.algorithm.forwardPropagation([j, i])
                class_type = self.class_type(list(class_output))
                self.plane[ind].append(class_type)
                self.ax.scatter(j,
                                i,
                                s=size,
                                c=self.colors_class[class_type],
                                marker='s')
            progress_count += progress
            progress_bar.setValue(progress_count)

        for _class in self.points.items():
            points = _class[1]
            for point in points:
                plt.scatter(point[0],
                            point[1],
                            s=10,
                            marker='o',
                            c=self.classes[int(_class[0]) - 1][1])

        self.canvas.draw()
        self.maped = True

    def normalize_class(self, class_vector):
        normalized_class = list(np.zeros(len(class_vector), dtype=np.int32))
        normalized_class[class_vector.index(max(class_vector))] = 1
        return normalized_class

    def class_type(self, class_vector):
        return class_vector.index(max(class_vector))

    def colors_class_type(self, classes_count):
        colors = [
            'red', 'black', 'darkgreen', 'navy', 'orange', 'yellowgreen',
            'fuchsia', 'gold', 'cyan', 'pink', 'brown'
        ]
        self.colors_class = []
        for _ in range(classes_count):
            color = np.random.choice(colors)
            colors.pop(colors.index(color))
            self.colors_class.append(color)

    def show_lines(self, init_layer, bias):

        plt.figure(2)
        plt.clf()
        plt.tight_layout()

        self.fig = plt.figure(2)
        self.ax = plt.gca()
        self.init_lines(self.fig, self.ax)

        for index, (neuron, tetha) in enumerate(zip(init_layer, bias)):
            w1 = neuron[0]
            w2 = neuron[1]
            y = [(-(tetha / w1) / (tetha / w2)) * -5 + (-tetha / w1),
                 (-(tetha / w1) / (tetha / w2)) * 5 + (-tetha / w1)]
            x = [-5, 5]
            line, = self.ax.plot(x, y)
            line.set_label('Neurona {}'.format(index + 1))

        self.ax.legend()

        for _class in self.points.items():
            points = _class[1]
            for point in points:
                self.ax.scatter(point[0],
                                point[1],
                                s=5,
                                marker='o',
                                c=self.classes[int(_class[0]) - 1][1])

        self.canvas.draw()

    def show_planes(self, size=30, dpi=40):
        self.figure = plt.figure(2)
        plt.clf()
        self.init_graph()
        self.ax = plt.gca()

        x = list(np.linspace(-5 + size * 0.005, 5 - size * 0.005, dpi))
        y = list(np.linspace(-5 + size * 0.005, 5 - size * 0.005, dpi))

        for ind_y, i in enumerate(y):
            for ind_x, j in enumerate(x):
                class_type = self.plane[ind_y][ind_x]
                self.ax.scatter(j,
                                i,
                                s=size,
                                c=self.colors_class[class_type],
                                marker='s')

        for _class in self.points.items():
            points = _class[1]
            for point in points:
                plt.scatter(point[0],
                            point[1],
                            s=10,
                            marker='o',
                            c=self.classes[int(_class[0]) - 1][1])

        self.canvas.draw()

    def init_lines(self, fig, ax):
        fig.set_facecolor('#323232')
        ax.grid(zorder=0)
        ax.set_axisbelow(True)
        ax.set_xlim([-5, 5])
        ax.set_ylim([-5, 5])
        ax.set_xticks(range(-5, 6))
        ax.set_yticks(range(-5, 6))
        ax.axhline(y=0, color='#323232')
        ax.axvline(x=0, color='#323232')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.tick_params(axis='x', colors='#b1b1b1')
        ax.tick_params(axis='y', colors='#b1b1b1')
class MatplotlibWidget(QWidget):
    
    def __init__(self, parent = None):

        QWidget.__init__(self, parent)

        self.CanvasWindow = parent

        self.canvas = FigureCanvas(Figure(dpi=DPI))
        
        # attach matplotlib canvas to layout
        vertical_layout = QVBoxLayout()
        vertical_layout.addWidget(self.canvas)
        self.setLayout(vertical_layout)
        
        # canvas setup
        self.canvas.axis = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
        self.canvas.axis.spines['top'].set_visible(False)
        self.canvas.axis.spines['left'].set_visible(False)
        self.canvas.axis.set_xlim(AXIS_LIMIT)
        self.canvas.axis.set_ylim(AXIS_LIMIT)

        # variables
        self.pressed = False
        self.plotted = False
        self.pen = False
        self.N = 0
        self.arr_drawing_complex = []
        self.time = []
        self.arr_coordinates = []
        self.line_draw, = self.canvas.axis.plot([], [], c='#eb4034', linewidth=1)
        self.arr_drawing = np.empty((1,2))

        # listeners
        self.canvas.mpl_connect('button_press_event', self.on_press)
        self.canvas.mpl_connect('motion_notify_event', self.on_motion)
        self.canvas.mpl_connect('button_release_event', self.on_release)


    def on_press(self, event):
        if len(self.arr_drawing[1:]) > 0:
            self.animation._stop()

        if self.plotted:
            # reset variables
            self.plotted = False
            self.arr_drawing = np.empty((1,2))
            self.N = 0

            # clear and reload subplot
            self.canvas.axis.clear()
            self.canvas.axis.set_xlim(AXIS_LIMIT)
            self.canvas.axis.set_ylim(AXIS_LIMIT)
            self.canvas.draw()

        self.pressed = True
        self.pen = True



    def on_motion(self, event):
        if self.pressed:
            self.plotted = True

            # plot drawing points
            coordinates = [event.xdata, event.ydata]
            self.arr_drawing = np.append(self.arr_drawing, [coordinates], axis=0)
            self.line_draw.set_data(self.arr_drawing[1:,0],self.arr_drawing[1:,1])
            self.canvas.axis.draw_artist(self.line_draw)
            self.canvas.blit(self.canvas.axis.bbox)
            self.canvas.flush_events()


    def on_release(self, event):
        self.pressed = False
        self.line_draw.set_data([],[])

        if len(self.arr_drawing) > 0:
            self.runAnimation()


    def getSizes(self):
        arr_radius = np.array([item['amplitude'] for item in self.ft.arr_epicycles])
        rr_pix = (self.canvas.axis.transData.transform(np.vstack([arr_radius, arr_radius]).T) - self.canvas.axis.transData.transform(np.vstack([np.zeros(self.N), np.zeros(self.N)]).T))
        rpix, _ = rr_pix.T
        size_pt = (2*rpix/DPI*72)**2
        return size_pt


    def runAnimation(self):
        self.animation = animation.FuncAnimation(
                            self.canvas.figure,
                            self.animate,
                            init_func=self.init,
                            interval=ANIMATION_INTERVAL,
                            blit=True)


    def init(self):
        # set new values
        if self.pen:
            self.arr_drawing_complex = [complex(coordinates[0], coordinates[1]) for coordinates in self.arr_drawing[1:]]
        else:
            self.arr_drawing_complex = [complex(coordinates[0], coordinates[1]) for coordinates in self.arr_drawing[:]]
        
        self.N = len(self.arr_drawing_complex)
        self.CanvasWindow.horizontalSlider.setMaximum(self.N)
        self.CanvasWindow.horizontalSlider.setValue(self.N)

        # calculate fourier transform via FFT algorithm
        self.ft = FourierTransform(self.arr_drawing_complex)
        self.ft.toEpicycles()

        # calculate all points in time range from 0 to 2pi
        self.time = np.linspace(0,2*pi,endpoint = False, num=self.N)
        self.arr_coordinates = np.array([self.ft.getPoint(dt) for dt in self.time])

        # create all components responsible for showing animation
        self.circle = self.canvas.axis.scatter([],[], fc='None', ec='#9ac7e4', lw=1)
        self.circle.set_sizes(self.getSizes())
        self.line_connect, = self.canvas.axis.plot([], [], c='#9ac7e4', lw=1)
        self.line_plot, = self.canvas.axis.plot([], [], c='#4c6bd5', lw=2)
        self.line_plot_all, = self.canvas.axis.plot([], [], c='#4c6bd5', lw=0.5)

        return [self.circle, self.line_connect, self.line_plot, self.line_plot_all]


    def animate(self,i):
        s = self.CanvasWindow.horizontalSlider.value()+1

        # update components
        self.circle.set_offsets(self.arr_coordinates[:,:s][i%self.N,:-1])
        self.line_connect.set_data(self.arr_coordinates[:,:s][i%self.N,:,0],self.arr_coordinates[:,:s][i%self.N,:,1])
        self.line_plot.set_data(self.arr_coordinates[:,:s][:i%self.N+1,-1,0],self.arr_coordinates[:,:s][:i%self.N+1,-1,1])
        self.line_plot_all.set_data(self.arr_coordinates[:,:s][:,-1,0],self.arr_coordinates[:,:s][:,-1,1])

        return [self.circle, self.line_connect, self.line_plot, self.line_plot_all]
class MainWindow(QMainWindow):
    def __init__(self, *args, **kwargs):
        super(MainWindow, self).__init__()

        self.setWindowTitle('Data Analysis for NSOR project')
        self.setWindowIcon(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\window_icon.png'))
        '''
        q actions that are intend to be in menu or toolbar
        '''

        openFile = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\open_file.png'),
            '&Open File...', self)
        openFile.setShortcut('Ctrl+O')
        openFile.setStatusTip('Open the data file')
        openFile.triggered.connect(self.open_file)

        exitProgram = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\exit_program.png'),
            '&Exit', self)
        exitProgram.setShortcut("Ctrl+W")
        exitProgram.setStatusTip('Close the Program')
        exitProgram.triggered.connect(self.exit_program)

        renewData = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\renew.png'), '&Renew',
            self)
        renewData.setShortcut("Ctrl+R")
        renewData.setStatusTip('Reload the original data')
        renewData.triggered.connect(self.renew_data)

        self.verticalZoom = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\vertical_zoom.png'),
            '&Vertical Zoom', self)
        self.verticalZoom.setShortcut("Ctrl+Shift+V")
        self.verticalZoom.setStatusTip('Zoom in the vertical direction')
        self.verticalZoom.setCheckable(True)
        self.verticalZoom.toggled.connect(self.vzoom)

        self.horizontalZoom = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\horizontal_zoom.png'),
            '&Horizontal Zoom', self)
        self.horizontalZoom.setShortcut("Ctrl+Shift+H")
        self.horizontalZoom.setStatusTip('Zoom in the horizaontal direction')
        self.horizontalZoom.setCheckable(True)
        self.horizontalZoom.toggled.connect(self.hzoom)

        self.moveCursor = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\move_cursor.png'),
            '&Move Cursor', self)
        self.moveCursor.setShortcut("Ctrl+M")
        self.moveCursor.setStatusTip('Move cursors')
        self.moveCursor.setCheckable(True)
        self.moveCursor.toggled.connect(self.move_cursor)

        self.autoAxis = {}
        self.autoAxis['time_x'] = MyQAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_time_x.png'),
            '&Auto X axis (time)', 'time_x', self)
        self.autoAxis['time_y'] = MyQAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_time_y.png'),
            '&Auto Y axis (time)', 'time_y', self)
        self.autoAxis['freq_x'] = MyQAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_freq_x.png'),
            '&Auto X axis (freq)', 'freq_x', self)
        self.autoAxis['freq_y'] = MyQAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_freq_y.png'),
            '&Auto Y axis (freq)', 'freq_y', self)

        editParameters = QAction('&Edit Parameter', self)
        editParameters.setShortcut('Ctrl+E')
        editParameters.setStatusTip('open and edit the parameter file')
        editParameters.triggered.connect(self.edit_parameters)

        saveParameters = QAction('&Save Parameter', self)
        saveParameters.setShortcut('Ctrl+S')
        saveParameters.setStatusTip('save the parameters on screen to file')
        saveParameters.triggered.connect(self.save_parameters)

        self.data_type = QComboBox()
        self.data_type.setStatusTip(
            'bin for legacy data recorded from labview program, big endian coded binary data, npy for numpy type data'
        )
        self.data_type.addItems(['bin', '.npy'])
        '''
        setting menubar
        '''
        mainMenu = self.menuBar()  #create a menuBar
        fileMenu = mainMenu.addMenu('&File')  #add a submenu to the menu bar
        fileMenu.addAction(
            openFile)  # add what happens when this menu is interacted
        fileMenu.addSeparator()
        fileMenu.addAction(exitProgram)  # add an exit menu
        parameterMenu = mainMenu.addMenu('&Parameter')
        parameterMenu.addAction(editParameters)
        parameterMenu.addAction(saveParameters)
        '''
        setting toolbar
        '''
        self.toolbar = self.addToolBar(
            'nsor_toolbar')  #add a tool bar to the window
        if app.desktop().screenGeometry().height() == 2160:
            self.toolbar.setIconSize(QSize(100, 100))
        else:
            self.toolbar.setIconSize(QSize(60, 60))
        self.toolbar.addAction(
            openFile)  # add what happens when this tool is interacted

        self.toolbar.addWidget(self.data_type)
        self.toolbar.addAction(renewData)
        self.toolbar.addSeparator()

        self.toolbar.addAction(self.verticalZoom)
        self.toolbar.addAction(self.horizontalZoom)
        self.toolbar.addAction(self.moveCursor)
        self.toolbar.addSeparator()
        for key, item in self.autoAxis.items():
            self.autoAxis[key].setStatusTip(
                f'set {key} axis to size automatically')
            self.autoAxis[key].btnToggled.connect(self.auto_axis)
            self.toolbar.addAction(self.autoAxis[key])

        self.statusBar()  #create a status bar
        '''
        setting matplotlib
        '''

        if app.desktop().screenGeometry().height() == 2160:
            matplotlib.rcParams.update({'font.size': 28})
        else:
            matplotlib.rcParams.update({'font.size': 14})
        self.canvas = FigureCanvas(Figure(figsize=(40, 9)))
        self.fig = self.canvas.figure
        '''
        setting axis as dictionary,

        containing two axes of time and freq
        ax['time']
        ax['freq']
        also initiate the vertical lines
        vline['time_l']
        vline['time_r']
        vline['freq_l']
        vline['freq_r']
        '''
        self.ax = {}
        self.vline = {}
        self.ax['time'] = self.fig.add_subplot(121)
        self.ax['freq'] = self.fig.add_subplot(122)

        for axis in self.ax.values():
            if app.desktop().screenGeometry().height() == 2160:
                axis.tick_params(pad=20)
            elif app.desktop().screenGeometry().height() == 1080:
                axis.tick_params(pad=10)
            # axis.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

        self.fourier_lb = QLabel("Ready", self)

        self.parameters = read_parameter(PARAMETER_FILE)
        '''
        setting edits and labels as dictionary,

        representing all time and freq edits
        "file_name" is excluded
        "time_x_limit"
        "time_y_limit"
        "freq_x_limit"
        "freq_y_imit"
        "time_cursor"
        "freq_cursor"
        '''
        self.edits = {}
        labels = {}
        for key, value in self.parameters.items():
            if type(value) == list:
                val = str(value[0]) + ' ' + str(value[1])
            if key == 'file_name':
                continue
            labels[key] = QLabel(key.replace('_', ' ').title(), self)
            self.edits[key] = MyLineEdit(key, val, self)
            self.edits[key].setStatusTip(f'{key}')
            self.edits[key].textModified.connect(self.limit_and_cursor)
            if key[0:4] == 'freq':
                self.edits[key].setFixedWidth(250)
            if 'cursor' in key:
                self.vline[key[0:4] + '_l'] = self.ax[key[0:4]].axvline(
                    float(value[0]), c='red')
                self.vline[key[0:4] + '_r'] = self.ax[key[0:4]].axvline(
                    float(value[1]), c='red')
                self.vline[key[0:4] + '_l'].set_animated(True)
                self.vline[key[0:4] + '_r'].set_animated(True)

        self.integral_label = QLabel('Peak Intensity: \n0', self)

        self.zeroPadPower = QComboBox(self)
        self.zeroPadPower.addItems(['x1', 'x2', 'x4', 'x8'])
        self.zeroPadPower.setStatusTip('This sets the zerofilling of the data')
        self.zeroPadPower.activated[str].connect(self.zero_padding)
        '''
        phase stuff
        '''
        self.toolbar.addSeparator()
        first_order_phase_check = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_phase_check.png'),
            '&First order on', self)
        first_order_phase_check.setStatusTip('Check to enbale 1st order phase')
        first_order_phase_check.setShortcut('Ctrl+F')
        first_order_phase_check.toggled.connect(self.first_order_phase_check)
        first_order_phase_check.setCheckable(True)
        self.toolbar.addAction(first_order_phase_check)

        auto_phase_btn = QAction(
            QIcon(BASE_FOLDER + r'\pyqt_analysis\icons\auto_phase_btn.png'),
            '&Auto Phase', self)
        auto_phase_btn.setStatusTip('Auto phase the peak (0th order only)')
        auto_phase_btn.setShortcut('Ctrl+A')
        auto_phase_btn.triggered.connect(self.auto_phase)
        self.toolbar.addAction(auto_phase_btn)

        self.zeroth_slider = QSlider(self)
        self.zeroth_slider.setMinimum(0)
        self.zeroth_slider.setMaximum(360)
        self.zeroth_slider.setValue(0)
        self.zeroth_slider.setTickInterval(1)
        self.zeroth_slider.valueChanged.connect(self.zeroth_order_phase)
        self.zeroth_slider.sliderReleased.connect(self.slider_released)

        self.first_slider = QSlider(self)
        self.first_slider.setMinimum(0)
        self.first_slider.setMaximum(360)
        self.first_slider.setValue(0)
        self.first_slider.hide()
        self.first_slider.valueChanged.connect(self.first_order_phase)

        self.phase_info = QLabel('Current Phase: \n0th: 0\n1st: 0 \nInt: 0',
                                 self)
        '''
        setting layout
        '''
        self._main = QWidget()
        self.setCentralWidget(self._main)
        layout1 = QHBoxLayout(self._main)
        layout2 = QVBoxLayout()
        layout3 = QVBoxLayout()
        layout4 = QVBoxLayout()
        layout5 = QHBoxLayout()

        for key in labels.keys():
            if key[0:4] == 'time':
                layout2.addWidget(labels[key])
                layout2.addWidget(self.edits[key])
            elif key[0:4] == 'freq':
                layout4.addWidget(labels[key])
                layout4.addWidget(self.edits[key])
        layout4.addWidget(self.integral_label)
        layout4.addWidget(self.phase_info)

        layout4.addLayout(layout5)
        layout5.addWidget(self.zeroth_slider)
        layout5.addWidget(self.first_slider)

        layout2.addWidget(self.zeroPadPower)
        layout1.addLayout(layout2)
        layout2.addStretch(1)
        layout1.addLayout(layout3)
        layout3.addWidget(self.canvas)
        layout3.addWidget(self.fourier_lb)
        layout1.addLayout(layout4)
        # layout4.addStretch(1)

        self.threadpool = QThreadPool()  #Multithreading

    '''
    ################################################################################
    phase
    '''

    def slider_released(self):
        self.canvas.draw()
        key = 'freq'
        self.ax[key[0:4]].ticklabel_format(
            style='sci', axis='both',
            scilimits=(0, 0))  # format the tick label of the axes
        for k in self.ax.keys():
            self.ax[k].draw_artist(self.vline[k + '_l'])
            self.ax[k].draw_artist(self.vline[k + '_r'])

    def first_order_phase_check(self, toggle_state):
        if toggle_state:
            self.first_slider.show()
        else:
            self.first_slider.setValue(0)
            self.first_slider.hide()

    def auto_phase(self):
        try:
            reft = self.data['freq_y'].real[self.csL:self.csR]
            imft = self.data['freq_y'].imag[self.csL:self.csR]
            intensity_int = np.array([])
            for angle in range(360):
                phi = angle / 360 * 2 * pi
                intensity_int = np.append(
                    intensity_int,
                    np.sum(np.cos(phi) * reft + np.sin(phi) * imft))
            best_angle = intensity_int.argmax()
            best_phi = best_angle / 360 * 2 * pi
            self.zeroth_slider.setValue(best_angle)
            self.data['freq_real'] = self.data['freq_y'].real*np.cos(best_phi) + \
                                     self.data['freq_y'].imag*np.sin(best_phi)
            self.draw_phased_data()
        except AttributeError:
            dlg = QMessageBox.warning(self, 'WARNING',
                                      'No original data available!',
                                      QMessageBox.Ok)

    def zeroth_order_phase(self, value):
        phi = value / 360 * 2 * pi
        try:
            reft = self.data['freq_y'].real[self.csL:self.csR]
            imft = self.data['freq_y'].imag[self.csL:self.csR]
            self.data['freq_real'] = self.data['freq_y'].real*np.cos(phi) + \
                                     self.data['freq_y'].imag*np.sin(phi)
            intensity = np.sum(np.cos(phi) * reft + np.sin(phi) * imft)
            str = self.phase_info.text()
            str_lst = str.split('\n')
            intensity_str = "{:.5f}".format(intensity * 2)
            self.phase_info.setText(f'Current Phase: \n0th: {value}\n' +
                                    str_lst[2] + f'\nInt: {intensity_str}')
            self.draw_phased_data()
            self.canvas.blit(self.ax['freq'].bbox)
        except AttributeError:
            dlg = QMessageBox.warning(self, 'WARNING',
                                      'No original data available!',
                                      QMessageBox.Ok)

    def first_order_phase(self, value):
        intensity = 0
        str = self.phase_info.text()
        str_lst = str.split('\n')
        intensity_str = "{:.5f}".format(intensity * 2)
        self.phase_info.setText('Current Phase: \n' + str_lst[1] +
                                f'\n1st: {value}' + f'\nInt: {intensity_str}')

    def draw_phased_data(self):
        key = 'freq'
        self.ax[key].clear()
        self.ax[key].plot(self.data[key + '_x'], self.data[key + '_real'])

        cs_value = [
            float(x) for x in self.edits[key + '_cursor'].text().split(' ')
        ]
        self.vline[key + '_l'].set_xdata([cs_value[0], cs_value[0]])
        self.vline[key + '_r'].set_xdata([cs_value[1], cs_value[1]])

        lm_value = [
            float(x) for x in self.edits[key + '_x_limit'].text().split(' ')
        ]

        self.ax[key].set_xlim(lm_value[0], lm_value[1])

        self.canvas.draw()
        self.ax[key].ticklabel_format(
            style='sci', axis='both',
            scilimits=(0, 0))  # format the tick label of the axes
        for k in self.ax.keys():
            self.ax[k].draw_artist(self.vline[k + '_l'])
            self.ax[k].draw_artist(self.vline[k + '_r'])

    '''
    ################################################################################
    some less complicated slot
    '''

    def edit_parameters(self):
        os.startfile(PARAMETER_FILE)

    def save_parameters(self):
        for key in self.parameters.keys():
            if key == 'file_name':
                continue
            str = self.edits[key].text()
            self.parameters[key] = str.split(' ')

        save_parameter(PARAMETER_FILE, **self.parameters)

    def auto_axis(self, key):
        '''
        auto scale the axis
        '''
        if key != 'time_y':
            self.ax[key[0:4]].autoscale(axis=key[5])
        else:
            try:
                average = np.mean(np.abs(self.data['time_y']))
                self.ax['time'].set_ylim(-2 * average, 2 * average)
            except AttributeError:
                self.ax[key[0:4]].autoscale(axis=key[5])

        self.canvas.draw()
        self.ax[key[0:4]].ticklabel_format(
            style='sci', axis='both',
            scilimits=(0, 0))  # format the tick label of the axes
        for k in self.ax.keys():
            self.ax[k].draw_artist(self.vline[k + '_l'])
            self.ax[k].draw_artist(self.vline[k + '_r'])

    '''
    ################################################################################
    browse the figure
    calculate based on cursors
    '''

    def limit_and_cursor(self, key, text):
        '''
        respond to the change of text in the edits
        '''
        try:
            value = [float(x) for x in text.split(' ')]
            if 'limit' in key:
                if 'x' in key:
                    self.ax[key[0:4]].set_xlim(value[0], value[1])
                elif 'y' in key:
                    self.ax[key[0:4]].set_ylim(value[0], value[1])

            elif 'cursor' in key:
                self.vline[key[0:4] + '_l'].set_xdata([value[0], value[0]])
                self.vline[key[0:4] + '_r'].set_xdata([value[1], value[1]])
                try:
                    cs1 = np.argmin(
                        np.abs(self.data[key[0:4] + '_x'] - value[0])
                    )  # finding the index corresponding to the time stamp
                    cs2 = np.argmin(
                        np.abs(self.data[key[0:4] + '_x'] - value[1]))
                    if cs1 > cs2:
                        self.cursor_operation(key, cs2, cs1)
                    else:
                        self.cursor_operation(key, cs1, cs2)
                except AttributeError:
                    dlg = QMessageBox.warning(self, 'WARNING',
                                              'No original data available!',
                                              QMessageBox.Ok)

            self.canvas.draw()
            self.ax[key[0:4]].ticklabel_format(
                style='sci', axis='both',
                scilimits=(0, 0))  # format the tick label of the axes
            for k in self.ax.keys():
                self.ax[k].draw_artist(self.vline[k + '_l'])
                self.ax[k].draw_artist(self.vline[k + '_r'])

        except ValueError:
            dlg = QMessageBox.warning(self, 'WARNING', 'Input only number',
                                      QMessageBox.Ok)

    def cursor_operation(self, key, csL, csR):
        self.csL = csL
        self.csR = csR
        if 'time' in key:
            self.zero_padding(self.zeroPadPower.currentText(), [csL, csR])
        elif 'freq' in key:
            intensity = (np.sum(self.data['freq_y'].real)**2 +
                         np.sum(self.data['freq_y'].imag)**2)**(1 / 2)
            intensity_str = "{:.5f}".format(intensity)
            self.integral_label.setText(
                f'Peak Intensity: \n{intensity_str}')  #

    def cursor_lines_in_axis(self, ax):
        if ax == self.ax['time']:
            line1 = self.vline['time_l']
            line2 = self.vline['time_r']
        else:
            line1 = self.vline['freq_l']
            line2 = self.vline['freq_r']
        return line1, line2

    def move_cursor(self, state):
        def on_press(event):
            if self.in_ax:
                if self.current_line != None:
                    if event.button == 1:
                        ax = event.inaxes
                        self.last_ax = ax
                        self.c_lock = True
                        self.x0 = event.xdata
                        self.current_line.set_xdata([event.xdata, event.xdata])
                        line1, line2 = self.cursor_lines_in_axis(ax)

                        self.canvas.draw()
                        self.background = self.canvas.copy_from_bbox(ax.bbox)
                        ax.draw_artist(line1)
                        ax.draw_artist(line2)
                        self.canvas.blit(ax.bbox)

        def on_motion(event):
            ax = event.inaxes
            if ax != None:
                line1, line2 = self.cursor_lines_in_axis(ax)

                if self.c_lock:
                    self.current_line.set_xdata([event.xdata, event.xdata])
                    self.canvas.restore_region(self.background)
                    ax.draw_artist(line1)
                    ax.draw_artist(line2)
                    self.canvas.blit(ax.bbox)
                    if self.x0 > event.xdata:
                        self.c_side = 'left'
                    else:
                        self.c_side = 'right'

                else:
                    if abs(event.xdata -
                           line1.get_xdata()[0]) / self.xrange <= 0.02:
                        if self.cursor == 'arrow':
                            QApplication.setOverrideCursor(Qt.CrossCursor)
                            self.current_line = line1
                            self.cursor = 'cross'
                    elif abs(event.xdata -
                             line2.get_xdata()[0]) / self.xrange <= 0.02:
                        if self.cursor == 'arrow':
                            QApplication.setOverrideCursor(Qt.CrossCursor)
                            self.cursor = 'cross'
                            self.current_line = line2
                    else:
                        if self.cursor == 'cross':
                            QApplication.restoreOverrideCursor()
                            self.cursor = 'arrow'
                            self.current_line = None

        def on_release(event):
            if self.c_lock:
                self.background = None
                self.c_lock = False

                ax = event.inaxes
                if ax != self.last_ax:
                    ax = self.last_ax

                    limit = ax.get_xlim()
                    if self.c_side == 'left':
                        event.xdata = limit[0]
                    else:
                        event.xdata = limit[1]

                line1, line2 = self.cursor_lines_in_axis(ax)

                str1 = "{:.5E}".format(line1.get_xdata()[0])
                str2 = "{:.5E}".format(line2.get_xdata()[0])

                if line2.get_xdata()[0] < line1.get_xdata()[0]:
                    str1, str2 = str2, str1

                if ax == self.ax['freq']:
                    self.edits['freq_cursor'].setText(str1 + ' ' + str2)
                else:
                    self.edits['time_cursor'].setText(str1 + ' ' + str2)

        def move_in_ax(event):
            self.in_ax = True
            ax = event.inaxes
            xmin, xmax = ax.get_xlim()
            self.xrange = xmax - xmin

        def move_out_ax(event):
            self.out_ax = False

        if state:
            self.cursor = 'arrow'
            self.verticalZoom.setChecked(False)
            self.horizontalZoom.setChecked(False)
            self.c_lock = False
            self.c_onpick = False
            self.c_cid_press = self.canvas.mpl_connect('button_press_event',
                                                       on_press)
            self.c_cid_release = self.canvas.mpl_connect(
                'button_release_event', on_release)
            self.c_cid_motion = self.canvas.mpl_connect(
                'motion_notify_event', on_motion)
            self.c_in_ax = self.canvas.mpl_connect('axes_enter_event',
                                                   move_in_ax)
            self.c_out_ax = self.canvas.mpl_connect('axes_leave_event',
                                                    move_out_ax)

        else:
            self.canvas.mpl_disconnect(self.c_cid_press)
            self.canvas.mpl_disconnect(self.c_cid_release)
            self.canvas.mpl_disconnect(self.c_cid_motion)
            self.canvas.mpl_disconnect(self.c_in_ax)
            self.canvas.mpl_disconnect(self.c_out_ax)

    def vzoom(self, state):
        def on_press(event):
            if self.in_ax:
                ax = event.inaxes
                line1, line2 = self.cursor_lines_in_axis(ax)
                try:
                    if event.button == 1:
                        self.vlock = True
                        self.last_ax = ax
                        ymin, ymax = ax.get_ylim()
                        self.yrange = ymax - ymin
                        xmin, xmax = ax.get_xlim()
                        self.xrange = xmax - xmin
                        self.y0 = event.ydata
                        self.top_ln, = ax.plot([
                            event.xdata - self.xrange * 0.02,
                            event.xdata + self.xrange * 0.02
                        ], [event.ydata, event.ydata])
                        self.btm_ln, = ax.plot([
                            event.xdata - self.xrange * 0.02,
                            event.xdata + self.xrange * 0.02
                        ], [event.ydata, event.ydata])
                        self.vzoom_ln, = ax.plot([event.xdata, event.xdata],
                                                 [event.ydata, event.ydata])
                        self.top_ln.set_color('m')
                        self.btm_ln.set_color('m')
                        self.vzoom_ln.set_color('m')
                        # print(self.right_ln.get_xdata(), self.right_ln.get_ydata())
                        self.btm_ln.set_animated(True)
                        self.vzoom_ln.set_animated(True)
                        self.canvas.draw()
                        self.background = self.canvas.copy_from_bbox(ax.bbox)
                        ax.draw_artist(self.vzoom_ln)
                        ax.draw_artist(self.btm_ln)
                        line1, line2 = self.cursor_lines_in_axis(ax)
                        ax.draw_artist(line1)
                        ax.draw_artist(line2)
                        self.canvas.blit(ax.bbox)
                    else:

                        self.top_ln.remove()
                        self.vzoom_ln.remove()
                        self.btm_ln.remove()
                        self.canvas.draw()
                        self.background = None
                        self.vlock = False
                        ax.draw_artist(line1)
                        ax.draw_artist(line2)
                        self.canvas.blit(ax.bbox)
                except:
                    print('no')

        def on_release(event):
            if self.vlock:
                try:
                    self.top_ln.remove()
                    self.vzoom_ln.remove()
                    self.btm_ln.remove()
                    self.canvas.draw()
                    self.background = None
                    self.vlock = False
                    ax = event.inaxes
                    if ax != self.last_ax:
                        ax = self.last_ax
                        limit = ax.get_ylim()
                        if self.vside == 'btm':
                            event.ydata = limit[0]
                        else:
                            event.ydata = limit[1]

                    if self.y0 > event.ydata:
                        self.y0, event.ydata = event.ydata, self.y0
                    str1 = "{:.5E}".format(self.y0)
                    str2 = "{:.5E}".format(event.ydata)
                    if ax == self.ax['freq']:
                        self.edits['freq_y_limit'].setText(str1 + ' ' + str2)
                    else:
                        self.edits['time_y_limit'].setText(str1 + ' ' + str2)
                except:
                    print('no')

        def on_motion(event):
            if self.vlock:
                ax = event.inaxes
                if ax != None:
                    self.btm_ln.set_ydata([event.ydata, event.ydata])
                    self.vzoom_ln.set_ydata([self.y0, event.ydata])
                    self.canvas.restore_region(self.background)
                    ax.draw_artist(self.vzoom_ln)
                    ax.draw_artist(self.btm_ln)
                    line1, line2 = self.cursor_lines_in_axis(ax)
                    ax.draw_artist(line1)
                    ax.draw_artist(line2)
                    self.canvas.blit(ax.bbox)
                    if self.y0 > event.ydata:
                        self.vside = 'btm'
                    else:
                        self.vside = 'top'

        def move_in_ax(event):
            self.in_ax = True

        def move_out_ax(event):
            self.out_ax = False

        if state:
            self.horizontalZoom.setChecked(False)
            self.moveCursor.setChecked(False)
            self.vlock = False
            self.vcid_press = self.canvas.mpl_connect('button_press_event',
                                                      on_press)
            self.vcid_release = self.canvas.mpl_connect(
                'button_release_event', on_release)
            self.vcid_motion = self.canvas.mpl_connect('motion_notify_event',
                                                       on_motion)
            self.vin_ax = self.canvas.mpl_connect('axes_enter_event',
                                                  move_in_ax)
            self.vout_ax = self.canvas.mpl_connect('axes_leave_event',
                                                   move_out_ax)
        else:
            self.canvas.mpl_disconnect(self.vcid_press)
            self.canvas.mpl_disconnect(self.vcid_release)
            self.canvas.mpl_disconnect(self.vcid_motion)
            self.canvas.mpl_disconnect(self.vin_ax)
            self.canvas.mpl_disconnect(self.vout_ax)

    def hzoom(self, state):
        def on_press(event):
            if self.in_ax:
                ax = event.inaxes
                line1, line2 = self.cursor_lines_in_axis(ax)
                try:
                    if event.button == 1:
                        self.hlock = True
                        self.last_ax = ax
                        ymin, ymax = ax.get_ylim()
                        self.yrange = ymax - ymin
                        xmin, xmax = ax.get_xlim()
                        self.xrange = xmax - xmin
                        self.x0 = event.xdata
                        self.left_ln, = ax.plot([event.xdata, event.xdata], [
                            event.ydata - self.yrange * 0.02,
                            event.ydata + self.yrange * 0.02
                        ])
                        self.right_ln, = ax.plot([event.xdata, event.xdata], [
                            event.ydata - self.yrange * 0.02,
                            event.ydata + self.yrange * 0.02
                        ])
                        self.hzoom_ln, = ax.plot([event.xdata, event.xdata],
                                                 [event.ydata, event.ydata])
                        self.left_ln.set_color('m')
                        self.right_ln.set_color('m')
                        self.hzoom_ln.set_color('m')
                        # print(self.right_ln.get_xdata(), self.right_ln.get_ydata())
                        self.right_ln.set_animated(True)
                        self.hzoom_ln.set_animated(True)
                        self.canvas.draw()
                        self.background = self.canvas.copy_from_bbox(ax.bbox)
                        ax.draw_artist(self.hzoom_ln)
                        ax.draw_artist(self.right_ln)
                        ax.draw_artist(line1)
                        ax.draw_artist(line2)
                        self.canvas.blit(ax.bbox)

                    else:
                        self.left_ln.remove()
                        self.hzoom_ln.remove()
                        self.right_ln.remove()
                        self.canvas.draw()
                        self.background = None
                        self.hlock = False
                        ax.draw_artist(line1)
                        ax.draw_artist(line2)
                        self.canvas.blit(ax.bbox)

                except:
                    print('no')

        def on_motion(event):
            if self.hlock:
                ax = event.inaxes
                if ax != None:
                    self.right_ln.set_xdata([event.xdata, event.xdata])
                    self.hzoom_ln.set_xdata([self.x0, event.xdata])
                    self.canvas.restore_region(self.background)
                    ax.draw_artist(self.hzoom_ln)
                    ax.draw_artist(self.right_ln)
                    line1, line2 = self.cursor_lines_in_axis(ax)
                    ax.draw_artist(line1)
                    ax.draw_artist(line2)
                    self.canvas.blit(ax.bbox)
                    if self.x0 > event.xdata:
                        self.hside = 'left'
                    else:
                        self.hside = 'right'

        def on_release(event):
            if self.hlock:
                try:
                    self.left_ln.remove()
                    self.hzoom_ln.remove()
                    self.right_ln.remove()
                    self.canvas.draw()
                    self.background = None
                    self.hlock = False
                    ax = event.inaxes
                    if ax != self.last_ax:
                        ax = self.last_ax
                        limit = ax.get_xlim()
                        if self.hside == 'left':
                            event.xdata = limit[0]
                        else:
                            event.xdata = limit[1]

                    if self.x0 > event.xdata:
                        self.x0, event.xdata = event.xdata, self.x0

                    # ax.set_xlim(self.x0, event.xdata)
                    str1 = "{:.5E}".format(self.x0)
                    str2 = "{:.5E}".format(event.xdata)
                    if ax == self.ax['freq']:
                        self.edits['freq_x_limit'].setText(str1 + ' ' + str2)
                    else:
                        self.edits['time_x_limit'].setText(str1 + ' ' + str2)

                except:
                    print('no')

        def move_in_ax(event):
            self.in_ax = True

        def move_out_ax(event):
            self.out_ax = False

        if state:
            self.moveCursor.setChecked(False)
            self.verticalZoom.setChecked(False)
            self.hlock = False
            self.hcid_press = self.canvas.mpl_connect('button_press_event',
                                                      on_press)
            self.hcid_release = self.canvas.mpl_connect(
                'button_release_event', on_release)
            self.hcid_motion = self.canvas.mpl_connect('motion_notify_event',
                                                       on_motion)
            self.hin_ax = self.canvas.mpl_connect('axes_enter_event',
                                                  move_in_ax)
            self.hout_ax = self.canvas.mpl_connect('axes_leave_event',
                                                   move_out_ax)
        else:
            self.canvas.mpl_disconnect(self.hcid_press)
            self.canvas.mpl_disconnect(self.hcid_release)
            self.canvas.mpl_disconnect(self.hcid_motion)
            self.canvas.mpl_disconnect(self.hin_ax)
            self.canvas.mpl_disconnect(self.hout_ax)

    '''
    ################################################################################
    Multithreading fft calculation
    '''

    def fourier_multithreading(self, time_sig):
        self.fourier_lb.setText('Waiting...')
        fourier_worker = FourierWorker(time_sig, self.f_max)
        fourier_worker.signals.data.connect(self.set_fourier)
        fourier_worker.signals.finished.connect(self.fourier_finished)
        self.threadpool.start(fourier_worker)

    def set_fourier(self, data):
        self.data['freq_x'] = data[0]
        self.data['freq_y'] = data[1]
        self.draw('freq')
        self.edits['freq_x_limit'].returnPressed.emit()
        self.edits['freq_cursor'].returnPressed.emit()

    def fourier_finished(self):
        self.fourier_lb.setText('Ready')

    '''
    ################################################################################
    make zerofilling work
    '''

    def zero_padding(self, pad_power, value=[]):
        if value == []:
            value = [float(val) for val in self.parameters['time_cursor']]
            cs1 = np.argmin(np.abs(
                self.data['time_x'] -
                value[0]))  # finding the index corresponding to the time stamp
            cs2 = np.argmin(np.abs(self.data['time_x'] - value[1]))
        else:
            cs1 = value[0]
            cs2 = value[1]
        try:
            time_data = self.data['time_y'][cs1:cs2]
            pad_power = int(pad_power[1:])
            x = np.ceil(np.log2(len(self.data['time_y'])))
            n = 2**(pad_power - 1)
            l = int(2**x * n)
            time_sig = np.pad(time_data, (0, l - len(time_data)), 'constant')
            self.fourier_multithreading(time_sig)
        except AttributeError:
            dlg = QMessageBox.warning(self, 'WARNING',
                                      'No original data available!',
                                      QMessageBox.Ok)
            self.zeroPadPower.setCurrentIndex(0)

    '''
    ################################################################################
    other miscellaneous function
    '''

    def renew_data(self):
        try:
            self.data['time_x'] = self.data['raw_x']
            self.data['time_y'] = self.data['raw_y']
            self.draw('time')
            self.zeroPadPower.setCurrentIndex(0)
        except AttributeError:
            dlg = QMessageBox.warning(self, 'WARNING',
                                      'No original data available!',
                                      QMessageBox.Ok)

    def exit_program(self):
        choice = QMessageBox.question(
            self, 'Exiting', 'Are you sure about exit?',
            QMessageBox.Yes | QMessageBox.No)  #Set a QMessageBox when called
        if choice == QMessageBox.Yes:  # give actions when answered the question
            sys.exit()

    def open_file(self):
        '''
        open file and assign data to a dictionary self.Data
        self.data['raw_x']
        self.data['raw_y']
        above two are the original data
        self.data['time_x']
        self.data['time_y']
        self.data['freq_x']
        self.data['freq_y']

        '''
        dlg = QFileDialog()
        dlg.setDirectory(read_parameter(PARAMETER_FILE)['file_name'])
        if dlg.exec_():
            file_name = dlg.selectedFiles()[0]
            save_parameter(PARAMETER_FILE, **{"file_name": file_name})
            if str(self.data_type.currentText()) == 'bin':
                raw_data = np.fromfile(file_name, '>f8')
            elif str(self.data_type.currentText()) == '.npy':
                raw_data = np.load(file_name)
            self.data = {}
            self.data['raw_x'] = raw_data[::2]
            self.data['raw_y'] = raw_data[1::2]
            self.data['time_x'] = self.data['raw_x']
            self.data['time_y'] = self.data['raw_y']
            dt = self.data['time_x'][1] - self.data['time_x'][0]
            self.f_max = 1 / (2 * dt)
            self.fourier_multithreading(self.data['time_y'])
            self.edits['time_cursor'].returnPressed.emit()

            self.draw('time')

    '''
    ################################################################################
    '''

    def draw(self, key):
        self.ax[key].clear()
        if key == 'time':
            self.ax[key].plot(self.data[key + '_x'], self.data[key + '_y'])
        elif key == 'freq':
            self.ax[key].plot(self.data[key + '_x'],
                              np.abs(self.data[key + '_y']))
        value = [
            float(x) for x in self.edits[key + '_cursor'].text().split(' ')
        ]
        self.vline[key + '_l'].set_xdata([value[0], value[0]])
        self.vline[key + '_r'].set_xdata([value[1], value[1]])

        self.canvas.draw()
        self.ax[key].ticklabel_format(
            style='sci', axis='both',
            scilimits=(0, 0))  # format the tick label of the axes
        self.ax[key].draw_artist(self.vline[key + '_l'])
        self.ax[key].draw_artist(self.vline[key + '_r'])
        self.canvas.blit(self.ax[key].bbox)
Example #33
0
class FramePanel(QtWidgets.QWidget):
    '''GUI panel containing frame display widget
    Can scroll through frames of parent's EMCReader object

    Other parameters:
        compare - Side-by-side view of frames and best guess tomograms from reconstruction
        powder - Show sum of all frames

    Required members of parent class:
        emc_reader - Instance of EMCReader class
        geom - Instance of DetReader class
        output_folder - (Only for compare mode) Folder with output data
        need_scaling - (Only for compare mode) Whether reconstruction was done with scaling
    '''
    def __init__(self, parent, compare=False, powder=False, **kwargs):
        super(FramePanel, self).__init__(**kwargs)

        matplotlib.rcParams.update({
            'text.color': '#eff0f1',
            'xtick.color': '#eff0f1',
            'ytick.color': '#eff0f1',
            'axes.labelcolor': '#eff0f1',
            #'axes.facecolor': '#232629',
            #'figure.facecolor': '#232629'})
            'axes.facecolor': '#2a2a2f',
            'figure.facecolor': '#2a2a2f'})

        self.parent = parent
        self.emc_reader = self.parent.emc_reader
        self.do_compare = compare
        self.do_powder = powder
        if self.do_compare:
            self.slices = slices.SliceGenerator(self.parent.geom, 'data/quat.dat',
                                                folder=self.parent.output_folder,
                                                need_scaling=self.parent.need_scaling)
        if self.do_powder:
            self.powder_sum = self.emc_reader.get_powder()

        self.numstr = '0'
        self.rangestr = '10'

        self._init_ui()

    def _init_ui(self):
        vbox = QtWidgets.QVBoxLayout(self)

        self.fig = Figure(figsize=(6, 6))
        self.fig.subplots_adjust(left=0.05, right=0.99, top=0.9, bottom=0.05)
        self.canvas = FigureCanvas(self.fig)
        self.navbar = MyNavigationToolbar(self.canvas, self)
        self.canvas.mpl_connect('button_press_event', self._frame_focus)
        vbox.addWidget(self.navbar)
        vbox.addWidget(self.canvas)

        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        if not self.do_powder:
            label = QtWidgets.QLabel('Frame number: ', self)
            hbox.addWidget(label)
            self.numstr = QtWidgets.QLineEdit('0', self)
            self.numstr.setFixedWidth(64)
            hbox.addWidget(self.numstr)
            label = QtWidgets.QLabel('/%d'%self.emc_reader.num_frames, self)
            hbox.addWidget(label)
        hbox.addStretch(1)
        if not self.do_powder and self.do_compare:
            self.compare_flag = QtWidgets.QCheckBox('Compare', self)
            self.compare_flag.clicked.connect(self._compare_flag_changed)
            self.compare_flag.setChecked(False)
            hbox.addWidget(self.compare_flag)
            label = QtWidgets.QLabel('CMap:', self)
            hbox.addWidget(label)
            self.slicerange = QtWidgets.QLineEdit('10', self)
            self.slicerange.setFixedWidth(30)
            hbox.addWidget(self.slicerange)
            label = QtWidgets.QLabel('^', self)
            hbox.addWidget(label)
            self.exponent = QtWidgets.QLineEdit('1.0', self)
            self.exponent.setFixedWidth(30)
            hbox.addWidget(self.exponent)
            hbox.addStretch(1)
        label = QtWidgets.QLabel('PlotMax:', self)
        hbox.addWidget(label)
        self.rangestr = QtWidgets.QLineEdit('10', self)
        self.rangestr.setFixedWidth(48)
        hbox.addWidget(self.rangestr)

        hbox = QtWidgets.QHBoxLayout()
        vbox.addLayout(hbox)
        button = QtWidgets.QPushButton('Plot', self)
        button.clicked.connect(self.plot_frame)
        hbox.addWidget(button)
        if self.do_powder:
            button = QtWidgets.QPushButton('Save', self)
            button.clicked.connect(self._save_powder)
            hbox.addWidget(button)
        else:
            gui_utils.add_scroll_hbox(self, hbox)
        hbox.addStretch(1)
        button = QtWidgets.QPushButton('Quit', self)
        button.clicked.connect(self.parent.close)
        hbox.addWidget(button)

        self.show()
        #if not self.do_compare:
        self.plot_frame()

    def plot_frame(self, frame=None):
        '''Update canvas according to GUI parameters
        Updated plot depends on mode (for classifier) and whether the GUI is in
        'compare' or 'powder' mode.
        '''
        try:
            mode = self.parent.mode_val
        except AttributeError:
            mode = None

        if frame is not None:
            pass
        elif self.do_powder:
            frame = self.powder_sum
            num = None
        else:
            num = self.get_num()
            if num is None:
                return
            frame = self.emc_reader.get_frame(num)

        try:
            for point in self.parent.embedding_panel.roi_list:
                point.remove()
        except (ValueError, AttributeError):
            pass

        self.fig.clear()
        if mode == 2:
            subp = self.parent.conversion_panel.plot_converted_frame()
        elif self.do_compare and self.compare_flag.isChecked():
            subp = self._plot_slice(num)
        else:
            subp = self.fig.add_subplot(111)
        subp.imshow(frame.T, vmin=0, vmax=float(self.rangestr.text()),
                    interpolation='none', cmap=self.parent.cmap)
        subp.set_title(self._get_plot_title(frame, num, mode))
        self.fig.tight_layout()
        self.canvas.draw()

    def get_num(self):
        '''Get valid frame number from GUI
        Returns None if the types number is either unparseable or out of bounds
        '''
        try:
            num = int(self.numstr.text())
        except ValueError:
            sys.stderr.write('Frame number must be integer\n')
            return None
        if num < 0 or num >= self.emc_reader.num_frames:
            sys.stderr.write('Frame number %d out of range!\n' % num)
            return None
        return num

    def _plot_slice(self, num):
        with open(self.parent.log_fname, 'r') as fptr:
            line = fptr.readlines()[-1]
            try:
                iteration = int(line.split()[0])
            except (IndexError, ValueError):
                sys.stderr.write('Unable to determine iteration number from %s\n' %
                                 self.parent.log_fname)
                sys.stderr.write('%s\n' % line)
                iteration = None

        if iteration > 0:
            subp = self.fig.add_subplot(121)
            subpc = self.fig.add_subplot(122)
            tomo, info = self.slices.get_slice(iteration, num)
            subpc.imshow(tomo**float(self.exponent.text()), cmap=self.parent.cmap, vmin=0, vmax=float(self.slicerange.text()), interpolation='gaussian')
            subpc.set_title('Mutual Info. = %f'%info)
            self.fig.add_subplot(subpc)
        else:
            subp = self.fig.add_subplot(111)

        return subp

    def _next_frame(self):
        num = int(self.numstr.text()) + 1
        if num < self.emc_reader.num_frames:
            self.numstr.setText(str(num))
            self.plot_frame()

    def _prev_frame(self):
        num = int(self.numstr.text()) - 1
        if num > -1:
            self.numstr.setText(str(num))
            self.plot_frame()

    def _rand_frame(self):
        num = np.random.randint(0, self.emc_reader.num_frames)
        self.numstr.setText(str(num))
        self.plot_frame()

    def _get_plot_title(self, frame, num, mode):
        title = '%d photons' % frame.sum()
        if frame is None and (mode == 1 or mode == 3):
            title += ' (%s)' % self.parent.classes.clist[num]
        if mode == 4 and self.parent.mlp_panel.predictions is not None:
            title += ' [%s]' % self.parent.mlp_panel.predictions[num]
        if (mode is None and
                not self.do_powder and
                self.parent.blacklist is not None and
                self.parent.blacklist[num] == 1):
            title += ' (bad frame)'
        return title

    def _compare_flag_changed(self):
        self.plot_frame()

    def _frame_focus(self, event): # pylint: disable=unused-argument
        self.setFocus()

    def _save_powder(self):
        fname = '%s/assem_powder.bin' % self.parent.output_folder
        sys.stderr.write('Saving assembled powder sum with shape %s to %s\n' %
                         ((self.powder_sum.shape,), fname))
        self.powder_sum.data.tofile(fname)

        raw_powder = self.emc_reader.get_powder(raw=True)
        fname = '%s/powder.bin' % self.parent.output_folder
        sys.stderr.write('Saving raw powder sum with shape %s to %s\n' %
                         ((raw_powder.shape,), fname))
        raw_powder.tofile(fname)

    def keyPressEvent(self, event): # pylint: disable=C0103
        '''Override of default keyPress event handler'''
        key = event.key()
        mod = int(event.modifiers())

        if QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+N'):
            self._next_frame()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+P'):
            self._prev_frame()
        elif QtGui.QKeySequence(mod+key) == QtGui.QKeySequence('Ctrl+R'):
            self._rand_frame()
        elif key == QtCore.Qt.Key_Return or key == QtCore.Qt.Key_Enter:
            self.plot_frame()
        elif key == QtCore.Qt.Key_Right or key == QtCore.Qt.Key_Down:
            self._next_frame()
        elif key == QtCore.Qt.Key_Left or key == QtCore.Qt.Key_Up:
            self._prev_frame()
        else:
            event.ignore()