示例#1
0
def get_spectrogram(wav: WavFile, graphics_layout: pyqtgraph.GraphicsLayoutWidget):
    f, t, Sxx = signal.spectrogram(wav.data, wav.rate)

    # Interpret image data as row-major instead of col-major
    pyqtgraph.setConfigOptions(imageAxisOrder='row-major')
    pyqtgraph.mkQApp()
    graphics_layout.clear()
    plot_widget = graphics_layout.addPlot()
    # A plot area (ViewBox + axes) for displaying the image

    # Item for displaying image data
    img = pyqtgraph.ImageItem()
    plot_widget.addItem(img)
    # Add a histogram with which to control the gradient of the image
    hist = pyqtgraph.HistogramLUTItem()
    # Link the histogram to the image
    hist.setImageItem(img)
    # If you don't add the histogram to the window, it stays invisible, but I find it useful.
    graphics_layout.addItem(hist)
    # Show the window
    graphics_layout.show()
    # Fit the min and max levels of the histogram to the data available
    #print("min: "+ str(np.min(Sxx)) + "max:" + str(np.max(Sxx)))
    hist.setLevels(0, 40000)
    # This gradient is roughly comparable to the gradient used by Matplotlib
    # You can adjust it and then save it using hist.gradient.saveState()
    hist.gradient.restoreState(
        {'mode': 'rgb',
         'ticks': [(0.5, (0, 182, 188, 255)),
                   (1.0, (246, 111, 0, 255)),
                   (0.0, (75, 0, 113, 255))]})
    # Sxx contains the amplitude for each pixel
    img.setImage(Sxx)
    # Scale the X and Y Axis to time and frequency (standard is pixels)
    img.scale(t[-1] / np.size(Sxx, axis=1),
              f[-1] / np.size(Sxx, axis=0))
    # Limit panning/zooming to the spectrogram
    plot_widget.setLimits(xMin=0, xMax=t[-1], yMin=0, yMax=f[-1])
    # Add labels to the axis
    plot_widget.setLabel('bottom', "Time", units='s')
    # If you include the units, Pyqtgraph automatically scales the axis and adjusts the SI prefix (in this case kHz)
    plot_widget.setLabel('left', "Frequency", units='Hz')
示例#2
0
文件: moldy.py 项目: shrx/moldy
class MainWidget(QWidget):
    def __init__(self):
        QWidget.__init__(self)

        # define periodic table widget for element selection
        self.periodicTableWidget = widgets.PeriodicTableDialog()

        # initial molecule Zmatrix (can be empty)
        # self.inp = []
        self.inp = [['H'],
        ['O', 1, 0.9],
        ['O', 2, 1.4, 1, 105.],
        ['H', 3, 0.9, 2, 105., 1, 120.]]

        self.atomList = []
        self.highList = []
        self.labelList = []
        self.fast = False

        # define & initialize ZMatModel that will contain Zmatrix data
        self.ZMatModel = QStandardItemModel(len(self.inp), 7, self)
        self.ZMatTable = QTableView(self)
        self.ZMatTable.setModel(self.ZMatModel)
        self.ZMatTable.setFixedWidth(325)
        #self.ZMatTable.installEventFilter(self)
        #self.ZMatModel.installEventFilter(self)
        self.ZMatModel.setHorizontalHeaderLabels(['atom','','bond','','angle','','dihedral'])
        for j, width in enumerate([40, 22, 65, 22, 65, 22, 65]):
            self.ZMatTable.setColumnWidth(j, width)
        # populate the ZMatModel
        self.populateZMatModel()

        #define Menu bar menus and their actions
        self.menuBar = QMenuBar(self)
        fileMenu = self.menuBar.addMenu('&File')
        editMenu = self.menuBar.addMenu('&Edit')
        viewMenu = self.menuBar.addMenu('&View')
        measureMenu = self.menuBar.addMenu('&Measure')
        helpMenu = self.menuBar.addMenu('&Help')

        readZmatAction = QAction('&Read &ZMat', self)
        readZmatAction.setShortcut('Ctrl+O')
        readZmatAction.setStatusTip('Read Zmat from file')
        readZmatAction.triggered.connect(self.readZmat)
        fileMenu.addAction(readZmatAction)

        readXYZAction = QAction('&Read &XYZ', self)
        readXYZAction.setShortcut('Ctrl+Shift+O')
        readXYZAction.setStatusTip('Read XYZ from file')
        readXYZAction.triggered.connect(self.readXYZ)
        fileMenu.addAction(readXYZAction)

        readGaussianAction = QAction('&Read &Gaussian log', self)
        readGaussianAction.setShortcut('Ctrl+G')
        readGaussianAction.setStatusTip('Read Gaussian log file')
        readGaussianAction.triggered.connect(self.readGaussian)
        fileMenu.addAction(readGaussianAction)

        writeZmatAction = QAction('&Write &ZMat', self)
        writeZmatAction.setShortcut('Ctrl+S')
        writeZmatAction.setStatusTip('Write Zmat to file')
        writeZmatAction.triggered.connect(self.writeZmat)
        fileMenu.addAction(writeZmatAction)

        writeXYZAction = QAction('&Write &XYZ', self)
        writeXYZAction.setShortcut('Ctrl+Shift+S')
        writeXYZAction.setStatusTip('Write XYZ from file')
        writeXYZAction.triggered.connect(self.writeXYZ)
        fileMenu.addAction(writeXYZAction)

        exitAction = QAction('&Exit', self)
        exitAction.setShortcut('Ctrl+Q')
        exitAction.setStatusTip('Exit application')
        exitAction.triggered.connect(qApp.quit)
        fileMenu.addAction(exitAction)

        addRowAction = QAction('&Add &row', self)
        addRowAction.setShortcut('Ctrl+R')
        addRowAction.setStatusTip('Add row to ZMatrix')
        addRowAction.triggered.connect(self.addRow)
        editMenu.addAction(addRowAction)

        deleteRowAction = QAction('&Delete &row', self)
        deleteRowAction.setShortcut('Ctrl+Shift+R')
        deleteRowAction.setStatusTip('Delete row from ZMatrix')
        deleteRowAction.triggered.connect(self.deleteRow)
        editMenu.addAction(deleteRowAction)

        addAtomAction = QAction('&Add &atom', self)
        addAtomAction.setShortcut('Ctrl+A')
        addAtomAction.setStatusTip('Add atom to ZMatrix')
        addAtomAction.triggered.connect(self.buildB)
        editMenu.addAction(addAtomAction)

        drawModeMenu = QMenu('Draw mode', self)
        viewMenu.addMenu(drawModeMenu)
        fastDrawAction = QAction('&Fast draw', self)
        fastDrawAction.triggered.connect(self.fastDraw)
        normalDrawAction = QAction('&Normal draw', self)
        normalDrawAction.triggered.connect(self.normalDraw)
        drawModeMenu.addAction(normalDrawAction)
        drawModeMenu.addAction(fastDrawAction)

        clearHighlightsAction = QAction('&Clear selection', self)
        clearHighlightsAction.setShortcut('Ctrl+C')
        clearHighlightsAction.setStatusTip('Clear highlighted atoms')
        clearHighlightsAction.triggered.connect(self.clearHighlights)
        viewMenu.addAction(clearHighlightsAction)

        clearLabelsAction = QAction('&Clear labels', self)
        clearLabelsAction.setShortcut('Ctrl+Alt+C')
        clearLabelsAction.setStatusTip('Clear labels')
        clearLabelsAction.triggered.connect(self.clearLabels)
        viewMenu.addAction(clearLabelsAction)

        clearUpdateViewAction = QAction('&Clear selection and labels', self)
        clearUpdateViewAction.setShortcut('Ctrl+Shift+C')
        clearUpdateViewAction.setStatusTip('Clear highlighted atoms and labels')
        clearUpdateViewAction.triggered.connect(self.clearUpdateView)
        viewMenu.addAction(clearUpdateViewAction)

        self.showGaussAction = QAction('Show &Gaussian geometry optimization', self)
        self.showGaussAction.setShortcut('Ctrl+G')
        self.showGaussAction.setStatusTip('Show Gaussian geometry optimization plots for energy, force and displacement.')
        self.showGaussAction.setEnabled(False)
        self.showGaussAction.triggered.connect(self.showGauss)
        viewMenu.addAction(self.showGaussAction)
        self.showFreqAction = QAction('Show &IR frequency plot', self)
        self.showFreqAction.setShortcut('Ctrl+I')
        self.showFreqAction.setStatusTip('Show Gaussian calculated IR frequency plot.')
        self.showFreqAction.setEnabled(False)
        self.showFreqAction.triggered.connect(self.showFreq)
        viewMenu.addAction(self.showFreqAction)

        measureDistanceAction = QAction('&Measure &distance', self)
        measureDistanceAction.setShortcut('Ctrl+D')
        measureDistanceAction.setStatusTip('Measure distance between two atoms')
        measureDistanceAction.triggered.connect(self.measureDistanceB)
        measureMenu.addAction(measureDistanceAction)

        measureAngleAction = QAction('&Measure &angle', self)
        measureAngleAction.setShortcut('Ctrl+Shift+D')
        measureAngleAction.setStatusTip('Measure angle between three atoms')
        measureAngleAction.triggered.connect(self.measureAngleB)
        measureMenu.addAction(measureAngleAction)

        aboutAction = QAction('&About', self)
        aboutAction.setStatusTip('About this program...')
        aboutAction.triggered.connect(self.about)
        helpMenu.addAction(aboutAction)

        aboutQtAction = QAction('&About Qt', self)
        aboutQtAction.setStatusTip('About Qt...')
        aboutQtAction.triggered.connect(self.aboutQt)
        helpMenu.addAction(aboutQtAction)

        # define GL widget that displays the 3D molecule model
        self.window = widgets.MyGLView()
        self.window.installEventFilter(self)
        self.window.setMinimumSize(500, 500)
        #self.window.setBackgroundColor((50, 0, 10))
        self.updateView()

        self.gaussianPlot = GraphicsLayoutWidget()
        self.gaussianPlot.resize(750, 250)
        self.gaussianPlot.setWindowTitle('Gaussian geometry optimization')
        #self.gaussianPlot.setAspectLocked(True)
        #self.gaussianPlot.addLayout(rowspan=3, colspan=1)

        self.FreqModel = QStandardItemModel(1, 3, self)
        self.freqTable = QTableView(self)
        self.freqTable.setModel(self.FreqModel)
        self.freqTable.setMinimumWidth(240)
        self.freqTable.installEventFilter(self)
        self.FreqModel.installEventFilter(self)
        self.FreqModel.setHorizontalHeaderLabels(['Frequency','IR Intensity','Raman Intensity'])
        for j, width in enumerate([80, 80, 80]):
            self.freqTable.setColumnWidth(j, width)

        self.freqWidget = QWidget()
        self.freqWidget.setWindowTitle('IR frequency plot & table')
        self.freqWidget.resize(800, 400)
        self.freqWidget.layout = QHBoxLayout(self.freqWidget)
        self.freqWidget.layout.setSpacing(1)
        self.freqWidget.layout.setContentsMargins(1, 1, 1, 1)
        self.freqPlot = GraphicsLayoutWidget()
        self.freqWidget.layout.addWidget(self.freqPlot)
        self.freqWidget.layout.addWidget(self.freqTable)
        self.freqTable.clicked.connect(self.freqCellClicked)

        # define other application parts
        self.statusBar = QStatusBar(self)
        self.fileDialog = QFileDialog(self)

        # define application layout
        self.layout = QVBoxLayout(self)
        self.layout.setSpacing(1)
        self.layout.setContentsMargins(1, 1, 1, 1)
        self.layout1 = QHBoxLayout()
        self.layout1.setSpacing(1)
        self.layout1.addWidget(self.ZMatTable)
        self.layout1.addWidget(self.window)
        self.layout.addWidget(self.menuBar)
        self.layout.addLayout(self.layout1)
        self.layout.addWidget(self.statusBar)

        self.adjustSize()
        self.setWindowTitle('Moldy')
        iconPath = 'icon.png'
        icon = QIcon(iconPath)
        icon.addFile(iconPath, QSize(16, 16))
        icon.addFile(iconPath, QSize(24, 24))
        icon.addFile(iconPath, QSize(32, 32))
        icon.addFile(iconPath, QSize(48, 48))
        icon.addFile(iconPath, QSize(256, 256))
        self.setWindowIcon(icon)

        # start monitoring changes in the ZMatModel
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)

    # run and show the application
    def run(self):
        self.show()
        self.ZMatTable.clicked.connect(self.ZMatCellClicked)
        qt_app.instance().aboutToQuit.connect(self.deleteGLwidget)
        qt_app.exec_()

    # fill the ZMatModel with initial data from 'self.inp'
    def populateZMatModel(self):
        self.ZMatModel.removeRows(0, self.ZMatModel.rowCount())
        for i, row in enumerate(self.inp):
            for j, cell in enumerate(row):
                item = QStandardItem(str(cell))
                self.ZMatModel.setItem(i, j, item)
        # some cells should not be editable, they are disabled
        for i in range(min(len(self.inp), 3)):
            for j in range(2*i+1, 7):
                self.ZMatModel.setItem(i, j, QStandardItem())
                self.ZMatModel.item(i, j).setBackground(QColor(150,150,150))
                self.ZMatModel.item(i, j).setFlags(Qt.ItemIsEnabled)
    
    def populateFreqModel(self):
        self.FreqModel.removeRows(0, self.FreqModel.rowCount())
        for i, row in enumerate(zip(self.vibfreqs, self.vibirs, self.vibramans)):
            for j, cell in enumerate(row):
                item = QStandardItem(str(cell))
                self.FreqModel.setItem(i, j, item)

    # add a row to the bottom of the ZMatModel
    def addRow(self):
        # temporarily stop updating the GL window
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        row = self.ZMatModel.rowCount()
        self.ZMatModel.insertRow(row)
        # some cells should not be editable
        if row < 3:
            for j in range(2*row+1, 7):
                self.ZMatModel.setItem(row, j, QStandardItem())
                self.ZMatModel.item(row, j).setBackground(QColor(150,150,150))
                self.ZMatModel.item(row, j).setFlags(Qt.ItemIsEnabled)
        # restart GL window updating
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Added 1 row.', 3000)

    # delete the last row of the ZMatModel
    def deleteRow(self):
        xyz = [list(vi) for vi in list(v)]
        atoms = [str(elements[e]) for e in elems]
        oldLen = self.ZMatModel.rowCount()
        idxs = sorted(set(idx.row() for idx in self.ZMatTable.selectedIndexes()), reverse=True)
        newLen = oldLen - len(idxs)
        if newLen == oldLen:
            self.ZMatModel.removeRow(self.ZMatModel.rowCount()-1)
        else:
            self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
            for idx in idxs:
                self.ZMatModel.removeRow(idx)
                if idx < 3:
                    for i in range(idx, min(3, newLen)):
                        for j in range(2*i+1, 7):
                            self.ZMatModel.setItem(i, j, QStandardItem())
                            self.ZMatModel.item(i, j).setBackground(QColor(150,150,150))
                            self.ZMatModel.item(i, j).setFlags(Qt.ItemIsEnabled)
                if len(xyz) > idx:
                    xyz.pop(idx)
                    atoms.pop(idx)
            self.inp = xyz2zmat(xyz, atoms)
            self.populateZMatModel()
            for i in reversed(self.highList):
                self.window.removeItem(i[1])
            self.highList = []
            self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.updateView()
        self.statusBar.clearMessage()
        if idxs:
            self.statusBar.showMessage('Deleted row(s): '+str([i+1 for i in idxs]), 3000)
        else:
            self.statusBar.showMessage('Deleted last row.', 3000)

    # show the periodic table widget
    def periodicTable(self):
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Select element from periodic table.')
        self.periodicTableWidget.exec_()
        selection = self.periodicTableWidget.selection()
        return selection

    # import molecule with zmatrix coordinates
    def readZmat(self):
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        filename = self.fileDialog.getOpenFileName(self, 'Open file', expanduser('~'), '*.zmat;;*.*')
        self.inp = []
        self.populateZMatModel()
        if filename:
            with open(filename, 'r') as f:
                next(f)
                next(f)
                for row in f:
                    self.inp.append(row.split())
                f.close()
            self.populateZMatModel()
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.updateView()
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Read molecule from '+filename+'.', 5000)
        self.showGaussAction.setEnabled(False)
        self.showFreqAction.setEnabled(False)

    # import molecule with xyz coordinates
    def readXYZ(self):
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        filename = self.fileDialog.getOpenFileName(self, 'Open file', expanduser('~'), '*.xyz;;*.*')
        xyz = []
        elems = []
        self.inp = []
        self.populateZMatModel()
        if filename:
            with open(filename, 'r') as f:
                next(f)
                next(f)
                for row in f:
                    rs = row.split()
                    if len(rs) == 4:
                        elems.append(rs[0])
                        xyz.append([float(f) for f in rs[1:]])
                f.close()
            self.inp = xyz2zmat(xyz, elems)
            self.populateZMatModel()
            #print(elems)
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.updateView()
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Read molecule from '+filename+'.', 5000)
        self.showGaussAction.setEnabled(False)
        self.showFreqAction.setEnabled(False)

    # import Gaussian log file
    def readGaussian(self):
        global vsShifted
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        filename = self.fileDialog.getOpenFileName(self, 'Open file', expanduser('~'), '*.log;;*.*')
        if filename:
            self.gaussianPlot.clear()
            self.inp = []
            self.populateZMatModel()
            file = ccopen(filename)
            data = file.parse().getattributes()
            self.natom = data['natom']
            self.atomnos = data['atomnos'].tolist()
            self.atomsymbols = [ str(elements[e]) for e in self.atomnos ]
            self.atomcoords = data['atomcoords'].tolist()
            self.scfenergies = data['scfenergies'].tolist()
            self.geovalues = data['geovalues'].T.tolist()
            self.geotargets = data['geotargets'].tolist()
            if 'vibfreqs' in data.keys():
                self.vibfreqs = data['vibfreqs']
                #print(self.vibfreqs)
                self.vibirs = data['vibirs']
                #print(self.vibirs)
                #print(data.keys())
                if 'vibramans' in data.keys():
                    self.vibramans = data['vibramans']
                else:
                    self.vibramans = [''] * len(self.vibirs)
                self.vibdisps = data['vibdisps']
                #print(self.vibdisps)
            self.inp = xyz2zmat(self.atomcoords[0], self.atomsymbols)
            self.populateZMatModel()

            titles = ['SCF Energies', 'RMS & Max Forces', 'RMS & Max Displacements']
            for i in range(3):
                self.gaussianPlot.addPlot(row=1, col=i+1)
                plot = self.gaussianPlot.getItem(1, i+1)
                plot.setTitle(title=titles[i])
                if i == 0:
                    c = ['c']
                    x = [0]
                    y = [self.scfenergies]
                else:
                    c = ['r', 'y']
                    x = [0, 0]
                    y = [self.geovalues[2*i-2], self.geovalues[2*i-1]]
                    targety = [self.geotargets[2*i-2], self.geotargets[2*i-1]]
                plot.clear()
                plot.maxData = plot.plot(y[0], symbol='o', symbolPen=c[0], symbolBrush=c[0], pen=c[0], symbolSize=5, pxMode=True, antialias=True, autoDownsample=False)
                plot.highlight=plot.plot(x, [ yy[0] for yy in y ], symbol='o', symbolPen='w', symbolBrush=None, pen=None, symbolSize=15, pxMode=True, antialias=True, autoDownsample=False)
                plot.maxData.sigPointsClicked.connect(self.gausclicked)
                if i > 0:
                    for j in range(2):
                        plot.addLine(y=np.log10(targety[j]), pen=mkPen((255, 255*j, 0, int(255/2)), width=1))
                    plot.RMSData=plot.plot(y[1], symbol='o', symbolPen=c[1], symbolBrush=c[1], pen=c[1], symbolSize=5, pxMode=True, antialias=True, autoDownsample=False)
                    plot.RMSData.sigPointsClicked.connect(self.gausclicked)
                    plot.setLogMode(y=True)
            self.showGauss()
            self.updateView()
            self.statusBar.clearMessage()
            self.statusBar.showMessage('Read molecule from '+filename+'.', 5000)
            self.ZMatModel.dataChanged.connect(self.clearUpdateView)
            if self.natom:
                self.showGaussAction.setEnabled(True)
            if 'vibfreqs' in data.keys():
                self.showFreqAction.setEnabled(True)

                # populate the FreqModel
                self.populateFreqModel()

                self.freqPlot.clear()
                irPlot = self.freqPlot.addPlot(row=1, col=1)
                irPlot.clear()
                minFreq = np.min(self.vibfreqs)
                maxFreq = np.max(self.vibfreqs)
                maxInt = np.max(self.vibirs)
                x = np.sort(np.concatenate([np.linspace(minFreq-100, maxFreq+100, num=1000), self.vibfreqs]))
                y = x*0
                for f,i in zip(self.vibfreqs, self.vibirs):
                    y += lorentzv(x, f, 2*np.pi, i)
                #xy = np.array([np.concatenate([x, np.array(self.vibfreqs)]), np.concatenate([y, np.array(self.vibirs)])]).T
                #xysort = xy[xy[:,0].argsort()]
                irPlot.maxData = irPlot.plot(x, y, antialias=True)
                markers = ErrorBarItem(x=self.vibfreqs, y=self.vibirs, top=maxInt/30, bottom=None, pen='r')
                irPlot.addItem(markers)
                self.showFreq()
                #self.vibdisps = np.append(self.vibdisps, [np.mean(self.vibdisps, axis=0)], axis=0)
                maxt = 100
                vsShifted = np.array([ [ vs + self.vibdisps[i]*np.sin(t*2*np.pi/maxt)/3 for t in range(maxt) ] for i in range(len(self.vibfreqs)) ])
            else:
                self.showFreqAction.setEnabled(False)
                self.freqWidget.hide()

    def showGauss(self):
        self.gaussianPlot.show()

    def showFreq(self):
        self.freqWidget.show()

    # export Zmatrix to csv
    def writeZmat(self):
        zm = model2list(self.ZMatModel)
        filename = self.fileDialog.getSaveFileName(self, 'Save file', expanduser('~')+'/'+getFormula(list(list(zip(*zm))[0]))+'.zmat', '*.zmat;;*.*')
        try:
            filename
        except NameError:
            pass
        else:
            if filename:
                writeOutput(zm, filename)
                self.statusBar.clearMessage()
                self.statusBar.showMessage('Wrote molecule to '+filename+'.', 5000)

    # export XYZ coordinates to csv
    def writeXYZ(self):
        xyz = []
        zm = model2list(self.ZMatModel)
        for i in range(len(v)):
            xyz.append(np.round(v[i], 7).tolist())
            xyz[i][:0] = zm[i][0]
        if len(v) > 0:
            formula = getFormula(list(list(zip(*xyz))[0]))
        else:
            formula = 'moldy_output'
        filename = self.fileDialog.getSaveFileName(self, 'Save file', expanduser('~')+'/'+formula+'.xyz', '*.xyz;;*.*')
        try:
            filename
        except NameError:
            pass
        else:
            if filename:
                writeOutput(xyz, filename)
                self.statusBar.clearMessage()
                self.statusBar.showMessage('Wrote molecule to '+filename+'.', 5000)

    # redraw the 3D molecule in GL widget
    def updateView(self):
        global r
        global c
        global v
        global vs
        global elems
        global nelems
        data = model2list(self.ZMatModel)
        try:
            # create a list with element coordinates
            v = zmat2xyz(data)
        except (AssertionError, IndexError, ZMError):
            pass
        else:
            # clear the screen before redraw
            for item in reversed(self.window.items):
                self.window.removeItem(item)
            # create a second coordinate list 'vs' that is centered in the GL view
            self.atomList = []
            if len(v) > 0:
                shift = np.mean(v, axis=0)
                vs = np.add(v, -shift)
                elems = [ 1 + next((i for i, sublist in enumerate(colors) if row[0] in sublist), -1) for row in data ]
                nelems = len(elems)
                # define molecule radii and colors
                r = []
                c = []
                for i in elems:
                    r.append(elements[i].covalent_radius)
                    c.append(colors[i-1][-1])
                # draw atoms
                for i in range(nelems):
                    addAtom(self.window, i, r, vs, c, fast=self.fast)
                    self.atomList.append([i, self.window.items[-1]])
                #print(self.atomList)
                # draw bonds where appropriate
                combs = list(itertools.combinations(range(nelems), 2))
                bonds = []
                for i in combs:
                    bonds.append(addBond(self.window, i[0], i[1], r, vs, c, fast=self.fast))
                if self.fast:
                    bondedAtoms = set(filter((None).__ne__, flatten(bonds)))
                    for i in set(range(nelems)) - bondedAtoms:
                        addUnbonded(self.window, i, vs, c)
                        self.atomList[i][1]=self.window.items[-1]
                    #print(self.atomList)

                for i in self.highList:
                    self.window.addItem(i[1])
                for i in self.labelList:
                    self.window.addItem(i)
        if len(v) > 1:
            maxDim = float('-inf')
            for dim in v.T:
                span = max(dim)-min(dim)
                if span > maxDim:
                    maxDim = span
        else: maxDim = 2
        self.window.setCameraPosition(distance=maxDim*1.5+1)

    global index
    index = 0
    def updateFreq(self):
        global vsShifted, index, r, c
        index += 1
        index = index % len(vsShifted[0])
        #print(index)
        #print(vsShifted[index])
        for item in reversed(self.window.items):
            self.window.removeItem(item)
        for i in range(nelems):
            addAtom(self.window, i, r, vsShifted[self.freqIndex, index], c, fast=self.fast)
            self.atomList.append([i, self.window.items[-1]])
        combs = itertools.combinations(range(nelems), 2)
        bonds = []
        for i in combs:
            bonds.append(addBond(self.window, i[0], i[1], r, vsShifted[self.freqIndex, index], c, fast=self.fast))
        if self.fast:
            bondedAtoms = set(filter((None).__ne__, flatten(bonds)))
            for i in set(range(nelems)) - bondedAtoms:
                addUnbonded(self.window, i, vsShifted[self.freqIndex, index], c)
                self.atomList[i][1]=self.window.items[-1]

    # detect mouse clicks in GL window and process them
    def eventFilter(self, obj, event):
        if obj == self.window:
            if event.type() == event.MouseButtonPress:
                itms = obj.itemsAt((event.pos().x()-2, event.pos().y()-2, 4, 4))
                if len(itms):
                    self.highlight(obj, [itms[0]])
                elif len(self.atomList) == 0:
                    self.build()
        # also do the default click action
        return super(MainWidget, self).eventFilter(obj, event)

    def ZMatCellClicked(self):
        idxs = sorted(set(idx.row() for idx in self.ZMatTable.selectedIndexes()), reverse=True)
        itms = []
        if self.highList:
            highIdx = list(np.array(self.highList).T[0])
        for idx in idxs:
            if self.highList and idx in highIdx:
                itms.append(self.highList[highIdx.index(idx)][1])
            elif len(self.atomList) > idx:
                itms.append(self.atomList[idx][1])
        self.highlight(self.window, itms)

    def freqCellClicked(self):
        global vsShifted
        self.timer = QTimer()
        self.timer.setInterval(30)
        self.timer.timeout.connect(self.updateFreq)
        idxs = [ idx.row() for idx in self.freqTable.selectedIndexes() ]
        if len(idxs) == 1:
            self.freqIndex = idxs[0]
            self.timer.stop()
            self.timer.timeout.connect(self.updateFreq)
            try:
                self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
            except TypeError:
                pass
            self.timer.start()
        if len(idxs) != 1:
            self.timer.stop()
            self.freqTable.clearSelection()
            self.timer.timeout.disconnect(self.updateFreq)
            self.ZMatModel.dataChanged.connect(self.clearUpdateView)
            self.clearUpdateView()

    def gausclicked(self, item, point):
        itemdata = item.scatter.data
        points = [ row[7] for row in itemdata ]
        idx = points.index(point[0])
        for i in range(3):
            if i == 0:
                x = [idx]
                y = [self.scfenergies[idx]]
            else:
                x = [idx, idx]
                y = [self.geovalues[2*i-2][idx], self.geovalues[2*i-1][idx]]
            plot = self.gaussianPlot.getItem(1, i+1)
            plot.removeItem(plot.highlight)
            plot.highlight=plot.plot(x, y, symbol='o', symbolPen='w', symbolBrush=None, pen=None, symbolSize=15, pxMode=True, antialias=True, autoDownsample=False)
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        self.inp = []
        self.populateZMatModel()
        self.inp = xyz2zmat(self.atomcoords[min(idx, len(self.atomcoords)-1)], self.atomsymbols)
        self.populateZMatModel()
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.updateView()

    def highlight(self, obj, itms):
        for itm in itms:
            idx = next((i for i, sublist in enumerate(self.atomList) if itm in sublist), -1)
            #print(idx)
            if idx != -1:
                addAtom(obj, idx, r, vs, c, opt='highlight', fast=self.fast)
                self.highList.append([idx, obj.items[-1]])
                self.ZMatTable.selectRow(idx)
            idx = next((i for i, sublist in enumerate(self.highList) if itm in sublist), -1)
            if idx != -1:
                obj.removeItem(self.highList[idx][1])
                self.highList.pop(idx)
                self.ZMatTable.clearSelection()
        self.statusBar.clearMessage()
        if len(self.highList) > 0:
            idxs = np.asarray(self.highList).T[0]
            selected = []
            for i in idxs:
                selected.append(str(i+1)+str(elements[elems[i]]))
            self.statusBar.showMessage('Selected atoms: '+str(selected), 5000)

    def buildB(self):
        try:
            nelems
        except NameError:
            self.build()
        else:
            if len(self.highList) <= min(nelems, 3):
                diff = min(nelems, 3) - len(self.highList)
                if diff != 0:
                    self.statusBar.clearMessage()
                    self.statusBar.showMessage('Please select '+str(diff)+' more atom(s).')
                else:
                    self.build()
            else:
                self.statusBar.clearMessage()
                self.statusBar.showMessage('Too many atoms selected.')

    def build(self):
        selection = self.periodicTable()
        row = self.ZMatModel.rowCount()
        self.addRow()
        self.ZMatModel.dataChanged.disconnect(self.clearUpdateView)
        newSymbol = selection[1]
        newData = [newSymbol]
        if len(self.highList) >= 1:
            newBond = round(2.1*gmean([ elements[e].covalent_radius for e in [selection[0], elems[self.highList[0][0]]] ]), 4)
            newData.append(self.highList[0][0]+1)
            newData.append(newBond)
            if len(self.highList) >= 2:
                newAngle = 109.4712
                newData.append(self.highList[1][0]+1)
                newData.append(newAngle)
                if len(self.highList) == 3:
                    newDihedral = 120.
                    newData.append(self.highList[2][0]+1)
                    newData.append(newDihedral)
        for j, cell in enumerate(newData):
            item = QStandardItem(str(cell))
            self.ZMatModel.setItem(row, j, item)
        self.highList = []
        self.ZMatModel.dataChanged.connect(self.clearUpdateView)
        self.updateView()

    def measureDistanceB(self):
        sel = len(self.highList)
        if sel <= 2:
            if sel < 2:
                self.statusBar.clearMessage()
                self.statusBar.showMessage('Please select '+str(2-sel)+' more atom(s).')
            else:
                self.measureDistance()
        else:
            self.statusBar.clearMessage()
            self.statusBar.showMessage('Too many atoms selected.')

    def measureDistance(self):
        pts = []
        for pt in self.highList:
            pts.append(vs[pt[0]])
        pts = np.array(pts)
        self.clearHighlights()
        line = gl.GLLinePlotItem(pos=pts, color=(0., 1., 0., 1.), width=3)
        self.window.addItem(line)
        self.labelList.append(line)
        q = pts[1]-pts[0]
        dist = round(np.sqrt(np.dot(q, q)), 4)
        self.window.labelPos.append(np.mean(pts[0:2], axis=0))
        self.window.labelText.append(str(dist))
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Measured distance: '+str(dist)+' A.', 3000)

    def measureAngleB(self):
        sel = len(self.highList)
        if sel <= 3:
            if sel < 3:
                self.statusBar.clearMessage()
                self.statusBar.showMessage('Please select '+str(3-sel)+' more atom(s).')
            else:
                self.measureAngle()
        else:
            self.statusBar.clearMessage()
            self.statusBar.showMessage('Too many atoms selected.')

    def measureAngle(self):
        pts = []
        for pt in self.highList:
            pts.append(vs[pt[0]])
        pts = np.array(pts)
        q = pts[1]-pts[0]
        r = pts[2]-pts[0]
        q_u = q / np.sqrt(np.dot(q, q))
        r_u = r / np.sqrt(np.dot(r, r))
        angle = round(degrees(acos(np.dot(q_u, r_u))), 1)
        srange = np.array([slerp(q, r, t) for t in np.arange(0.0, 13/12, 1/12)])
        self.clearHighlights()
        for i in range(12):
            mesh = gl.MeshData(np.array([[0,0,0],srange[i],srange[i+1]]))
            tri = gl.GLMeshItem(meshdata=mesh, smooth=False, computeNormals=False, color=(0.3, 1., 0.3, 0.5), glOptions=('translucent'))
            tri.translate(pts[0][0], pts[0][1], pts[0][2])
            self.window.addItem(tri)
            self.labelList.append(tri)
        self.window.labelPos.append(slerp(q, r, 0.5)+pts[0])
        self.window.labelText.append(str(angle))
        self.statusBar.clearMessage()
        self.statusBar.showMessage('Measured angle: '+str(angle)+'°', 3000)

    def clearLabels(self):
        self.window.labelPos = []
        self.window.labelText = []
        self.labelList = []
        self.updateView()

    def clearHighlights(self):
        for item in reversed(self.highList):
                self.window.removeItem(item[1])
        self.highList = []
        self.updateView()

    def clearUpdateView(self):
        self.window.labelPos = []
        self.window.labelText = []
        self.labelList = []
        for item in reversed(self.highList):
                self.window.removeItem(item[1])
        self.highList = []
        self.updateView()
        #print(self.highList)

    def fastDraw(self):
        if not self.fast:
            self.fast = True
            self.updateView()

    def normalDraw(self):
        if self.fast:
            self.fast = False
            self.updateView()

    def about(self):
        QMessageBox.about(self, 'About moldy', 'moldy beta 15. 9. 2015')

    def aboutQt(self):
        QMessageBox.aboutQt(self, 'About Qt')

    def deleteGLwidget(self):
        self.window.setParent(None)
        del self.window
class MainUi(QMainWindow):
    def __init__(self):
        super().__init__()

        pg.setConfigOption('background', '#19232D')
        pg.setConfigOption('foreground', 'd')
        pg.setConfigOptions(antialias=True)

        # 窗口居中显示
        self.center()

        self.init_ui()
        # 设置城市的编号
        self.code = ""

        # 多次查询时,查询近5天, 受限于 API 接口提供的数量
        self.num = 5

        # return request json file
        self.rep = ""

        # 创建绘图面板
        self.plt = []
        # 控制绘图的文件名称
        self.filename = 1
        # 默认的状态栏
        # 可以设置其他按钮点击 参考多行文本显示 然而不行
        self.status = self.statusBar()
        self.status.showMessage("我在主页面~")

        # 标题栏
        self.setWindowTitle("天气查询软件")

        self.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5())

    def init_ui(self):

        # self.setFixedSize(960,700)

        # 创建窗口主部件
        self.main_widget = QWidget()
        # 创建主部件的网格布局
        self.main_layout = QGridLayout()
        # 设置窗口主部件布局为网格布局
        self.main_widget.setLayout(self.main_layout)

        # 创建左侧部件
        self.left_widget = QWidget()
        self.left_widget.setObjectName('left_widget')
        # 创建左侧部件的网格布局层
        self.left_layout = QGridLayout()
        # 设置左侧部件布局为网格
        self.left_widget.setLayout(self.left_layout)

        # 创建右侧部件
        self.right_widget = QWidget()
        self.right_widget.setObjectName('right_widget')
        self.right_layout = QGridLayout()
        self.right_widget.setLayout(self.right_layout)

        # 左侧部件在第0行第0列,占8行3列
        self.main_layout.addWidget(self.left_widget, 0, 0, 12, 5)
        # 右侧部件在第0行第3列,占8行9列
        self.main_layout.addWidget(self.right_widget, 0, 5, 12, 7)
        # 设置窗口主部件
        self.setCentralWidget(self.main_widget)

        # function button
        self.single_query = QPushButton("查询今日")
        self.single_query.clicked.connect(self.request_weather)
        self.single_query.setEnabled(False)
        # self.single_query.setFixedSize(400, 30)
        self.btn_tempa = QPushButton("温度预测(可绘图)")
        self.btn_tempa.clicked.connect(self.request_weather)
        self.btn_tempa.setEnabled(False)
        self.btn_wind = QPushButton("风力预测(可绘图)")
        self.btn_wind.clicked.connect(self.request_weather)
        self.btn_wind.setEnabled(False)
        self.btn_stawea = QPushButton("综合天气预测")
        self.btn_stawea.clicked.connect(self.request_weather)
        self.btn_stawea.setEnabled(False)
        self.left_layout.addWidget(self.single_query, 2, 0, 1, 5)
        self.left_layout.addWidget(self.btn_tempa, 3, 0, 1, 5)
        self.left_layout.addWidget(self.btn_wind, 4, 0, 1, 5)
        self.left_layout.addWidget(self.btn_stawea, 5, 0, 1, 5)

        # lineEdit to input a city
        self.city_line = QLineEdit()
        self.city_line.setPlaceholderText("输入城市回车确认")
        self.city_line.returnPressed.connect(self.match_city)
        self.left_layout.addWidget(self.city_line, 1, 0, 1, 5)

        # save figure and quit window
        self.save_fig = QPushButton("保存绘图")
        self.save_fig.setEnabled(False)
        self.save_fig.clicked.connect(self.fig_save)
        self.left_layout.addWidget(self.save_fig, 6, 0, 1, 5)

        self.load = QPushButton("写日记")
        self.left_layout.addWidget(self.load, 7, 0, 1, 5)

        self.quit_btn = QPushButton("退出")
        self.quit_btn.clicked.connect(self.quit_act)
        self.left_layout.addWidget(self.quit_btn, 8, 0, 1, 5)

        # tablewidgt to view data
        self.query_result = QTableWidget()
        self.left_layout.addWidget(self.query_result, 9, 0, 2, 5)
        self.query_result.verticalHeader().setVisible(False)

        self.label = QLabel("预测天气情况绘图展示区")
        self.right_layout.addWidget(self.label, 0, 6, 1, 7)

        self.plot_weather_wind = GraphicsLayoutWidget()
        self.plot_weather_temp = GraphicsLayoutWidget()
        self.right_layout.addWidget(self.plot_weather_temp, 1, 6, 4, 7)
        self.right_layout.addWidget(self.plot_weather_wind, 6, 6, 4, 7)

        self.setWindowOpacity(0.9)  # 设置窗口透明度
        # self.setWindowFlag(QtCore.Qt.FramelessWindowHint) # 隐藏边框
        self.main_layout.setSpacing(0)

    # 按照城市的code, 请求一个城市的天气 返回 json 形式
    def request_weather(self):
        root_url = "http://t.weather.sojson.com/api/weather/city/"
        url = root_url + str(self.code)
        self.rep = get_weather.run(url)
        sender = self.sender()
        if sender.text() == "查询今日":
            self.query(1, 5, '温度', '风向', '风力', 'PM2.5', '天气描述')
        if sender.text() == "温度预测(可绘图)":
            self.btn_tempa.setEnabled(False)
            self.query(self.num, 2, '日期', '温度')
        if sender.text() == "风力预测(可绘图)":
            self.btn_wind.setEnabled(False)
            self.query(self.num, 2, '日期', '风力')
        if sender.text() == "综合天气预测":
            self.query(self.num, 4, '温度', '风向', '风力', '天气描述')

    # 读取 json 文件, 获得城市的code
    def match_city(self):
        # 输入城市后才能点击绘图
        self.btn_tempa.setEnabled(True)
        self.btn_wind.setEnabled(True)
        self.single_query.setEnabled(True)
        self.btn_stawea.setEnabled(True)
        # 在外部json文件中 读取所有城市的 code
        city = read_citycode.read_code("最新_city.json")
        line_city = self.city_line.text()
        # code与输入的城市对比, 如果有, 返回code, 如果没有则默认为北京
        if line_city in city.keys():
            self.code = city[line_city]
        else:
            self.code = "101010100"
            self.city_line.setText("北京")
            Qreply = QMessageBox.about(self, "你犯了一个粗误", "输入城市无效,请示新输入,否则默认为北京")

    # 保存图片成功时的提醒
    def pic_messagebox(self):
        string = '第' + str(self.filename) + '张图片.png'
        Qreply = QMessageBox.information(self, string,
                                         "已经成功保存图片到当前目录, 关闭软件后请及时拷贝走")

    # 保存图片的设置 pyqtgraph 保存无法设置图片路径
    def fig_save(self):
        ex = pe.ImageExporter(self.plt.scene())
        filename = '第' + str(self.filename) + '张图片.png'
        self.filename += 1
        ex.export(fileName=filename)
        self.pic_messagebox()

    def get_date(self, dateFormat="%Y-%m-%d", addDays=0):
        ls = []
        timeNow = datetime.datetime.now()
        key = 0
        if (addDays != 0) and key < addDays - 1:
            for i in range(addDays):
                anotherTime = timeNow + datetime.timedelta(days=key)
                anotherTime.strftime(dateFormat)
                ls.append(str(anotherTime)[0:10])
                key += 1
        else:
            anotherTime = timeNow

        return ls

    # 查询, 可以接受多个参数, 更加灵活的查询
    def query(self, row_num, col_num, *args):
        # true value
        tempature = self.rep.json()['data']['wendu']
        wind_power = self.rep.json()['data']['forecast'][0]['fl']
        wind_direction = self.rep.json()['data']['forecast'][0]['fx']
        pm = self.rep.json()['data']['pm25']
        type_ = self.rep.json()['data']['forecast'][0]['type']
        # forecast value
        pre_tempature = []
        pre_wind_power = []
        pre_wind_direction = []
        pre_pm = []
        pre_type_ = []

        for i in range(self.num):
            pre_tempature.append(
                str(self.rep.json()['data']['forecast'][i]['low']))
            pre_wind_power.append(
                str(self.rep.json()['data']['forecast'][i]['fl']))
            pre_wind_direction.append(
                str(self.rep.json()['data']['forecast'][i]['fx']))
            pre_type_.append(
                str(self.rep.json()['data']['forecast'][i]['type']))

        # 设置当前查询结果的行列
        self.query_result.setRowCount(row_num)
        # 否则不会显示
        self.query_result.setColumnCount(col_num)
        # 表头自适应伸缩
        self.query_result.horizontalHeader().setSectionResizeMode(
            QHeaderView.Stretch)
        # 按照 传入的参数设置表头, 因为每次查询的表头都不一样
        ls = [i for i in args]
        self.query_result.setHorizontalHeaderLabels(ls)

        if col_num > 2 and row_num == 1:
            item = QTableWidgetItem(str(tempature) + "℃")
            # 设置单元格文本颜色
            item.setForeground(QBrush(QColor(144, 182, 240)))
            self.query_result.setItem(0, 0, item)

            item = QTableWidgetItem(str(wind_direction))
            item.setForeground(QBrush(QColor(144, 182, 240)))
            self.query_result.setItem(0, 1, item)

            item = QTableWidgetItem(str(wind_power))
            item.setForeground(QBrush(QColor(144, 182, 240)))
            self.query_result.setItem(0, 2, item)

            item = QTableWidgetItem(str(pm))
            item.setForeground(QBrush(QColor(144, 182, 240)))
            self.query_result.setItem(0, 3, item)

            item = QTableWidgetItem(str(type_))
            item.setForeground(QBrush(QColor(144, 182, 240)))
            self.query_result.setItem(0, 4, item)

        if col_num > 2 and row_num > 1:
            for i in range(0, self.num):
                item = QTableWidgetItem("最" + str(pre_tempature[i]))
                # 设置单元格文本颜色
                item.setForeground(QBrush(QColor(144, 182, 240)))
                self.query_result.setItem(i, 0, item)

                item = QTableWidgetItem(str(pre_wind_direction[i]))
                item.setForeground(QBrush(QColor(144, 182, 240)))
                self.query_result.setItem(i, 1, item)

                item = QTableWidgetItem(str(pre_wind_power[i]))
                item.setForeground(QBrush(QColor(144, 182, 240)))
                self.query_result.setItem(i, 2, item)

                item = QTableWidgetItem(str(pre_type_[i]))
                item.setForeground(QBrush(QColor(144, 182, 240)))
                self.query_result.setItem(i, 3, item)

        if col_num == 2 and row_num > 1:
            date = self.get_date(addDays=self.num)
            key = 0
            for i in date:
                item = QTableWidgetItem(i)
                item.setForeground(QBrush(QColor(144, 182, 240)))
                self.query_result.setItem(key, 0, item)
                key += 1
            if self.query_result.horizontalHeaderItem(1).text() == '温度':
                key = 0
                for i in pre_tempature:
                    item = QTableWidgetItem("最" + i)
                    item.setForeground(QBrush(QColor(144, 182, 240)))
                    self.query_result.setItem(key, 1, item)
                    key += 1
            if self.query_result.horizontalHeaderItem(1).text() == '风力':
                key = 0
                for i in pre_wind_power:
                    item = QTableWidgetItem(i)
                    item.setForeground(QBrush(QColor(144, 182, 240)))
                    self.query_result.setItem(key, 1, item)
                    key += 1

        # 只有两列的时候才可以绘制图像,
        if col_num < 4:
            ls, y = [], []
            x = np.linspace(1, self.num, self.num)
            # 将 treeview 里面的结果以数字的形式返回到列表中, 用于绘图
            for row in range(self.num):
                str1 = str(self.query_result.item(row, 1).text())
                ls.extend(re.findall(r'\d+(?:\.\d+)?', str1))
            if len(ls) == 5:
                y = [float(i) for i in ls]
            if len(ls) == 10:
                lt = [float(i) for i in ls]
                for i in range(0, len(lt), 2):
                    y.append((lt[i] + lt[i + 1]) / 2)
            if len(ls) == 5:
                y = [float(i) for i in ls]
            else:
                y = [float(i) for i in ls[0:5]]
            # 获取 treeview 的标题, 以得到绘图时的标得
            if self.query_result.horizontalHeaderItem(1).text() == '温度':
                title_ = "近期一个月温度变化(预测)"
            if self.query_result.horizontalHeaderItem(1).text() == '风力':
                title_ = "近期一个月风力变化(预测)"
            # 绘图时先清空面板 否则会新加一列,效果不好
            # 且 pyqtgraph 的新加一列有bug, 效果不是很好 下次使用 matplotlib
            if title_ == "近期一个月风力变化(预测)":
                self.plot_weather_wind.clear()
                bg1 = pg.BarGraphItem(x=x,
                                      height=y,
                                      width=0.3,
                                      brush=QColor(137, 232, 165))
                self.plt1 = self.plot_weather_wind.addPlot(title=title_)
                self.plt1.addItem(bg1)
            if title_ == "近期一个月温度变化(预测)":
                self.plot_weather_temp.clear()
                self.plt = self.plot_weather_temp.addPlot(title=title_)
                bg2 = pg.BarGraphItem(x=x,
                                      height=y,
                                      width=0.3,
                                      brush=QColor(32, 235, 233))
                self.plt.addItem(bg2)
            # 绘图后才可以保存图片
            self.save_fig.setEnabled(True)

    # 退出按钮
    def quit_act(self):
        # sender 是发送信号的对象
        sender = self.sender()
        print(sender.text() + '键被按下')
        qApp = QApplication.instance()
        qApp.quit()

    def center(self):
        '''
        获取桌面长宽
        获取窗口长宽
        移动
        '''
        screen = QDesktopWidget().screenGeometry()
        size = self.geometry()
        self.move((screen.width() - size.width()) / 2,
                  (screen.height() - size.height()) / 2)
示例#4
0
class centralWidget(QtWidgets.QWidget):
    def __init__(self, timeData, voltsData):
        super(centralWidget, self).__init__()
        # first generating the widgets of the tab
        self.generateUiWidgets()
        # initializing important variables for plotting
        self.timeData = timeData
        self.originalVoltsData = voltsData
        self.editedVoltsData = voltsData
        self.plot = None  # original Plot
        self.plot1 = None  # plot After Editing

        self.xRangeStack = []
        self.yRangeStack = []

        self.sampleTime = timeData[1] - timeData[0]
        yrange = voltsData[len(voltsData) - 1] - voltsData[0]
        self.scrollStep_x = 100 * self.sampleTime
        self.scrollStep_y = yrange / 10

        # sliders values
        for i in range(1, 11):
            setattr(self, "_value" + str(i), 1)  # self._value[1-10] = 1

        #Pallete of spectrogram *viridis as deafult
        self.RGB_Pallete_1 = (0, 182, 188, 255)
        self.RGB_Pallete_2 = (246, 111, 0, 255)
        self.RGB_Pallete_3 = (75, 0, 113, 255)

        #set labels text
        ft = fourierTransform(self.originalVoltsData, int(1 / self.sampleTime))
        ranges = ft.rangesOfFrequancy

        for i in range(10):
            getattr(getattr(self, "label" + str(i + 1)),
                    "setText")(str(ranges[i][0] / 1000) + " Khz : \n" +
                               str(ranges[i][1] / 1000) + " Khz")

        self.HorizontalLabel1.setText("Set Minimun Frequancy")
        self.HorizontalLabel2.setText("Set Maximum Frequancy")

        # values of range of spectrogram
        self.minFreqOfSpectrogram = 0
        self.maxFreqOfSpectrogram = ranges[-1][1]

        self.horizontalSlider1.setMinimum(self.minFreqOfSpectrogram)
        self.horizontalSlider1.setMaximum(self.maxFreqOfSpectrogram)
        self.horizontalSlider1.setTickInterval(self.maxFreqOfSpectrogram / 10)

        self.horizontalSlider2.setMinimum(self.minFreqOfSpectrogram)
        self.horizontalSlider2.setMaximum(self.maxFreqOfSpectrogram)
        self.horizontalSlider2.setSliderPosition(self.maxFreqOfSpectrogram)
        self.horizontalSlider2.setTickInterval(self.maxFreqOfSpectrogram / 10)

        self.SpectrogramViewer.clear()
        self.drawSpectrogram()

        self.xRangeOfSignal = None  # [from , to]
        self.yRangeOfSignal = None

        #start plotting the data
        self.startPlotting()
        # get range of view of the plots
        self.xRangeOfSignal = [0.0, list(self.timeData)[-1]]  # [from , to]
        self.yRangeOfSignal = self.plot.viewRange()[1]

    def startPlotting(self):
        # plot original signal
        self.plot = self.OriginalSignalViewer.addPlot()
        self.plot.plot(self.timeData, self.originalVoltsData)
        # plot data After Editing
        self.plot1 = self.EditedSignalViewer.addPlot()
        self.plot1.plot(self.timeData, self.editedVoltsData)
        # range edit
        self.plot.setXRange(0.0, list(self.timeData)[-1], 0)
        self.plot1.setXRange(0.0, list(self.timeData)[-1], 0)

    def minSliderOfSpectrogram(self, value):
        self.minFreqOfSpectrogram = value
        self.SpectrogramViewer.clear()
        self.drawSpectrogram()

    def maxSliderOfSpectrogram(self, value):
        self.maxFreqOfSpectrogram = value
        self.SpectrogramViewer.clear()
        self.drawSpectrogram()

    def drawSpectrogram(self, minFreq=1, maxFreq=1):
        minFreq_Slider = self.minFreqOfSpectrogram
        maxFreq_Slider = self.maxFreqOfSpectrogram
        freq = 1 / self.sampleTime
        ft = fourierTransform(self.editedVoltsData, int(freq))
        ft.deleteRangeOfFrequancy(0, minFreq_Slider)
        ft.deleteRangeOfFrequancy(maxFreq_Slider, int(freq / 2))
        realsAfterEdit = ft.fn_InverceFourier(ft.data_fft)
        frequancyArr, timeArr, Sxx = signal.spectrogram(
            np.array(realsAfterEdit), freq)
        pyqtgraph.setConfigOptions(imageAxisOrder='row-major')

        win = self.SpectrogramViewer
        p1 = win.addPlot()

        img = pyqtgraph.ImageItem()
        p1.addItem(img)
        hist = pyqtgraph.HistogramLUTItem()
        hist.setImageItem(img)
        win.addItem(hist)
        hist.setLevels(np.min(Sxx), np.max(Sxx))
        hist.gradient.restoreState({
            'mode':
            'rgb',
            'ticks': [(0.5, self.RGB_Pallete_1), (1.0, self.RGB_Pallete_2),
                      (0.0, self.RGB_Pallete_3)]
        })
        img.setImage(Sxx)
        img.scale(timeArr[-1] / np.size(Sxx, axis=1),
                  frequancyArr[-1] / np.size(Sxx, axis=0))
        p1.setLimits(xMin=0, xMax=timeArr[-1], yMin=0, yMax=frequancyArr[-1])
        p1.setLabel('bottom', "Time", units='s')
        p1.setLabel('left', "Frequency", units='Hz')

    def generateUiWidgets(self):
        font = QtGui.QFont()
        font.setFamily("Arial Unicode MS")
        font.setPointSize(10)
        font.setBold(False)
        font.setItalic(False)
        font.setUnderline(False)
        font.setWeight(50)
        font.setStrikeOut(False)
        font.setKerning(True)
        self.setFont(font)
        self.setTabletTracking(False)
        self.gridLayout_6 = QtWidgets.QGridLayout(self)
        self.gridLayout_6.setObjectName("gridLayout_6")
        self.gridLayout_4 = QtWidgets.QGridLayout()
        self.gridLayout_4.setObjectName("gridLayout_4")
        self.SpectrogramGroupBox = QtWidgets.QGroupBox(self)
        font = QtGui.QFont()
        font.setFamily("Arial Unicode MS")
        font.setPointSize(10)
        font.setBold(True)
        font.setWeight(75)
        self.SpectrogramGroupBox.setFont(font)
        self.SpectrogramGroupBox.setStyleSheet(
            "border-color: qlineargradient(spread:pad, x1:0, y1:0, x2:1, y2:0, stop:0 rgba(0, 0, 0, 255), stop:1 rgba(255, 255, 255, 255));"
        )
        self.SpectrogramGroupBox.setObjectName("SpectrogramGroupBox")
        self.gridLayout_2 = QtWidgets.QGridLayout(self.SpectrogramGroupBox)
        self.gridLayout_2.setObjectName("gridLayout_2")

        self.horizontalSlider2 = QtWidgets.QSlider(self.SpectrogramGroupBox)
        font = QtGui.QFont()
        font.setPointSize(7)
        self.horizontalSlider2.setFont(font)
        self.horizontalSlider2.setCursor(
            QtGui.QCursor(QtCore.Qt.ClosedHandCursor))
        self.horizontalSlider2.setOrientation(QtCore.Qt.Horizontal)
        self.horizontalSlider2.setTickPosition(QtWidgets.QSlider.TicksBelow)
        self.horizontalSlider2.setTickInterval(1)
        self.horizontalSlider2.setObjectName("horizontalSlider2")
        self.gridLayout_2.addWidget(self.horizontalSlider2, 3, 0, 1, 1)

        self.SpectrogramViewer = GraphicsLayoutWidget(self.SpectrogramGroupBox)
        self.SpectrogramViewer.viewport().setProperty(
            "cursor", QtGui.QCursor(QtCore.Qt.PointingHandCursor))
        self.SpectrogramViewer.setStyleSheet("background-color:rgb(0,0,0)")
        self.SpectrogramViewer.setObjectName("SpectrogramViewer")
        self.gridLayout_2.addWidget(self.SpectrogramViewer, 0, 0, 1, 1)

        self.horizontalSlider1 = QtWidgets.QSlider(self.SpectrogramGroupBox)
        self.horizontalSlider1.setCursor(
            QtGui.QCursor(QtCore.Qt.ClosedHandCursor))
        self.horizontalSlider1.setOrientation(QtCore.Qt.Horizontal)
        self.horizontalSlider1.setTickPosition(QtWidgets.QSlider.TicksBelow)
        self.horizontalSlider1.setTickInterval(1)
        self.horizontalSlider1.setObjectName("horizontalSlider1")
        self.gridLayout_2.addWidget(self.horizontalSlider1, 1, 0, 1, 1)
        self.HorizontalLabel1 = QtWidgets.QLabel(self.SpectrogramGroupBox)
        font = QtGui.QFont()
        font.setPointSize(7)
        self.HorizontalLabel1.setFont(font)
        self.HorizontalLabel1.setAlignment(QtCore.Qt.AlignCenter)
        self.HorizontalLabel1.setObjectName("HorizontalLabel1")
        self.gridLayout_2.addWidget(self.HorizontalLabel1, 2, 0, 1, 1)
        self.HorizontalLabel2 = QtWidgets.QLabel(self.SpectrogramGroupBox)
        self.HorizontalLabel2.setAlignment(QtCore.Qt.AlignCenter)
        self.HorizontalLabel2.setObjectName("HorizontalLabel2")
        self.gridLayout_2.addWidget(self.HorizontalLabel2, 4, 0, 1, 1)

        self.gridLayout_4.addWidget(self.SpectrogramGroupBox, 0, 1, 3, 1)
        self.OriginalSignalGroupbox = QtWidgets.QGroupBox(self)
        font = QtGui.QFont()
        font.setFamily("Arial Unicode MS")
        font.setPointSize(10)
        font.setBold(True)
        font.setWeight(75)
        self.OriginalSignalGroupbox.setFont(font)
        self.OriginalSignalGroupbox.setObjectName("OriginalSignalGroupbox")
        self.gridLayout_5 = QtWidgets.QGridLayout(self.OriginalSignalGroupbox)
        self.gridLayout_5.setObjectName("gridLayout_5")
        self.OriginalSignalViewer = GraphicsLayoutWidget(
            self.OriginalSignalGroupbox)
        self.OriginalSignalViewer.viewport().setProperty(
            "cursor", QtGui.QCursor(QtCore.Qt.PointingHandCursor))
        self.OriginalSignalViewer.setStyleSheet(
            "background-color: rgb(0, 0, 0);")
        self.OriginalSignalViewer.setObjectName("OriginalSignalViewer")
        self.gridLayout_5.addWidget(self.OriginalSignalViewer, 0, 0, 1, 1)
        self.gridLayout_4.addWidget(self.OriginalSignalGroupbox, 0, 0, 1, 1)
        self.EditedSignalGroupBox = QtWidgets.QGroupBox(self)
        font = QtGui.QFont()
        font.setFamily("Arial Unicode MS")
        font.setPointSize(10)
        font.setBold(True)
        font.setWeight(75)
        self.EditedSignalGroupBox.setFont(font)
        self.EditedSignalGroupBox.setObjectName("EditedSignalGroupBox")
        self.gridLayout_3 = QtWidgets.QGridLayout(self.EditedSignalGroupBox)
        self.gridLayout_3.setObjectName("gridLayout_3")
        self.EditedSignalViewer = GraphicsLayoutWidget(
            self.EditedSignalGroupBox)
        self.EditedSignalViewer.viewport().setProperty(
            "cursor", QtGui.QCursor(QtCore.Qt.PointingHandCursor))
        self.EditedSignalViewer.setStyleSheet("background-color:rgb(0,0,0)")
        self.EditedSignalViewer.setObjectName("EditedSignalViewer")
        self.gridLayout_3.addWidget(self.EditedSignalViewer, 0, 0, 1, 1)
        self.gridLayout_4.addWidget(self.EditedSignalGroupBox, 2, 0, 1, 1)
        self.SignalEditorGroupBox = QtWidgets.QGroupBox(self)
        font = QtGui.QFont()
        font.setFamily("Arial Unicode MS")
        font.setPointSize(10)
        font.setBold(True)
        font.setWeight(75)
        self.SignalEditorGroupBox.setFont(font)
        self.SignalEditorGroupBox.setObjectName("SignalEditorGroupBox")
        self.gridLayout = QtWidgets.QGridLayout(self.SignalEditorGroupBox)
        self.gridLayout.setObjectName("gridLayout")
        self.gridLayout_7 = QtWidgets.QGridLayout()
        self.gridLayout_7.setObjectName("gridLayout_7")

        # add 10 labels for 10 sliders
        for i in range(10):
            setattr(
                self, "label" + str(i + 1),
                QtWidgets.QLabel(self.SignalEditorGroupBox)
            )  # self.label[1-10] = QtWidgets.QLabel(self.SignalEditorGroupBox)
            font = QtGui.QFont()
            font.setPointSize(5)
            font.setBold(False)
            font.setWeight(50)
            getattr(getattr(self, "label" + str(i + 1)),
                    "setFont")(font)  #self.label1.setFont(font)
            getattr(getattr(self, "label" + str(i + 1)), "setObjectName")(
                "label" +
                str(i + 1))  #self.label[1-10].setObjectName("label[1-10]")
            getattr(getattr(self, "gridLayout_7"), "addWidget")(
                getattr(self, "label" + str(i + 1)), 0, i, 1,
                1)  #self.gridLayout_7.addWidget(self.label1, 0, 0, 1, 1)

        self.gridLayout.addLayout(self.gridLayout_7, 1, 0, 1, 10)

        # add 10 sliders to gui
        for i in range(10):
            setattr(
                self, "slider" + str(i + 1),
                Slider(parent=self.SignalEditorGroupBox,
                       objectName="Slider" + str(i + 1)))
            Slid = getattr(self, "slider" + str(i + 1))
            getattr(getattr(self, "gridLayout"), "addWidget")(Slid, 0, i, 1, 1)
            # link slider to the trigger function
            Slid.valueChanged.connect(lambda v: self.fn_sliderValue(v))

        self.gridLayout_4.addWidget(self.SignalEditorGroupBox, 1, 0, 1, 1)
        self.gridLayout_6.addLayout(self.gridLayout_4, 0, 0, 1, 1)

        self.horizontalSlider1.valueChanged[int].connect(
            self.minSliderOfSpectrogram)
        self.horizontalSlider2.valueChanged[int].connect(
            self.maxSliderOfSpectrogram)
        QtCore.QMetaObject.connectSlotsByName(self)

    def fn_sliderValue(self, value):
        objectName = str(self.sender().objectName())
        sliderNumber = re.search("[0-9]+", objectName).group()
        setattr(self, "_value" + sliderNumber, value)
        self.process()

    def process(self):  #(freq,complex_data,reals,time,np.abs(complex_data))
        ft = fourierTransform(
            list(self.originalVoltsData).copy(), 1 / self.sampleTime)
        complex_data = ft.gain(self._value1, self._value2, self._value3,
                               self._value4, self._value5, self._value6,
                               self._value7, self._value8, self._value9,
                               self._value10)
        reals = ft.fn_InverceFourier(complex_data)
        #update Plot
        self.EditedSignalViewer.clear()
        self.plot1 = self.EditedSignalViewer.addPlot()
        self.plot1.plot(self.timeData, reals)
        self.plot1.setXRange(self.xRangeOfSignal[0], self.xRangeOfSignal[1], 0)
        self.plot1.setYRange(self.yRangeOfSignal[0], self.yRangeOfSignal[1], 0)
        self.editedVoltsData = np.array(reals)
        # update spectrogram
        self.SpectrogramViewer.clear()
        self.drawSpectrogram()

    def playSound(self):
        ft = fourierTransform(
            list(self.editedVoltsData).copy(), 1 / self.sampleTime)
        complex_data = ft.gain(self._value1, self._value2, self._value3,
                               self._value4, self._value5, self._value6,
                               self._value7, self._value8, self._value9,
                               self._value10)
        reals = ft.fn_InverceFourier(complex_data)
        sound = soundfileUtility()
        sound.fn_CreateSoundFile(list(reals), int(1 / self.sampleTime))
        sound.fn_PlaySoundFile()
示例#5
0
class Viewer:
    """
    abstraction layer on viewer that dispatches visualisation to specific components
    and implements widgets for DataSet functionality
    """
    def __init__(self, dataset):
        self.dataset = dataset

        self.napari_viewer = None
        self.plots = None
        self.blik_widget = None

    def show(self, napari_viewer=None, **kwargs):
        self.ensure_ready(napari_viewer=napari_viewer)
        for db in self.dataset:
            db.init_depictor(**kwargs)
        if self.dataset.volumes:
            self.show_volume(list(self.dataset.volumes.keys())[0])
        if self.dataset.plots:
            self.plots.show()

    @property
    def shown(self):
        if self.napari_viewer and self.volume_selector:
            return self.dataset.volumes[self.volume_selector.currentText()]
        return None

    def ensure_ready(self, napari_viewer=None):
        if napari_viewer is not None:
            self._init_viewer(napari_viewer)
            self._init_plots()
            self._init_blik_widget()
            self._hook_keybindings()
        # check if viewer exists and is still open
        try:
            self.napari_viewer.window.qt_viewer.actions()
        except (AttributeError, RuntimeError):
            self._init_viewer()
            self._init_plots()
            self._init_blik_widget()
            self._hook_keybindings()

    def _init_viewer(self, napari_viewer=None):
        if napari_viewer is not None:
            self.napari_viewer = napari_viewer
        else:
            self.napari_viewer = napari.Viewer(title='napari - Blik')
        self.napari_viewer.scale_bar.unit = '0.1nm'
        self.napari_viewer.scale_bar.visible = True
        # TODO: workaround until layer issues are fixed (napari #2110)
        self.napari_viewer.window.qt_viewer.destroyed.connect(self.dataset.purge_gui)

    def _init_plots(self):
        self.plots = GraphicsLayoutWidget()
        self._plots_napari_widget = self.napari_viewer.window.add_dock_widget(self.plots,
                                                                              name='Blik - Plots',
                                                                              area='bottom')
        # use napari hide and show methods
        self.plots.show = self._plots_napari_widget.show
        self.plots.hide = self._plots_napari_widget.hide
        self.plots.hide()

    def _init_blik_widget(self):
        self.blik_widget = QWidget()
        layout = QVBoxLayout()
        self.blik_widget.setLayout(layout)

        self.volume_selector = QComboBox(self.blik_widget)
        self.volume_selector.addItems(self.dataset.volumes.keys())
        self.volume_selector.currentTextChanged.connect(self.show_volume)
        layout.addWidget(self.volume_selector)

        self.plots_toggler = QPushButton('Show / Hide Plots')
        self.plots_toggler.clicked.connect(self.toggle_plots)
        layout.addWidget(self.plots_toggler)

        self._blik_napari_widget = self.napari_viewer.window.add_dock_widget(self.blik_widget,
                                                                           name='Blik',
                                                                           area='left')
        # use napari hide and show methods
        self.blik_widget.show = self._blik_napari_widget.show
        self.blik_widget.hide = self._blik_napari_widget.hide

    def _hook_keybindings(self):
        self.napari_viewer.bind_key('PageUp', self.previous_volume)
        self.napari_viewer.bind_key('PageDown', self.next_volume)

    def update_blik_widget(self):
        if self.blik_widget is not None:
            current_text = self.volume_selector.currentText()
            with block_signals(self.volume_selector):
                self.volume_selector.clear()
                self.volume_selector.addItems(self.dataset.volumes.keys())
                self.volume_selector.setCurrentText(current_text)
        self.show()

    def clear_shown(self):
        for layer in self.napari_viewer.layers.copy():
            if layer in self.dataset.napari_layers:
                self.napari_viewer.layers.remove(layer)
        self.plots.clear()

    def show_volume(self, volume):
        if volume is None:
            return
        self.volume_selector.setCurrentText(volume)
        datablocks = self.dataset.omni + self.dataset.volumes[volume]

        layers = []
        plots = []
        for db in datablocks:
            for dep in db.depictors:
                if hasattr(dep, 'layers'):
                    if not dep.layers:
                        dep.depict()
                    layers.extend(dep.layers)
                elif hasattr(dep, 'plot'):
                    if not dep.plot.curves:
                        dep.depict()
                    plots.append(dep.plot)
        layers = sorted(layers, key=lambda l: isinstance(l, napari.layers.Image), reverse=True)

        self.clear_shown()
        self.napari_viewer.layers.extend(layers)
        for plt in plots:
            self.plots.addItem(plt)

    def previous_volume(self, viewer=None):
        idx = self.volume_selector.currentIndex()
        previous_idx = (idx - 1) % self.volume_selector.count()
        self.volume_selector.setCurrentIndex(previous_idx)

    def next_volume(self, viewer=None):
        idx = self.volume_selector.currentIndex()
        next_idx = (idx + 1) % self.volume_selector.count()
        self.volume_selector.setCurrentIndex(next_idx)

    def toggle_plots(self):
        if self.plots.isVisible():
            self.plots.hide()
        else:
            self.plots.show()
示例#6
0
class EEGViewer(QThread):
    """docstring for EEGViewer"""

    StoppedState = 0
    PausedState = 1
    RunningState = 2

    def __init__(self, mode='single', rows=4):
        super(EEGViewer, self).__init__()
        self.mode = mode
        self.rows = rows
        self.view = GraphicsLayoutWidget()
        self.view.setAntialiasing(True)
        self.view.setWindowTitle('EEG Viewer')
        self.state = self.StoppedState
        self.position = 0
        self.maxPosition = 0
        self.plotItem = list()
        self.plotTrace = dict()
        # Holders
        self.wait = 0
        self.wsize = 0
        self.hsize = 0
        self.color = dict()
        self.window = list([0, 0])
        self.channel = list()

    def widget(self):
        return self.view

    def show(self):
        self.view.show()

    def hide(self):
        self.view.hide()

    def getState(self):
        return self.state

    def isVisible(self):
        return self.view.isVisible()

    def setSize(self, width, height):
        self.view.resize(width, height)

    def configure(self, channel, color, wsize, fs=0):
        # Link params
        nCh = len(channel)
        self.wait = 1 / (fs * nCh) if fs > 0 else 0
        self.wsize = wsize
        self.hsize = wsize / 2
        self.color = color
        self.channel = channel
        self.window = np.array([0, wsize])
        # Remove previous items and traces
        self.view.clear()
        self.plotItem.clear()
        self.plotTrace.clear()
        # Create new canvas
        if self.mode == 'single':
            self.singleLayout()
        else:
            self.multipleLayout()

    def singleLayout(self):
        canvas = self.view.addPlot(0, 0)
        canvas.disableAutoRange()
        canvas.setClipToView(True)
        canvas.setLimits(yMin=0, yMax=1)
        canvas.setDownsampling(mode='subsample')
        canvas.showGrid(x=True, y=True, alpha=0.25)
        for ch in self.channel:
            pen = mkPen(color=self.color[ch], width=2)
            self.plotTrace[ch] = canvas.plot(pen=pen)
        self.plotItem.append(canvas)

    def multipleLayout(self):
        col = 0
        rowLimit = self.rows
        for i, ch in enumerate(self.channel):
            pen = mkPen(color=self.color[ch], width=2)
            canvas = self.view.addPlot(i % rowLimit, col)
            canvas.disableAutoRange()
            canvas.setClipToView(True)
            canvas.setLimits(yMin=0, yMax=1)
            canvas.setDownsampling(mode='subsample')
            canvas.showGrid(x=True, y=True, alpha=0.25)
            self.plotItem.append(canvas)
            self.plotTrace[ch] = canvas.plot(pen=pen)
            if (i + 1) % rowLimit == 0:
                col += 1

    def plotData(self, D):
        for ch in self.channel:
            self.plotTrace[ch].setData(D[ch].values)
        self.position = 0
        self.maxPosition = D.index.size

    def addMark(self, position, label=None):
        for canvas in self.plotItem:
            pen = mkPen(color='g', width=2.5, style=Qt.DashLine)
            hpen = mkPen(color='r', width=2.5, style=Qt.DashLine)
            mark = canvas.addLine(x=position,
                                  pen=pen,
                                  label=label,
                                  labelOpts={'position': 0.9},
                                  movable=True,
                                  hoverPen=hpen)
            return mark

    def setPosition(self, position):
        self.window[0] = position - self.hsize
        self.window[1] = position + self.hsize
        self.position = position
        self.update()

    def update(self):
        for plot in self.plotItem:
            plot.setRange(xRange=self.window)
        self.position += 1 if self.position < self.maxPosition else 0

    def play(self):
        self.state = self.RunningState
        self.start()

    def pause(self):
        self.state = self.PausedState

    def toggle(self):
        self.state = self.PausedState if self.state == self.RunningState else self.RunningState

    def stop(self):
        self.state = self.StoppedState
        self.quit()
        self.setPosition(0)

    def run(self):
        while True:
            if self.state == self.RunningState:
                self.setPosition(self.position)
            elif self.state == self.PausedState:
                pass
            else:
                break
            sleep(self.wait)