예제 #1
0
class AirInventoryView(QWidget):
    def __init__(self, game_model: GameModel) -> None:
        super().__init__()

        self.game_model = game_model
        self.country = self.game_model.game.country_for(player=True)

        layout = QVBoxLayout()
        self.setLayout(layout)

        self.only_unallocated_cb = QCheckBox("Unallocated Only?")
        self.only_unallocated_cb.toggled.connect(self.update_table)

        layout.addWidget(self.only_unallocated_cb)

        self.table = QTableWidget()
        layout.addWidget(self.table)

        self.table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.table.verticalHeader().setVisible(False)
        self.update_table(False)

    def update_table(self, only_unallocated: bool) -> None:
        self.table.setSortingEnabled(False)
        self.table.clear()

        inventory_rows = list(self.get_data(only_unallocated))
        self.table.setRowCount(len(inventory_rows))
        headers = AircraftInventoryData.headers()
        self.table.setColumnCount(len(headers))
        self.table.setHorizontalHeaderLabels(headers)

        for row, data in enumerate(inventory_rows):
            for column, value in enumerate(data.columns):
                self.table.setItem(row, column, QTableWidgetItem(value))

        self.table.resizeColumnsToContents()
        self.table.setSortingEnabled(True)

    def iter_allocated_aircraft(self) -> Iterator[AircraftInventoryData]:
        for package in self.game_model.game.blue_ato.packages:
            for flight in package.flights:
                yield from AircraftInventoryData.from_flight(flight)

    def iter_unallocated_aircraft(self) -> Iterator[AircraftInventoryData]:
        game = self.game_model.game
        for control_point, inventory in game.aircraft_inventory.inventories.items():
            if control_point.captured:
                yield from AircraftInventoryData.each_from_inventory(inventory)

    def get_data(self, only_unallocated: bool) -> Iterator[AircraftInventoryData]:
        yield from self.iter_unallocated_aircraft()
        if not only_unallocated:
            yield from self.iter_allocated_aircraft()
예제 #2
0
class AirInventoryView(QWidget):
    def __init__(self, game_model: GameModel) -> None:
        super().__init__()
        self.game_model = game_model

        self.only_unallocated = False
        self.enemy_info = False

        layout = QVBoxLayout()
        self.setLayout(layout)

        checkbox_row = QHBoxLayout()
        layout.addLayout(checkbox_row)

        self.only_unallocated_cb = QCheckBox("Unallocated only")
        self.only_unallocated_cb.toggled.connect(self.set_only_unallocated)
        checkbox_row.addWidget(self.only_unallocated_cb)

        self.enemy_info_cb = QCheckBox("Show enemy info")
        self.enemy_info_cb.toggled.connect(self.set_enemy_info)
        checkbox_row.addWidget(self.enemy_info_cb)

        checkbox_row.addStretch()

        self.table = QTableWidget()
        layout.addWidget(self.table)

        self.table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.table.verticalHeader().setVisible(False)
        self.set_only_unallocated(False)

    def set_only_unallocated(self, value: bool) -> None:
        self.only_unallocated = value
        self.update_table()

    def set_enemy_info(self, value: bool) -> None:
        self.enemy_info = value
        self.update_table()

    def update_table(self) -> None:
        self.table.setSortingEnabled(False)
        self.table.clear()

        inventory_rows = list(self.get_data())
        self.table.setRowCount(len(inventory_rows))
        headers = AircraftInventoryData.headers()
        self.table.setColumnCount(len(headers))
        self.table.setHorizontalHeaderLabels(headers)

        for row, data in enumerate(inventory_rows):
            for column, value in enumerate(data.columns):
                self.table.setItem(row, column, QTableWidgetItem(value))

        self.table.resizeColumnsToContents()
        self.table.setSortingEnabled(True)

    def iter_allocated_aircraft(self) -> Iterator[AircraftInventoryData]:
        coalition = self.game_model.game.coalition_for(not self.enemy_info)
        for package in coalition.ato.packages:
            for flight in package.flights:
                yield from AircraftInventoryData.from_flight(flight)

    def iter_unallocated_aircraft(self) -> Iterator[AircraftInventoryData]:
        coalition = self.game_model.game.coalition_for(not self.enemy_info)
        for squadron in coalition.air_wing.iter_squadrons():
            yield from AircraftInventoryData.each_untasked_from_squadron(
                squadron)

    def get_data(self) -> Iterator[AircraftInventoryData]:
        yield from self.iter_unallocated_aircraft()
        if not self.only_unallocated:
            yield from self.iter_allocated_aircraft()
예제 #3
0
class assCheck(QDialog):
    getSub = Signal()
    position = Signal(int)

    def __init__(self, subtitleDict, index, styles, styleNameList):
        super().__init__()
        self.subtitleDict = subtitleDict
        self.index = index
        self.styles = styles
        self.resize(950, 800)
        self.setWindowTitle('检查字幕')
        layout = QGridLayout()
        self.setLayout(layout)
        layout.addWidget(QLabel('选择字幕轨道:'), 0, 0, 1, 1)
        layout.addWidget(QLabel(''), 0, 1, 1, 1)
        self.subCombox = QComboBox()
        self.subCombox.addItems(styleNameList)
        self.subCombox.setCurrentIndex(index)
        self.subCombox.currentIndexChanged.connect(self.selectChange)
        layout.addWidget(self.subCombox, 0, 2, 1, 1)
        self.subTable = QTableWidget()
        self.subTable.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.subTable.doubleClicked.connect(self.clickTable)
        self.subTable.setEditTriggers(QAbstractItemView.NoEditTriggers)
        layout.addWidget(self.subTable, 1, 0, 6, 3)
        self.refresh = QPushButton('刷新')
        self.refresh.clicked.connect(self.refreshSub)
        layout.addWidget(self.refresh, 7, 0, 1, 1)
        self.cancel = QPushButton('确定')
        self.cancel.clicked.connect(self.hide)
        layout.addWidget(self.cancel, 7, 2, 1, 1)
        self.refreshTable()

    def setDefault(self, subtitleDict, styles):
        self.subtitleDict = subtitleDict
        self.styles = styles
        self.refreshTable()

    def selectChange(self, index):
        self.index = index
        self.refreshTable()

    def refreshSub(self):
        self.getSub.emit()

    def refreshTable(self):
        style = self.styles[self.index]
        subDict = self.subtitleDict[self.index]
        self.subTable.clear()
        self.subTable.setRowCount(22 + len(subDict))
        self.subTable.setColumnCount(4)
        for col in range(3):
            self.subTable.setColumnWidth(col, 160)
        self.subTable.setColumnWidth(3, 350)
        for y, name in enumerate([
                'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour',
                'OutlineColour', 'BackColour', 'Bold', 'Italic', 'Underline',
                'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle',
                'BorderStyle', 'Outline', 'Shadow', 'Alignment', 'MarginL',
                'MarginR', 'MarginV', 'Encoding'
        ]):
            self.subTable.setItem(y, 0, QTableWidgetItem(name))
            self.subTable.setItem(y, 1, QTableWidgetItem(str(style[y])))
        startList = sorted(subDict.keys())
        preConflict = False
        nextConflict = False
        for y, start in enumerate(startList):
            delta, text = subDict[start]
            end = start + delta
            if y < len(startList) - 1:
                nextStart = startList[y + 1]
                if end > nextStart:
                    nextConflict = True
                else:
                    nextConflict = False
            if delta < 500 or delta > 8000:  # 持续时间小于500ms或大于8s
                deltaError = 2
            elif delta > 4500:  # 持续时间大于4.5s且小于8s
                deltaError = 1
            else:
                deltaError = 0
            end = ms2ASSTime(start + delta)
            start = ms2ASSTime(start)
            s, ms = divmod(delta, 1000)
            ms = ('%03d' % ms)[:2]
            delta = '持续 %s.%ss' % (s, ms)
            self.subTable.setItem(y + 22, 0, QTableWidgetItem(start))  # 开始时间
            if preConflict:
                self.subTable.item(y + 22,
                                   0).setBackground(QColor('#B22222'))  # 红色警告
            self.subTable.setItem(y + 22, 1, QTableWidgetItem(end))  # 结束时间
            if nextConflict:
                self.subTable.item(y + 22,
                                   1).setBackground(QColor('#B22222'))  # 红色警告
            self.subTable.setItem(y + 22, 2, QTableWidgetItem(delta))  # 持续时间
            if deltaError == 2:
                self.subTable.item(y + 22,
                                   2).setBackground(QColor('#B22222'))  # 红色警告
            elif deltaError == 1:
                self.subTable.item(y + 22,
                                   2).setBackground(QColor('#FA8072'))  # 橙色警告
            self.subTable.setItem(y + 22, 3, QTableWidgetItem(text))  # 字幕文本
            preConflict = nextConflict  # 将重叠信号传递给下一条轴

    def clickTable(self):
        item = self.subTable.selectedItems()[0]
        row = item.row()
        if row > 21:
            pos = calSubTime(item.text())
            self.position.emit(pos)  # 发射点击位置
예제 #4
0
class assSelect(QDialog):
    assSummary = Signal(list)

    def __init__(self):
        super().__init__()
        self.subDict = {
            '': {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': []
            }
        }
        self.resize(950, 800)
        self.setWindowTitle('选择要导入的ass字幕轨道')
        layout = QGridLayout()
        self.setLayout(layout)
        layout.addWidget(QLabel('检测到字幕样式:'), 0, 0, 1, 1)
        layout.addWidget(QLabel(''), 0, 1, 1, 1)
        self.subCombox = QComboBox()
        self.subCombox.currentTextChanged.connect(self.selectChange)
        layout.addWidget(self.subCombox, 0, 2, 1, 1)
        self.subTable = QTableWidget()
        self.subTable.setEditTriggers(QAbstractItemView.NoEditTriggers)
        layout.addWidget(self.subTable, 1, 0, 6, 3)
        self.confirm = QPushButton('导入')
        self.confirm.clicked.connect(self.sendSub)
        layout.addWidget(self.confirm, 7, 0, 1, 1)
        self.confirmStyle = QPushButton('导入样式')
        self.confirmStyle.clicked.connect(self.sendSubStyle)
        layout.addWidget(self.confirmStyle, 7, 1, 1, 1)
        self.cancel = QPushButton('取消')
        self.cancel.clicked.connect(self.hide)
        layout.addWidget(self.cancel, 7, 2, 1, 1)

    def setDefault(self, subtitlePath='', index=0):
        if subtitlePath:
            self.assCheck(subtitlePath)
            self.index = index

    def selectChange(self, styleName):
        self.subTable.clear()
        self.subTable.setRowCount(
            len(self.subDict[styleName]) +
            len(self.subDict[styleName]['Tableview']) - 2)
        self.subTable.setColumnCount(4)
        for col in range(3):
            self.subTable.setColumnWidth(col, 160)
        self.subTable.setColumnWidth(3, 350)
        y = 0
        for k, v in self.subDict[styleName].items():
            if k not in ['Tableview', 'Events']:
                self.subTable.setItem(y, 0, QTableWidgetItem(k))
                self.subTable.setItem(y, 1, QTableWidgetItem(v))
                y += 1
            elif k == 'Tableview':
                preConflict = False  # 上一条字幕时轴有重叠
                for cnt, line in enumerate(v):
                    nextConflict = False
                    start = calSubTime(line[0])
                    end = calSubTime(line[1])
                    if cnt < len(v) - 1:
                        nextStart = calSubTime(v[cnt + 1][0])
                        if end > nextStart:
                            nextConflict = True
                        else:
                            nextConflict = False
                    delta = end - start
                    if delta < 500 or delta > 8000:  # 持续时间小于500ms或大于8s
                        deltaError = 2
                    elif delta > 4500:  # 持续时间大于4.5s且小于8s
                        deltaError = 1
                    else:
                        deltaError = 0
                    s, ms = divmod(delta, 1000)
                    ms = ('%03d' % ms)[:2]
                    delta = '持续 %s.%ss' % (s, ms)
                    self.subTable.setItem(y, 0,
                                          QTableWidgetItem(line[0]))  # 开始时间
                    if preConflict:
                        self.subTable.item(y, 0).setBackground(
                            QColor('#B22222'))  # 红色警告
                    self.subTable.setItem(y, 1,
                                          QTableWidgetItem(line[1]))  # 结束时间
                    if nextConflict:
                        self.subTable.item(y, 1).setBackground(
                            QColor('#B22222'))  # 红色警告
                    self.subTable.setItem(y, 2,
                                          QTableWidgetItem(delta))  # 持续时间
                    if deltaError == 2:
                        self.subTable.item(y, 2).setBackground(
                            QColor('#B22222'))  # 红色警告
                    elif deltaError == 1:
                        self.subTable.item(y, 2).setBackground(
                            QColor('#FA8072'))  # 橙色警告
                    self.subTable.setItem(y, 3,
                                          QTableWidgetItem(line[2]))  # 字幕文本
                    y += 1
                    preConflict = nextConflict  # 将重叠信号传递给下一条轴

    def sendSub(self):
        self.assSummary.emit([
            self.index,
            self.subCombox.currentText(),
            self.subDict[self.subCombox.currentText()]
        ])
        self.hide()

    def sendSubStyle(self):
        subData = self.subDict[self.subCombox.currentText()]
        subData['Events'] = {}
        self.assSummary.emit(
            [self.index, self.subCombox.currentText(), subData])
        self.hide()

    def assCheck(self, subtitlePath):
        self.subDict = {
            '': {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': {}
            }
        }
        ass = codecs.open(subtitlePath, 'r', 'utf_8_sig')
        f = ass.readlines()
        ass.close()
        V4Token = False
        styleFormat = []
        styles = []
        eventToken = False
        eventFormat = []
        events = []
        for line in f:
            if '[V4+ Styles]' in line:
                V4Token = True
            elif V4Token and 'Format:' in line:
                styleFormat = line.replace(' ',
                                           '').strip().split(':')[1].split(',')
            elif V4Token and 'Style:' in line and styleFormat:
                styles.append(line.strip().split(':')[1].split(','))
            elif '[Events]' in line:
                eventToken = True
                V4Token = False
            elif eventToken and 'Format:' in line:
                eventFormat = line.strip().split(':')[1].split(',')
            elif eventToken and 'Comment:' in line and eventFormat:
                events.append(line.strip().split('Comment:')[1].split(
                    ',',
                    len(eventFormat) - 1))
            elif eventToken and 'Dialogue:' in line and eventFormat:
                events.append(line.strip().split('Dialogue:')[1].split(
                    ',',
                    len(eventFormat) - 1))

        for cnt, _format in enumerate(eventFormat):
            _format = _format.replace(' ', '')
            if _format == 'Start':
                Start = cnt
            elif _format == 'End':
                End = cnt
            elif _format == 'Style':
                Style = cnt
            elif _format == 'Text':
                Text = cnt

        for style in styles:
            styleName = style[0]
            self.subDict[styleName] = {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': {}
            }
            for cnt, _format in enumerate(styleFormat):
                if _format in self.subDict[styleName]:
                    self.subDict[styleName][_format] = style[cnt]
            for line in events:
                if styleName.replace(' ', '') == line[Style].replace(' ', ''):
                    start = calSubTime(line[Start]) // 10 * 10
                    delta = calSubTime(line[End]) - start // 10 * 10
                    self.subDict[styleName]['Tableview'].append(
                        [line[Start], line[End], line[Text]])
                    self.subDict[styleName]['Events'][start] = [
                        delta, line[Text]
                    ]

        self.subCombox.clear()
        combox = []
        for style in self.subDict.keys():
            if style:
                combox.append(style)
        self.subCombox.addItems(combox)
예제 #5
0
class Ui_MainWindow(QMainWindow):
    def __init__(self, parent=None):
        QMainWindow.__init__(self, parent)
        self.setupUi()

    def setupUi(self):
        self.setObjectName(_fromUtf8("MainWindow"))
        self.resize(565, 358)
        self.centralwidget = QWidget(self)
        self.centralwidget.setObjectName(_fromUtf8("centralwidget"))
        self.verticalLayout = QVBoxLayout(self.centralwidget)
        self.verticalLayout.setObjectName(_fromUtf8("verticalLayout"))
        self.horizontalLayout = QHBoxLayout()
        self.horizontalLayout.setSizeConstraint(QLayout.SetDefaultConstraint)
        self.horizontalLayout.setObjectName(_fromUtf8("horizontalLayout"))
        spacerItem = QSpacerItem(40, 20, QSizePolicy.MinimumExpanding,
                                 QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem)
        self.searchTE = QLineEdit(self.centralwidget)
        self.searchTE.setObjectName(_fromUtf8("searchTE"))
        self.horizontalLayout.addWidget(self.searchTE)
        self.searchBtn = QPushButton(self.centralwidget)
        sizePolicy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Fixed)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.searchBtn.sizePolicy().hasHeightForWidth())
        self.searchBtn.setSizePolicy(sizePolicy)
        self.searchBtn.setStyleSheet(_fromUtf8("border: none"))
        self.searchBtn.setText(_fromUtf8(""))
        icon = QIcon()
        icon.addPixmap(
            QPixmap(
                _fromUtf8(
                    "../.designer/gitlab/ExcelToSql/icons/searchBtn.png")),
            QIcon.Normal, QIcon.Off)
        self.searchBtn.setIcon(icon)
        self.searchBtn.setIconSize(QtCore.QSize(48, 24))
        self.searchBtn.setObjectName(_fromUtf8("searchBtn"))

        self.horizontalLayout.addWidget(self.searchBtn)
        self.pushButton_2 = QPushButton(self.centralwidget)
        sizePolicy = QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.pushButton_2.sizePolicy().hasHeightForWidth())
        self.pushButton_2.setSizePolicy(sizePolicy)
        self.pushButton_2.setMaximumSize(QtCore.QSize(32, 16777215))
        self.pushButton_2.setStyleSheet(_fromUtf8("color: blue"))
        self.pushButton_2.setObjectName(_fromUtf8("pushButton_2"))

        self.horizontalLayout.addWidget(self.pushButton_2)
        # self.pushButton_2.clicked.connect(self.onRunBtnClick())

        self.verticalLayout.addLayout(self.horizontalLayout)
        self.gridLayout = QGridLayout()
        self.gridLayout.setObjectName(_fromUtf8("gridLayout"))
        self.dateEdit = QDateEdit(self.centralwidget)
        self.dateEdit.setObjectName(_fromUtf8("dateEdit"))
        self.gridLayout.addWidget(self.dateEdit, 0, 0, 1, 1)
        self.dateEdit_2 = QDateEdit(self.centralwidget)
        self.dateEdit_2.setObjectName(_fromUtf8("dateEdit_2"))
        self.gridLayout.addWidget(self.dateEdit_2, 0, 1, 1, 1)
        self.verticalLayout.addLayout(self.gridLayout)
        self.tableWidget = QTableWidget(self.centralwidget)
        sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.tableWidget.sizePolicy().hasHeightForWidth())
        self.tableWidget.setSizePolicy(sizePolicy)
        self.tableWidget.setObjectName(_fromUtf8("tableWidget"))
        self.tableWidget.setColumnCount(3)
        self.tableWidget.setRowCount(0)
        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(0, item)
        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(1, item)
        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(2, item)
        self.verticalLayout.addWidget(self.tableWidget)
        self.setCentralWidget(self.centralwidget)
        self.statusbar = QStatusBar(self)
        self.statusbar.setObjectName(_fromUtf8("statusbar"))
        self.setStatusBar(self.statusbar)

        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(0, item)
        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(1, item)
        item = QTableWidgetItem()
        self.tableWidget.setHorizontalHeaderItem(2, item)

        self.retranslateUi()

        QtCore.QMetaObject.connectSlotsByName(self)
        print("A")
        self.pushButton_2.clicked.connect(self, QtCore.SLOT("onRunBtnClick()"))
        print("B")

        conn = sqlite3.connect("/home/eamon/Desktop/test.sqlite")
        result = conn.execute("SELECT * FROM Sheet1")
        for raw_number, raw_data in enumerate(result):
            self.tableWidget.insertRow(raw_number)
            #            print(raw_number)
            for column_number, data in enumerate(raw_data):
                item = QTableWidgetItem(str(data))
                self.tableWidget.setItem(raw_number, column_number, item)


#                print("\t", column_number)

        conn.close()

    def retranslateUi(self):
        self.setWindowTitle(_translate("MainWindow", "MainWindow", None))
        item = self.tableWidget.horizontalHeaderItem(0)
        item.setText(_translate("MainWindow", "id", None))
        item = self.tableWidget.horizontalHeaderItem(1)
        item.setText(_translate("MainWindow", "Name", None))
        item = self.tableWidget.horizontalHeaderItem(2)
        item.setText(_translate("MainWindow", "Cost", None))

    def onRunBtnClick(self):

        db = sqlite3.connect("/home/eamon/Desktop/test.sqlite")
        cur = db.cursor()
        q = "SELECT * FROM Sheet1 WHERE name='%s'" % (self.searchTE.text(), )
        print(q)

        res = cur.execute(q)
        if res:
            self.tableWidget.clear()
            for raw_number, raw_data in enumerate(res):
                self.tableWidget.insertRow(raw_number)

                for column_number, data in enumerate(raw_data):
                    item = QTableWidgetItem(str(data))
                    self.tableWidget.setItem(raw_number, column_number, item)
예제 #6
0
class GrainSizeDatasetViewer(QDialog):
    PAGE_ROWS = 20
    logger = logging.getLogger("root.ui.GrainSizeDatasetView")
    gui_logger = logging.getLogger("GUI")

    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("Grain-size Dataset Viewer"))
        self.__dataset = GrainSizeDataset()  # type: GrainSizeDataset
        self.init_ui()
        self.data_table.setRowCount(0)
        self.frequency_curve_chart = FrequencyCurveChart(parent=self,
                                                         toolbar=True)
        self.frequency_curve_3D_chart = FrequencyCurve3DChart(parent=self,
                                                              toolbar=True)
        self.cumulative_curve_chart = CumulativeCurveChart(parent=self,
                                                           toolbar=True)
        self.folk54_GSM_diagram_chart = Folk54GSMDiagramChart(parent=self,
                                                              toolbar=True)
        self.folk54_SSC_diagram_chart = Folk54SSCDiagramChart(parent=self,
                                                              toolbar=True)
        self.BP12_GSM_diagram_chart = BP12GSMDiagramChart(parent=self,
                                                          toolbar=True)
        self.BP12_SSC_diagram_chart = BP12SSCDiagramChart(parent=self,
                                                          toolbar=True)
        self.load_dataset_dialog = LoadDatasetDialog(parent=self)
        self.load_dataset_dialog.dataset_loaded.connect(self.on_data_loaded)
        self.file_dialog = QFileDialog(parent=self)
        self.normal_msg = QMessageBox(self)

    def init_ui(self):
        self.setWindowTitle(self.tr("Dataset Viewer"))
        self.data_table = QTableWidget(100, 100)
        self.data_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.data_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.data_table.setAlternatingRowColors(True)
        self.data_table.setContextMenuPolicy(Qt.CustomContextMenu)
        # self.data_table.hideColumn(0)
        self.main_layout = QGridLayout(self)
        self.main_layout.addWidget(self.data_table, 0, 0, 1, 3)

        self.previous_button = QPushButton(self.tr("Previous"))
        self.previous_button.setToolTip(
            self.tr("Click to back to the previous page."))
        self.previous_button.clicked.connect(self.on_previous_button_clicked)
        self.current_page_combo_box = QComboBox()
        self.current_page_combo_box.currentIndexChanged.connect(
            self.update_page)
        self.next_button = QPushButton(self.tr("Next"))
        self.next_button.setToolTip(self.tr("Click to jump to the next page."))
        self.next_button.clicked.connect(self.on_next_button_clicked)
        self.main_layout.addWidget(self.previous_button, 1, 0)
        self.main_layout.addWidget(self.current_page_combo_box, 1, 1)
        self.main_layout.addWidget(self.next_button, 1, 2)

        self.geometric_checkbox = QCheckBox(self.tr("Geometric"))
        self.geometric_checkbox.setChecked(True)
        self.geometric_checkbox.stateChanged.connect(
            self.on_is_geometric_changed)
        self.main_layout.addWidget(self.geometric_checkbox, 2, 0)
        self.FW57_checkbox = QCheckBox(self.tr("Method of statistic moments"))
        self.FW57_checkbox.setChecked(False)
        self.FW57_checkbox.stateChanged.connect(self.on_is_FW57_changed)
        self.main_layout.addWidget(self.FW57_checkbox, 2, 1)
        self.proportion_combo_box = QComboBox()
        self.supported_proportions = [
            ("GSM_proportion", self.tr("Gravel, Sand, Mud")),
            ("SSC_proportion", self.tr("Sand, Silt, Clay")),
            ("BGSSC_proportion", self.tr("Boulder, Gravel, Sand, Silt, Clay"))
        ]
        self.proportion_combo_box.addItems(
            [description for _, description in self.supported_proportions])
        self.proportion_combo_box.currentIndexChanged.connect(
            lambda: self.update_page(self.page_index))
        self.main_layout.addWidget(self.proportion_combo_box, 2, 2)

        self.menu = QMenu(self.data_table)
        self.load_dataset_action = self.menu.addAction(qta.icon("fa.database"),
                                                       self.tr("Load Dataset"))
        self.load_dataset_action.triggered.connect(self.load_dataset)
        self.plot_cumulative_curve_menu = self.menu.addMenu(
            qta.icon("mdi.chart-bell-curve-cumulative"),
            self.tr("Plot Cumlulative Curve Chart"))
        self.cumulative_plot_selected_action = self.plot_cumulative_curve_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.cumulative_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.cumulative_curve_chart, self.
                                    selections, False))
        self.cumulative_append_selected_action = self.plot_cumulative_curve_menu.addAction(
            self.tr("Append Selected Samples"))
        self.cumulative_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.cumulative_curve_chart, self.
                                    selections, True))
        self.cumulative_plot_all_action = self.plot_cumulative_curve_menu.addAction(
            self.tr("Plot All Samples"))
        self.cumulative_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.cumulative_curve_chart, self.__dataset
                                    .samples, False))
        self.cumulative_append_all_action = self.plot_cumulative_curve_menu.addAction(
            self.tr("Append All Samples"))
        self.cumulative_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.cumulative_curve_chart, self.__dataset
                                    .samples, True))

        self.plot_frequency_curve_menu = self.menu.addMenu(
            qta.icon("mdi.chart-bell-curve"),
            self.tr("Plot Frequency Curve Chart"))
        self.frequency_plot_selected_action = self.plot_frequency_curve_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.frequency_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_chart, self.
                                    selections, False))
        self.frequency_append_selected_action = self.plot_frequency_curve_menu.addAction(
            self.tr("Append Selected Samples"))
        self.frequency_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_chart, self.
                                    selections, True))
        self.frequency_plot_all_action = self.plot_frequency_curve_menu.addAction(
            self.tr("Plot All Samples"))
        self.frequency_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_chart, self.__dataset.
                                    samples, False))
        self.frequency_append_all_action = self.plot_frequency_curve_menu.addAction(
            self.tr("Append All Samples"))
        self.frequency_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_chart, self.__dataset.
                                    samples, True))

        self.plot_frequency_curve_3D_menu = self.menu.addMenu(
            qta.icon("mdi.video-3d"), self.tr("Plot Frequency Curve 3D Chart"))
        self.frequency_3D_plot_selected_action = self.plot_frequency_curve_3D_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.frequency_3D_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_3D_chart, self.
                                    selections, False))
        self.frequency_3D_append_selected_action = self.plot_frequency_curve_3D_menu.addAction(
            self.tr("Append Selected Samples"))
        self.frequency_3D_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_3D_chart, self.
                                    selections, True))
        self.frequency_3D_plot_all_action = self.plot_frequency_curve_3D_menu.addAction(
            self.tr("Plot All Samples"))
        self.frequency_3D_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_3D_chart, self.
                                    __dataset.samples, False))
        self.frequency_3D_append_all_action = self.plot_frequency_curve_3D_menu.addAction(
            self.tr("Append All Samples"))
        self.frequency_3D_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.frequency_curve_3D_chart, self.
                                    __dataset.samples, True))

        self.folk54_GSM_diagram_menu = self.menu.addMenu(
            qta.icon("mdi.triangle-outline"),
            self.tr("Plot GSM Diagram (Folk, 1954)"))
        self.folk54_GSM_plot_selected_action = self.folk54_GSM_diagram_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.folk54_GSM_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_GSM_diagram_chart, self.
                                    selections, False))
        self.folk54_GSM_append_selected_action = self.folk54_GSM_diagram_menu.addAction(
            self.tr("Append Selected Samples"))
        self.folk54_GSM_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_GSM_diagram_chart, self.
                                    selections, True))
        self.folk54_GSM_plot_all_action = self.folk54_GSM_diagram_menu.addAction(
            self.tr("Plot All Samples"))
        self.folk54_GSM_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_GSM_diagram_chart, self.
                                    __dataset.samples, False))
        self.folk54_GSM_append_all_action = self.folk54_GSM_diagram_menu.addAction(
            self.tr("Append All Samples"))
        self.folk54_GSM_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_GSM_diagram_chart, self.
                                    __dataset.samples, True))

        self.folk54_SSC_diagram_menu = self.menu.addMenu(
            qta.icon("mdi.triangle-outline"),
            self.tr("Plot SSC Diagram (Folk, 1954)"))
        self.folk54_SSC_plot_selected_action = self.folk54_SSC_diagram_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.folk54_SSC_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_SSC_diagram_chart, self.
                                    selections, False))
        self.folk54_SSC_append_selected_action = self.folk54_SSC_diagram_menu.addAction(
            self.tr("Append Selected Samples"))
        self.folk54_SSC_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_SSC_diagram_chart, self.
                                    selections, True))
        self.folk54_SSC_plot_all_action = self.folk54_SSC_diagram_menu.addAction(
            self.tr("Plot All Samples"))
        self.folk54_SSC_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_SSC_diagram_chart, self.
                                    __dataset.samples, False))
        self.folk54_SSC_append_all_action = self.folk54_SSC_diagram_menu.addAction(
            self.tr("Append All Samples"))
        self.folk54_SSC_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.folk54_SSC_diagram_chart, self.
                                    __dataset.samples, True))

        self.BP12_GSM_diagram_menu = self.menu.addMenu(
            qta.icon("mdi.triangle-outline"),
            self.tr("Plot GSM Diagram (Blott && Pye, 2012)"))
        self.BP12_GSM_plot_selected_action = self.BP12_GSM_diagram_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.BP12_GSM_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_GSM_diagram_chart, self.
                                    selections, False))
        self.BP12_GSM_append_selected_action = self.BP12_GSM_diagram_menu.addAction(
            self.tr("Append Selected Samples"))
        self.BP12_GSM_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_GSM_diagram_chart, self.
                                    selections, True))
        self.BP12_GSM_plot_all_action = self.BP12_GSM_diagram_menu.addAction(
            self.tr("Plot All Samples"))
        self.BP12_GSM_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_GSM_diagram_chart, self.__dataset
                                    .samples, False))
        self.BP12_GSM_append_all_action = self.BP12_GSM_diagram_menu.addAction(
            self.tr("Append All Samples"))
        self.BP12_GSM_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_GSM_diagram_chart, self.__dataset
                                    .samples, True))

        self.BP12_SSC_diagram_menu = self.menu.addMenu(
            qta.icon("mdi.triangle-outline"),
            self.tr("Plot SSC Diagram (Blott && Pye, 2012)"))
        self.BP12_SSC_plot_selected_action = self.BP12_SSC_diagram_menu.addAction(
            self.tr("Plot Selected Samples"))
        self.BP12_SSC_plot_selected_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_SSC_diagram_chart, self.
                                    selections, False))
        self.BP12_SSC_append_selected_action = self.BP12_SSC_diagram_menu.addAction(
            self.tr("Append Selected Samples"))
        self.BP12_SSC_append_selected_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_SSC_diagram_chart, self.
                                    selections, True))
        self.BP12_SSC_plot_all_action = self.BP12_SSC_diagram_menu.addAction(
            self.tr("Plot All Samples"))
        self.BP12_SSC_plot_all_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_SSC_diagram_chart, self.__dataset
                                    .samples, False))
        self.BP12_SSC_append_all_action = self.BP12_SSC_diagram_menu.addAction(
            self.tr("Append All Samples"))
        self.BP12_SSC_append_all_action.triggered.connect(
            lambda: self.plot_chart(self.BP12_SSC_diagram_chart, self.__dataset
                                    .samples, True))

        self.save_action = self.menu.addAction(qta.icon("mdi.microsoft-excel"),
                                               self.tr("Save Summary"))
        self.save_action.triggered.connect(self.on_save_clicked)
        self.data_table.customContextMenuRequested.connect(self.show_menu)

    def show_menu(self, pos):
        self.menu.popup(QCursor.pos())

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    def load_dataset(self):
        self.load_dataset_dialog.show()

    def on_data_loaded(self, dataset: GrainSizeDataset):
        self.__dataset = dataset
        self.current_page_combo_box.clear()
        page_count, left = divmod(self.__dataset.n_samples, self.PAGE_ROWS)
        if left != 0:
            page_count += 1
        self.current_page_combo_box.addItems(
            [f"{self.tr('Page')} {i+1}" for i in range(page_count)])
        self.update_page(0)

    @property
    def is_geometric(self) -> bool:
        return self.geometric_checkbox.isChecked()

    def on_is_geometric_changed(self, state):
        if state == Qt.Checked:
            self.geometric_checkbox.setText(self.tr("Geometric"))
        else:
            self.geometric_checkbox.setText(self.tr("Logarithmic"))
        self.update_page(self.page_index)

    @property
    def is_FW57(self) -> bool:
        return self.FW57_checkbox.isChecked()

    def on_is_FW57_changed(self, state):
        if state == Qt.Checked:
            self.FW57_checkbox.setText(self.tr("Folk and Ward (1957) method"))
        else:
            self.FW57_checkbox.setText(self.tr("Method of statistic moments"))
        self.update_page(self.page_index)

    @property
    def proportion(self) -> str:
        index = self.proportion_combo_box.currentIndex()
        key, description = self.supported_proportions[index]
        return key, description

    @property
    def page_index(self) -> int:
        return self.current_page_combo_box.currentIndex()

    @property
    def n_pages(self) -> int:
        return self.current_page_combo_box.count()

    @property
    def unit(self) -> str:
        return "μm" if self.is_geometric else "φ"

    def update_page(self, page_index: int):
        if self.__dataset is None:
            return

        def write(row: int, col: int, value: str):
            if isinstance(value, str):
                pass
            elif isinstance(value, int):
                value = str(value)
            elif isinstance(value, float):
                value = f"{value: 0.2f}"
            else:
                value = value.__str__()
            item = QTableWidgetItem(value)
            item.setTextAlignment(Qt.AlignCenter)
            self.data_table.setItem(row, col, item)

        # necessary to clear
        self.data_table.clear()
        if page_index == self.n_pages - 1:
            start = page_index * self.PAGE_ROWS
            end = self.__dataset.n_samples
        else:
            start, end = page_index * self.PAGE_ROWS, (page_index +
                                                       1) * self.PAGE_ROWS
        proportion_key, proportion_desciption = self.proportion
        col_names = [
            f"{self.tr('Mean')}[{self.unit}]",
            self.tr("Mean Desc."), f"{self.tr('Median')} [{self.unit}]",
            f"{self.tr('Modes')} [{self.unit}]",
            self.tr("STD (Sorting)"),
            self.tr("Sorting Desc."),
            self.tr("Skewness"),
            self.tr("Skew. Desc."),
            self.tr("Kurtosis"),
            self.tr("Kurt. Desc."),
            f"({proportion_desciption})\n{self.tr('Proportion')} [%]",
            self.tr("Group\n(Folk, 1954)"),
            self.tr("Group\nSymbol (Blott & Pye, 2012)"),
            self.tr("Group\n(Blott & Pye, 2012)")
        ]
        col_keys = [(True, "mean"), (True, "mean_description"),
                    (True, "median"), (True, "modes"), (True, "std"),
                    (True, "std_description"), (True, "skewness"),
                    (True, "skewness_description"), (True, "kurtosis"),
                    (True, "kurtosis_description"), (False, proportion_key),
                    (False, "group_Folk54"), (False, "group_BP12_symbol"),
                    (False, "group_BP12")]
        self.data_table.setRowCount(end - start)
        self.data_table.setColumnCount(len(col_names))
        self.data_table.setHorizontalHeaderLabels(col_names)
        self.data_table.setVerticalHeaderLabels(
            [sample.name for sample in self.__dataset.samples[start:end]])
        for row, sample in enumerate(self.__dataset.samples[start:end]):
            statistic = get_all_statistic(sample.classes_μm, sample.classes_φ,
                                          sample.distribution)
            if self.is_geometric:
                if self.is_FW57:
                    sub_key = "geometric_FW57"
                else:
                    sub_key = "geometric"
            else:
                if self.is_FW57:
                    sub_key = "logarithmic_FW57"
                else:
                    sub_key = "logarithmic"
            for col, (in_sub, key) in enumerate(col_keys):
                value = statistic[sub_key][key] if in_sub else statistic[key]
                if key == "modes":
                    write(row, col, ", ".join([f"{m:0.2f}" for m in value]))
                elif key[-11:] == "_proportion":
                    write(row, col,
                          ", ".join([f"{p*100:0.2f}" for p in value]))
                else:
                    write(row, col, value)

        self.data_table.resizeColumnsToContents()

    @property
    def selections(self):
        if self.__dataset.n_samples == 0:
            self.show_warning(self.tr("Dataset has not been loaded."))
            return []

        start = self.page_index * self.PAGE_ROWS
        temp = set()
        for item in self.data_table.selectedRanges():
            for i in range(item.topRow(),
                           min(self.PAGE_ROWS + 1,
                               item.bottomRow() + 1)):
                temp.add(i + start)
        indexes = list(temp)
        indexes.sort()
        samples = [self.__dataset.samples[i] for i in indexes]
        return samples

    def on_previous_button_clicked(self):
        if self.page_index > 0:
            self.current_page_combo_box.setCurrentIndex(self.page_index - 1)

    def on_next_button_clicked(self):
        if self.page_index < self.n_pages - 1:
            self.current_page_combo_box.setCurrentIndex(self.page_index + 1)

    def plot_chart(self, chart, samples, append):
        if len(samples) == 0:
            return
        chart.show_samples(samples, append=append)
        chart.show()

    def save_file(self, filename: str):
        wb = openpyxl.Workbook()
        prepare_styles(wb)

        ws = wb.active
        ws.title = self.tr("README")
        description = \
            """
            This Excel file was generated by QGrain ({0}).

            Please cite:
            Liu, Y., Liu, X., Sun, Y., 2021. QGrain: An open-source and easy-to-use software for the comprehensive analysis of grain size distributions. Sedimentary Geology 423, 105980. https://doi.org/10.1016/j.sedgeo.2021.105980

            It contanins one sheet:
            1. The sheet puts the statistic parameters and the classification groups of the samples.

            The statistic formulas are referred to Blott & Pye (2001)'s work.
            The classification of GSDs is referred to Folk (1957)'s and Blott & Pye (2012)'s scheme.

            References:
                1.Blott, S. J. & Pye, K. Particle size scales and classification of sediment types based on particle size distributions: Review and recommended procedures. Sedimentology 59, 2071–2096 (2012).
                2.Blott, S. J. & Pye, K. GRADISTAT: a grain-size distribution and statistics package for the analysis of unconsolidated sediments. Earth Surf. Process. Landforms 26, 1237–1248 (2001).
                3.Folk, R. L. The Distinction between Grain Size and Mineral Composition in Sedimentary-Rock Nomenclature. The Journal of Geology 62, 344–359 (1954).

            """.format(QGRAIN_VERSION)

        def write(row, col, value, style="normal_light"):
            cell = ws.cell(row + 1, col + 1, value=value)
            cell.style = style

        lines_of_desc = description.split("\n")
        for row, line in enumerate(lines_of_desc):
            write(row, 0, line, style="description")
        ws.column_dimensions[column_to_char(0)].width = 200

        ws = wb.create_sheet(self.tr("Parameters and Groups"))
        proportion_key, proportion_desciption = self.proportion
        col_names = [
            f"{self.tr('Mean')}[{self.unit}]",
            self.tr("Mean Desc."), f"{self.tr('Median')} [{self.unit}]",
            f"{self.tr('Modes')} [{self.unit}]",
            self.tr("STD (Sorting)"),
            self.tr("Sorting Desc."),
            self.tr("Skewness"),
            self.tr("Skew. Desc."),
            self.tr("Kurtosis"),
            self.tr("Kurt. Desc."),
            f"({proportion_desciption})\n{self.tr('Proportion')} [%]",
            self.tr("Group\n(Folk, 1954)"),
            self.tr("Group\nSymbol (Blott & Pye, 2012)"),
            self.tr("Group\n(Blott & Pye, 2012)")
        ]
        col_keys = [(True, "mean"), (True, "mean_description"),
                    (True, "median"), (True, "modes"), (True, "std"),
                    (True, "std_description"), (True, "skewness"),
                    (True, "skewness_description"), (True, "kurtosis"),
                    (True, "kurtosis_description"), (False, proportion_key),
                    (False, "group_Folk54"), (False, "group_BP12_symbol"),
                    (False, "group_BP12")]
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        for col, moment_name in enumerate(col_names, 1):
            write(0, col, moment_name, style="header")
            if col in (2, 4, 6, 8, 10, 11, 12, 14):
                ws.column_dimensions[column_to_char(col)].width = 30
            else:
                ws.column_dimensions[column_to_char(col)].width = 16
        ws.column_dimensions[column_to_char(len(col_names))].width = 40
        for row, sample in enumerate(self.__dataset.samples, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, sample.name, style=style)
            statistic = get_all_statistic(sample.classes_μm, sample.classes_φ,
                                          sample.distribution)
            if self.is_geometric:
                if self.is_FW57:
                    sub_key = "geometric_FW57"
                else:
                    sub_key = "geometric"
            else:
                if self.is_FW57:
                    sub_key = "logarithmic_FW57"
                else:
                    sub_key = "logarithmic"
            for col, (in_sub, key) in enumerate(col_keys, 1):
                value = statistic[sub_key][key] if in_sub else statistic[key]
                if key == "modes":
                    write(row,
                          col,
                          ", ".join([f"{m:0.4f}" for m in value]),
                          style=style)
                elif key[-11:] == "_proportion":
                    write(row,
                          col,
                          ", ".join([f"{p*100:0.4f}" for p in value]),
                          style=style)
                else:
                    write(row, col, value, style=style)

        wb.save(filename)
        wb.close()

    def on_save_clicked(self):
        if self.__dataset is None or self.__dataset.n_samples == 0:
            self.show_warning(self.tr("Dataset has not been loaded."))
            return

        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Select Filename"), None, "Excel (*.xlsx)")
        if filename is None or filename == "":
            return

        try:
            self.save_file(filename)
            self.show_info(
                self.tr(
                    "The summary of this dataset has been saved to:\n    {0}").
                format(filename))
        except Exception as e:
            self.show_error(
                self.tr(
                    "Error raised while save summary to Excel file.\n    {0}").
                format(e.__str__()))
예제 #7
0
class manager(confStack):
    def __init_stack__(self):
        self.dbg = False
        self._debug("manager load")
        self.description = (_("Air Apps Manager"))
        self.menu_description = (_("Manage air apps"))
        self.icon = ('dialog-password')
        self.tooltip = (_(
            "From here you can manage the air apps installed on your system"))
        self.index = 1
        self.enabled = True
        self.level = 'system'
        self.hideControlButtons()
        self.airinstaller = installer.AirManager()
        self.menu = App2Menu.app2menu()
        self.setStyleSheet(self._setCss())
        self.widget = ''

    #def __init__

    def _load_screen(self):
        box = QVBoxLayout()
        self.lst_airApps = QTableWidget(0, 1)
        self.lst_airApps.setShowGrid(False)
        self.lst_airApps.horizontalHeader().hide()
        self.lst_airApps.verticalHeader().hide()
        self.lst_airApps.horizontalHeader().setSectionResizeMode(
            QHeaderView.Stretch)
        self.lst_airApps.verticalHeader().setSectionResizeMode(
            QHeaderView.ResizeToContents)
        box.addWidget(self.lst_airApps)
        self.setLayout(box)
        self.updateScreen()
        return (self)

    #def _load_screen

    def updateScreen(self):
        self.lst_airApps.clear()
        apps = self.airinstaller.get_installed_apps()
        cont = 0
        for airapp, airinfo in apps.items():
            airCell = self._paintCell(airinfo)
            if airCell:
                self.lst_airApps.insertRow(cont)
                self.lst_airApps.setCellWidget(cont, 0, airCell)
                self.lst_airApps.resizeRowToContents(cont)
                cont += 1
        if cont == 0:
            self.lst_airApps.insertRow(0)
            lbl = QLabel(_("There's no app installed"))
            lbl.setStyleSheet("background:silver;border:0px;margin:0px")
            self.lst_airApps.setCellWidget(0, 0, lbl)
            cont += 1

        while (cont < self.lst_airApps.rowCount()):
            self.lst_airApps.removeRow(cont)
        self.lst_airApps.resizeColumnsToContents()

        return True

    #def _udpate_screen

    def _paintCell(self, airApp):
        widget = None
        if airApp:
            desktop = self.menu.get_desktop_info(airApp.get('desktop', ''))
            name = desktop.get('Name', '')
            if name:
                widget = airWidget()
                widget.setDesktop(airApp.get('desktop'))
                widget.remove.connect(self._removeAir)
                widget.setName(name)
                icon = desktop.get('Icon', '')
                widget.setIcon(icon)
                comment = desktop.get('Comment', '')
                widget.setDesc(comment)
                execute = desktop.get('Exec', '')
                widget.setExe(execute)
        return widget

    #def _paintCell

    def writeConfig(self):
        if self.widget == '':
            return
        subprocess.check_call(['/usr/bin/xhost', '+'])
        try:
            subprocess.check_call([
                'pkexec', '/usr/bin/air-helper-installer.py', 'remove',
                self.widget.getName(),
                self.widget.getDesktop()
            ])
        except Exception as e:
            print(e)
        subprocess.check_call(['/usr/bin/xhost', '-'])
        self.showMsg(_("App %s uninstalled" % self.widget.getName()))
        self.updateScreen()

    #def writeConfig

    def _removeAir(self, widget):
        self.widget = widget
        self.writeConfig()

    #def _removeAir

    def _setCss(self):
        css = """
		#cell{
			padding:10px;
			margin:6px;
			background-color:rgb(250,250,250);

		}
		#appName{
			font-weight:bold;
			border:0px;
		}
		#btnRemove{
			background:red;
			color:white;
			font-size:9pt;
			padding:3px;
			margin:3px;
		}
		
		"""

        return (css)
예제 #8
0
class Ventana(QtWidgets.QWidget):
    '''Constructor de la clase'''
    def __init__(self, vectorizer, parent=None):
        QtWidgets.QWidget.__init__(self, parent)
        self.vectorizer = vectorizer  #Variable para el entrenamiento de algoritmos
        #resultados de los mejores entrenamientos
        self.resultadoEntrenamiento = {}
        self.layoutReporte = QVBoxLayout()
        self.svcTitulo = QLabel(self)
        self.layoutReporte.addWidget(self.svcTitulo)
        self.svcF1Label = QLabel(self)
        self.layoutReporte.addWidget(self.svcF1Label)
        self.svcRecallLabel = QLabel(self)
        self.layoutReporte.addWidget(self.svcRecallLabel)
        self.svcPrecisionLabel = QLabel(self)
        self.layoutReporte.addWidget(self.svcPrecisionLabel)
        self.svcAccuracyLabel = QLabel(self)
        self.layoutReporte.addWidget(self.svcAccuracyLabel)
        self.svcMatrix = QLabel(self)
        self.layoutReporte.addWidget(self.svcMatrix)
        self.mlpTitulo = QLabel(self)
        self.layoutReporte.addWidget(self.mlpTitulo)
        self.mlpF1Label = QLabel(self)
        self.layoutReporte.addWidget(self.mlpF1Label)
        self.mlpRecallLabel = QLabel(self)
        self.layoutReporte.addWidget(self.mlpRecallLabel)
        self.mlpPrecisionLabel = QLabel(self)
        self.layoutReporte.addWidget(self.mlpPrecisionLabel)
        self.mlpAccuracyLabel = QLabel(self)
        self.layoutReporte.addWidget(self.mlpAccuracyLabel)
        self.mlpMatrix = QLabel(self)
        self.layoutReporte.addWidget(self.mlpMatrix)
        self.knnTitulo = QLabel(self)
        self.layoutReporte.addWidget(self.knnTitulo)
        self.knnF1Label = QLabel(self)
        self.layoutReporte.addWidget(self.knnF1Label)
        self.knnRecallLabel = QLabel(self)
        self.layoutReporte.addWidget(self.knnRecallLabel)
        self.knnPrecisionLabel = QLabel(self)
        self.layoutReporte.addWidget(self.knnPrecisionLabel)
        self.knnAccuracyLabel = QLabel(self)
        self.layoutReporte.addWidget(self.knnAccuracyLabel)
        self.knnMatrix = QLabel(self)
        self.layoutReporte.addWidget(self.knnMatrix)
        self.nbTitulo = QLabel(self)
        self.layoutReporte.addWidget(self.nbTitulo)
        self.nbF1Label = QLabel(self)
        self.layoutReporte.addWidget(self.nbF1Label)
        self.nbRecallLabel = QLabel(self)
        self.layoutReporte.addWidget(self.nbRecallLabel)
        self.nbPrecisionLabel = QLabel(self)
        self.layoutReporte.addWidget(self.nbPrecisionLabel)
        self.nbAccuracyLabel = QLabel(self)
        self.layoutReporte.addWidget(self.nbAccuracyLabel)
        self.nbMatrix = QLabel(self)
        self.layoutReporte.addWidget(self.nbMatrix)
        self.layoutClass = QVBoxLayout()
        self.svcClass = QLabel(self)
        self.layoutClass.addWidget(self.svcClass)
        self.mlpClass = QLabel(self)
        self.layoutClass.addWidget(self.mlpClass)
        self.knnClass = QLabel(self)
        self.layoutClass.addWidget(self.knnClass)
        self.nbClass = QLabel(self)
        self.layoutClass.addWidget(self.nbClass)
        self.layoutEntrenamiento = QHBoxLayout()
        #Botones del menú principal
        self.buttonBloqueFile = QPushButton("&Analizar Bloque desde archivo",
                                            self)
        self.buttonBloqueTwitter = QPushButton(
            "&Analizar Bloque desde Twitter.com", self)
        self.buttonUnTweet = QPushButton("&Analizar un Tweet", self)
        self.buttonEntrenar = QPushButton("&Entrenar algoritmos", self)
        #Añadir los botones al layout del menú
        self.layoutMenu = QHBoxLayout()
        self.layoutMenu.addWidget(self.buttonBloqueFile)
        self.layoutMenu.addWidget(self.buttonBloqueTwitter)
        self.layoutMenu.addWidget(self.buttonUnTweet)
        self.layoutMenu.addWidget(self.buttonEntrenar)
        #Layout donde irían los resultados
        self.layoutWidget = QVBoxLayout()
        self.infoLayout = QVBoxLayout(
        )  #Layout que muestra todos los widgets de mostrar un sólo tweet
        #Variables para mostrar un sólo tweet
        self.tituloLabel = QLabel(self)  #Etiqueta que muestra el título
        self.infoLayout.addWidget(self.tituloLabel)
        self.tweetLabel = QLabel(
            self)  #Etiqueta que muestra el texto del tweet
        self.infoLayout.addWidget(self.tweetLabel)
        self.nerLabel = QLabel(self)  #Etiqueta que muestra el título de NER
        self.infoLayout.addWidget(self.nerLabel)
        self.nonerLabel = QLabel(
            self)  #Etiqueta que muestra el aviso de que no hay NER
        self.infoLayout.addWidget(self.nonerLabel)
        self.tabla = QTableWidget()  #Tabla que muestra los resultados del NER
        self.infoLayout.addWidget(self.tabla)
        self.claLabel = QLabel(
            self
        )  #Etiqueta que muestra el título que precede a la tabla de sentimientos
        self.infoLayout.addWidget(self.claLabel)
        self.tablaSent = QTableWidget(
        )  #Tabla que muestra los sentimientos de cada algoritmo
        self.infoLayout.addWidget(self.tablaSent)
        #Variables para mostrar en entrenamiento
        self.entrenamientoLabel = QLabel(self)
        #Variables para la barra de progreso
        self.progressBarUnTweet = QProgressBar(self)
        self.progressLayout = QVBoxLayout()
        self.progresLabel = QLabel(self)
        self.progressLayout.addWidget(self.progresLabel)
        self.progressLayout.addWidget(self.progressBarUnTweet)
        #Elementos para NER
        self.nerLayout = QVBoxLayout()
        self.layoutWidget.addLayout(self.nerLayout)
        self.botones = []
        #Variables para la selección en cargar datos de twitter
        self.consultaText = ""
        self.consultaTweets = 0
        self.nerCantidadValor = 0
        #diálogo de archivo
        self.dialogo1 = QFileDialog(self)
        #Creación del layout principal que anida el resto de layouts
        self.layoutPrincipal = QVBoxLayout()
        self.layoutPrincipal.addLayout(self.layoutMenu)
        self.layoutPrincipal.addLayout(self.progressLayout)
        self.layoutPrincipal.addStretch()
        self.layoutPrincipal.addLayout(self.layoutWidget)
        self.setLayout(self.layoutPrincipal)
        #Diálogo para configurar parámetros de bloque de Twitter
        self.dialogConsulta = QInputDialog(self)
        self.dialogTweets = QInputDialog(self)
        self.nerCantidad = QInputDialog(self)
        # Conectar a analizarUnTweet:
        self.buttonUnTweet.clicked.connect(self.analizarUnTweet)
        # Conectar a entrenar_algoritmos
        self.buttonEntrenar.clicked.connect(self.entrenar_algoritmos)
        # Conectar a cargar datos de Twitter
        self.buttonBloqueTwitter.clicked.connect(self.cuadroDialogo)
        # Conectar a cargar datos de archivo
        self.buttonBloqueFile.clicked.connect(self.analizarDeArchivo)

    '''Reinicia el estado de la barra de estado'''

    def reiniciarEstado(self, maximo):
        self.progressBarUnTweet.reset()
        self.progressBarUnTweet.setMaximum(maximo)
        self.progressBarUnTweet.setMinimum(0)

    '''Actualiza el estado de la barra de estado y de la etiqueta de estado'''

    def actualizarEstado(self, porcentaje, etiqueta):
        self.progressBarUnTweet.setValue(porcentaje)
        self.progresLabel.setText(etiqueta)

    '''Función que oculta todos los widgets que hay en el layout que se le pasa por parámetro'''

    def limpiarLayout(self, layout):
        for i in reversed(range(layout.count())):
            layout.itemAt(i).widget().hide()

    '''Función que muestra todos los widgets que hay en el layout que se le pasa por parámetro'''

    def mostrarLayout(self, layout):
        for i in range(layout.count()):
            layout.itemAt(i).widget().show()

    '''Función que muestra los cuadros de diálogo para analizar un bloque de tweets desde twitter.com'''

    def cuadroDialogo(self):
        self.consultaText = self.dialogConsulta.getText(
            self, "Consulta de Twitter", "¿Sobre qué quieres buscar?")
        self.consultaTweets = self.dialogTweets.getInt(
            self, "Cuántos twits quieres usar",
            "La cantidad se multiplica por 100")
        self.nerCantidadValor = self.nerCantidad.getInt(
            self, "Cuántos twits quieres usar en NER",
            "Tweets que se usarán para NER")
        self.cargar_datos_de_twitter()

    '''Función que muestra un gráfico con los datos pasados por parámetros'''

    def mostrarUnGraph(self, entidad, dataframe, estado):
        if estado == 2:
            dataframe = dataframe[dataframe['text'].str.contains(entidad)]
        dataframe = self.tokenizar(dataframe)
        test_data = dataframe['final'][
            -dataframe.size:]  #saca sólo los últimos 100
        test_data = list(test_data.apply(' '.join))
        test_vectors = self.vectorizer.transform(test_data)
        self.mostrar_graph(self.predecir_Naive_Bayes(test_vectors, entidad),
                           self.predecir_SVC(test_vectors, entidad),
                           self.predecir_KNN(test_vectors, entidad),
                           self.predecir_MLP(test_vectors, entidad))

    '''Función que carga todos los botones correspondientes a las entidades reconocidas'''

    def buttonsNER(self, resultadoNER, df2):
        self.limpiarLayout(self.nerLayout)
        self.botones = []
        self.botones.append(QPushButton('Todas las entidades', self))
        self.nerLayout.addWidget(self.botones[-1])
        self.botones[-1].clicked.connect(
            lambda x='Todas las entidades': self.mostrarUnGraph(x, df2, 1))
        for i in resultadoNER:
            for j in i:
                self.botones.append(QPushButton(
                    j[0], self))  #Creo el botón y lo añado a la lista
                self.nerLayout.addWidget(
                    self.botones[-1])  #Añado el botón al layout
                self.botones[-1].clicked.connect(lambda x=j[
                    0]: self.mostrarUnGraph(x, df2, 2))  #Conecto el botón

    '''Función para configurar todos los datos de la API de Twitter'''

    def configurarAPITwitter(self, buscar, restoConsulta):
        consumer_key = 'ynSB0dFvqPl3xRU7AmYk39rGT'
        consumer_secret = '6alIXTKSxf0RE57QK3fDQ8dxdvlsVr1IRsHDZmoSlMx96YKBFD'
        access_token = '966591013182722049-BVXW14Hf5s6O2oIwS3vtJ3S3dOsKLbY'
        access_token_secret = '829DTKPjmwsSytmp1ky9fMCJkjV0LZ04TbL9oqHGV6cDm'
        q = self.consultaText[0] + restoConsulta  #parámetros de la consulta
        url = 'https://api.Twitter.com/1.1/search/tweets.json'
        pms = {'q': q, 'count': 100, 'lang': 'en', 'result_type': 'recent'}
        auth = OAuth1(consumer_key, consumer_secret, access_token,
                      access_token_secret)
        return {'auth': auth, 'pms': pms, 'url': url}

    '''Función que configura la base de datos y realiza el proceso de paginación'''

    def paginacionMongo(self, url, pms, auth, nombre, colection, cliente,
                        paginas):
        #inicialización de la base de datos para cargar los datos
        database_name = nombre
        collection_name = colection
        client = MongoClient(cliente)
        db = client[database_name]
        collection = db[collection_name]
        #Paginación (carga de 100 en 100 datos)
        pages_counter = 0
        number_of_pages = paginas
        while pages_counter < number_of_pages:
            pages_counter += 1
            res = requests.get(url, params=pms, auth=auth)
            tweets = res.json()
            ids = [i['id'] for i in tweets['statuses']]
            pms['max_id'] = min(ids) - 1
            collection.insert_many(tweets['statuses'])
        return collection

    '''Función para cargar datos de twitter directamente, lo almacena en una base de datos y lo devuelve en un dataframe. Se usa sólo para ver los resultados. No para entrenar.'''

    def cargar_datos_de_twitter(self):
        self.limpiarLayout(self.infoLayout)
        datosAPI = self.configurarAPITwitter(
            self.consultaText[0], ' -filter:retweets AND -filter:replies')
        collection = self.paginacionMongo(datosAPI['url'], datosAPI['pms'],
                                          datosAPI['auth'], "baseDeDatos",
                                          "coleccion",
                                          'mongodb://localhost:27017/',
                                          self.consultaTweets[0])
        #Pasar de la base de datos a un dataframe
        documents = []
        for doc in collection.find().skip(collection.count() -
                                          self.consultaTweets[0] * 100):
            documents.append(doc)
        df = pd.DataFrame(documents)
        #Limpieza de datos
        df = self.limpieza_de_datos_de_twitter(df)
        df2 = pd.DataFrame(data=df['text'][-self.consultaTweets[0] * 100:])
        dfNER = pd.DataFrame(data=df['text'])
        dfNER = self.tokenizar(dfNER)
        anNER = dfNER['final'][
            -self.nerCantidadValor[0]:]  #saca sólo los últimos 5
        resultadoNER = self.usar_NER(anNER, 3)
        self.buttonsNER(resultadoNER, df2)

    '''Función que analiza twits sacados de un archivo a elegir por el usuario'''

    def analizarDeArchivo(self):
        filename = self.dialogo1.getOpenFileName(
            self, "Selecciona el fichero a analizar", "/")
        filename = filename[0].split("/")
        filename = filename[-1]
        self.nerCantidadValor = self.nerCantidad.getInt(
            self, "Cuántos twits quieres usar en NER",
            "Tweets que se usarán para NER")
        self.reiniciarEstado(100)
        #filename = input("\tEscribe el nombre del fichero donde se encuentra el bloque de tweets: ") or 'bloque.csv'
        dataset = pd.read_csv(filename)
        df3 = pd.DataFrame(data=dataset['text'])
        dfNER = pd.DataFrame(data=dataset['text'])
        self.actualizarEstado(20, "Tokenizando")
        dfNER = self.tokenizar(dfNER)
        anNER = dfNER['final'][-self.nerCantidadValor[0]:]
        self.actualizarEstado(30, "Analizando NER")
        resultadoNER = self.usar_NER(anNER, 3)
        self.reiniciarEstado(100)
        self.actualizarEstado(90, "Distribuyendo botones NER")
        self.buttonsNER(resultadoNER, df3)
        self.actualizarEstado(100, "FINALIZADO")

    '''Función que analiza un tweet individual sacado de twitter.com'''

    def analizarUnTweet(self):
        #Barra de progreso
        self.consultaText = self.dialogConsulta.getText(
            self, "Consulta de Twitter", "¿Sobre qué quieres buscar?")
        self.reiniciarEstado(10)
        #self.layoutProgressBar.addWidget(self.progressBarUnTweet)
        datosAPI = self.configurarAPITwitter(
            self.consultaText[0], ' -filter:retweets AND -filter:replies')
        self.progressBarUnTweet.setValue(2)
        self.progresLabel.setText("Iniciando base de datos")
        collection = self.paginacionMongo(datosAPI['url'], datosAPI['pms'],
                                          datosAPI['auth'], "baseDeDatos",
                                          "coleccion",
                                          'mongodb://localhost:27017/', 1)
        self.progressBarUnTweet.setValue(3)
        self.progresLabel.setText("Guardando tweets en base de datos")
        #Pasar de la base de datos a un dataframe
        documents = []
        for doc in collection.find().skip(collection.count() - 10):
            documents.append(doc)
        df = pd.DataFrame(documents)
        mostrar = pd.DataFrame(documents)
        self.progressBarUnTweet.setValue(4)
        #Limpieza de datos
        df = self.limpieza_de_datos_de_twitter(df)
        df2 = pd.DataFrame(data=df['text'][-1:])
        dfNER = pd.DataFrame(data=df['text'][-1:])
        tweet = mostrar['text'][-1:]
        dfNER = self.tokenizar(dfNER)
        anNER = dfNER['final'][-1:]
        resultadoNER = self.usar_NER(anNER, 10)
        self.progressBarUnTweet.setValue(5)
        self.progresLabel.setText("Limpiando datos")
        df2 = self.tokenizar(df2)
        test_data = df2['final']
        test_data = list(test_data.apply(' '.join))
        test_vectors = self.vectorizer.transform(test_data)
        self.progressBarUnTweet.setValue(6)
        self.progresLabel.setText("Transformando datos datos")
        nb = self.predecir_Naive_Bayes(test_vectors, 'nada')
        self.progressBarUnTweet.setValue(7)
        self.progresLabel.setText("Naive Bayes")
        svc = self.predecir_SVC(test_vectors, 'nada')
        self.progressBarUnTweet.setValue(8)
        self.progresLabel.setText("SVC")
        knn = self.predecir_KNN(test_vectors, 'nada')
        self.progressBarUnTweet.setValue(9)
        self.progresLabel.setText("KNN")
        mlp = self.predecir_MLP(test_vectors, 'nada')
        self.progressBarUnTweet.setValue(10)
        self.progresLabel.setText("FINALIZADO")
        self.tituloLabel.setText(
            "<h1>ANÁLISIS DE SENTIMIENTOS EN UN TWEET INDIVIDUAL</h1>")
        self.tweetLabel.setText("<b>TWEET:</b> " + tweet.to_string())
        self.nerLabel.setText("<h2>Entidades</h2>")
        if resultadoNER == 0:
            self.nonerLabel.setText(
                "<b><font size=" + "3" + " color=" + "red" +
                ">NO SE RECONOCIERON LAS ENTIDADES</font></b>")
            self.tabla.clear()
        else:
            self.nonerLabel.setText(" ")
            self.tabla.setVerticalHeaderLabels(["prueba", "prueba2"])
            self.tabla.horizontalHeader().hide()
            self.tabla.verticalHeader().hide()
            self.tabla.setColumnCount(1)
            fila = 0
            filasTotales = 0
            for i in resultadoNER:
                filasTotales = len(i) + filasTotales
                self.tabla.setRowCount(filasTotales)
                for j in i:
                    columna1 = QTableWidgetItem(j[0])
                    self.tabla.setItem(fila, 0, columna1)
                    fila = fila + 1
        self.claLabel.setText("<h2>Clasificación de sentimientos</h2>")
        self.tablaSent.setColumnCount(2)
        self.tablaSent.setRowCount(4)
        self.tablaSent.horizontalHeader().hide()
        self.tablaSent.verticalHeader().hide()
        self.tablaSent.setVerticalHeaderLabels(['Clasificador', 'Resultado'])
        self.tablaSent.setItem(0, 0, QTableWidgetItem("Naive Bayes"))
        if nb[1][0] == 1:
            self.tablaSent.setItem(0, 1, QTableWidgetItem("Positivo"))
        elif nb[1][1] == 1:
            self.tablaSent.setItem(0, 1, QTableWidgetItem("Neutro"))
        elif nb[1][2] == 1:
            self.tablaSent.setItem(0, 1, QTableWidgetItem("Negativo"))
        self.tablaSent.setItem(1, 0, QTableWidgetItem("Clasificador SVC"))
        if svc[1][0] == 1:
            self.tablaSent.setItem(1, 1, QTableWidgetItem("Positivo"))
        elif svc[1][1] == 1:
            self.tablaSent.setItem(1, 1, QTableWidgetItem("Neutro"))
        elif svc[1][2] == 1:
            self.tablaSent.setItem(1, 1, QTableWidgetItem("Negativo"))
        self.tablaSent.setItem(2, 0,
                               QTableWidgetItem("Clasificador K-Neighbors"))
        if knn[1][0] == 1:
            self.tablaSent.setItem(2, 1, QTableWidgetItem("Positivo"))
        elif knn[1][1] == 1:
            self.tablaSent.setItem(2, 1, QTableWidgetItem("Neutro"))
        elif knn[1][2] == 1:
            self.tablaSent.setItem(2, 1, QTableWidgetItem("Negativo"))
        self.tablaSent.setItem(3, 0, QTableWidgetItem("Clasificador MLP"))
        if mlp[1][0] == 1:
            self.tablaSent.setItem(3, 1, QTableWidgetItem("Positivo"))
        elif mlp[1][1] == 1:
            self.tablaSent.setItem(3, 1, QTableWidgetItem("Neutro"))
        elif mlp[1][2] == 1:
            self.tablaSent.setItem(3, 1, QTableWidgetItem("Negativo"))
        self.layoutWidget.addLayout(self.infoLayout)
        self.mostrarLayout(self.infoLayout)

    '''Función que tokeniza los datos de un tweet, eliminando las stopwords y los caracteres especiales'''

    def tokenizar(self, df):
        #TOKENIZATION inicial para NER
        df.loc[:, 'tokens'] = df['text'].apply(TweetTokenizer().tokenize)
        #STOPWORDS
        stopwords_vocabulary = stopwords.words('english')  #estará en español?
        df.loc[:, 'stopwords'] = df['tokens'].apply(
            lambda x: [i for i in x if i.lower() not in stopwords_vocabulary])
        #SPECIAL CHARACTERS AND STOPWORDS REMOVAL
        punctuations = list(string.punctuation)
        df.loc[:, 'punctuation'] = df['stopwords'].apply(
            lambda x: [i for i in x if i not in punctuations])
        df.loc[:, 'digits'] = df['punctuation'].apply(
            lambda x: [i for i in x if i[0] not in list(string.digits)])
        df.loc[:, 'final'] = df['digits'].apply(
            lambda x: [i for i in x if len(i) > 1])
        return df

    '''Función que recibe un dataframe con tweets de twitter y los deja preparados para ser tokenizados'''

    def limpieza_de_datos_de_twitter(self, df):
        df['tweet_source'] = df['source'].apply(
            lambda x: BeautifulSoup(x).get_text())
        devices = list(
            set(df[df['tweet_source'].str.startswith('Twitter')]
                ['tweet_source']))
        df = df[df['tweet_source'].isin(devices)]
        return df

    '''Funciones de predicción de los diferentes algoritmos para los diferentes modelos'''

    def predecir_Naive_Bayes(self, test_vectors, it):
        mod = MultinomialNB()
        file = open('NaiveBayes', 'rb')
        mod = load(file)
        result = mod.predict(test_vectors)
        pos = len(
            result[result == 4])  #guardamos la cantidad de tweets positivos
        neg = len(
            result[result == 0])  #guardamos la cantidad de tweets negativos
        neu = len(
            result[result == 2])  #guardamos la cantidad de tweets neutros
        y = [
            pos, neu, neg
        ]  # vector de la cantidad de tweets positivos, negativos y neutros
        return (it, y)

    def predecir_SVC(self, test_vectors, it):
        mod = SVC()
        file = open('Svc', 'rb')
        mod = load(file)
        result = mod.predict(test_vectors)
        pos = len(
            result[result == 4])  #guardamos la cantidad de tweets positivos
        neg = len(
            result[result == 0])  #guardamos la cantidad de tweets negativos
        neu = len(
            result[result == 2])  #guardamos la cantidad de tweets neutros
        y = [
            pos, neu, neg
        ]  # vector de la cantidad de tweets positivos, negativos y neutros
        return (it, y)

    def predecir_KNN(self, test_vectors, it):
        mod = KNeighborsClassifier()
        file = open('Knn', 'rb')
        mod = load(file)
        result = mod.predict(test_vectors)
        pos = len(
            result[result == 4])  #guardamos la cantidad de tweets positivos
        neg = len(
            result[result == 0])  #guardamos la cantidad de tweets negativos
        neu = len(
            result[result == 2])  #guardamos la cantidad de tweets neutros
        y = [
            pos, neu, neg
        ]  # vector de la cantidad de tweets positivos, negativos y neutros
        return (it, y)

    def predecir_MLP(self, test_vectors, it):
        mod = MLPClassifier()
        file = open('Mlp', 'rb')
        mod = load(file)
        result = mod.predict(test_vectors)
        pos = len(
            result[result == 4])  #guardamos la cantidad de tweets positivos
        neg = len(
            result[result == 0])  #guardamos la cantidad de tweets negativos
        neu = len(
            result[result == 2])  #guardamos la cantidad de tweets neutros
        y = [
            pos, neu, neg
        ]  # vector de la cantidad de tweets positivos, negativos y neutros
        return (it, y)

    '''Función que muestra los gráficos utilizando los plots de python'''

    def mostrar_graph(self, NB, SVC, KNN, MLP):
        plt.figure(figsize=(9, 7))
        #Naive Bayes
        plt.subplot(221)
        plt.title("NB para la entidad " + NB[0])
        plt.ylabel('tweets')
        plt.xticks(range(len(NB[1])), ['positive', 'neutral', 'negative'])
        plt.bar(range(len(NB[1])),
                height=NB[1],
                width=0.75,
                align='center',
                alpha=0.8)
        #SVC
        plt.subplot(222)
        plt.title("SVC para la entidad " + SVC[0])
        plt.ylabel('tweets')
        plt.xticks(range(len(SVC[1])), ['positive', 'neutral', 'negative'])
        plt.bar(range(len(SVC[1])),
                height=SVC[1],
                width=0.75,
                align='center',
                alpha=0.8)
        #KNN
        plt.subplot(223)
        plt.title("KNN para la entidad " + KNN[0])
        plt.ylabel('tweets')
        plt.xticks(range(len(KNN[1])), ['positive', 'neutral', 'negative'])
        plt.bar(range(len(KNN[1])),
                height=KNN[1],
                width=0.75,
                align='center',
                alpha=0.8)
        #MLP
        plt.subplot(224)
        plt.title("MLP para la entidad " + MLP[0])
        plt.ylabel('tweets')
        plt.xticks(range(len(MLP[1])), ['positive', 'neutral', 'negative'])
        plt.bar(range(len(MLP[1])),
                height=MLP[1],
                width=0.75,
                align='center',
                alpha=0.8)
        plt.show()

    '''Función que utiliza NER para detectar entidades.'''

    def usar_NER(self, tweetys, n):
        self.reiniciarEstado(len(tweetys) * 10)
        self.actualizarEstado(1, "ANALIZANDO ENTIDADES: ")
        st = StanfordNERTagger(
            r'C:\Users\Servicio Técnico\Documents\stanford-ner-2018-02-27\classifiers\english.all.3class.distsim.crf.ser.gz'
        )
        #st = StanfordNERTagger('/Users/jonas/stanford-ner-2018-02-27/classifiers/english.all.3class.distsim.crf.ser.gz')
        #Recuerda de que cambia para el mac que es donde vas a realizar la presentación
        entities = []
        tindice = 0
        for r in tweetys:
            PySide2.QtWidgets.QApplication.processEvents()
            lst_tags = st.tag(
                r)  #no tengo que hacer el split porque ya está hecho?
            for tup in lst_tags:
                PySide2.QtWidgets.QApplication.processEvents()
                self.actualizarEstado(tindice,
                                      "ANALIZANDO ENTIDADES EN: " + str(r))
                tindice = tindice + 1
                if (tup[1] != 'O'):
                    entities.append(tup)
        df_entities = pd.DataFrame(entities)
        self.actualizarEstado(len(tweetys) * 10, "FINALIZADO")
        if df_entities.size > 0:
            df_entities.columns = ["word", "ner"]
            #Organizaciones
            organizations = df_entities[df_entities['ner'].str.contains(
                "ORGANIZATION")]
            cnt = Counter(organizations['word'])
            organizaciones = cnt.most_common(n)
            #Personas
            person = df_entities[df_entities['ner'].str.contains("PERSON")]
            cnt_person = Counter(person['word'])
            personas = cnt_person.most_common(n)
            #Localizaciones
            locations = df_entities[df_entities['ner'].str.contains(
                "LOCATION")]
            cnt_location = Counter(locations['word'])
            lugares = cnt_location.most_common(n)
            return (organizaciones, personas, lugares)
        else:
            return 0

    '''Función que muestra el reporte del entrenamiento'''

    def mostrarReporte(self):
        self.limpiarLayout(self.layoutClass)
        self.limpiarLayout(self.layoutReporte)
        self.limpiarLayout(self.layoutEntrenamiento)

        self.svcTitulo.setText("<h2>RESULTADOS SVC</h2>")
        self.nbTitulo.setText("<h2>RESULTADOS NAIBE BAYES</h2>")
        self.knnTitulo.setText("<h2>RESULTADOS K-NEIGHBORS NEAREST</h2>")
        self.mlpTitulo.setText("<h2>RESULTADOS MLP</h2>")

        self.svcMatrix.setText(
            str(self.resultadoEntrenamiento['svc']['matrix']))
        self.nbMatrix.setText(str(self.resultadoEntrenamiento['nb']['matrix']))
        self.knnMatrix.setText(
            str(self.resultadoEntrenamiento['knn']['matrix']))
        self.mlpMatrix.setText(
            str(self.resultadoEntrenamiento['mlp']['matrix']))

        self.svcF1Label.setText("Puntuación F1 SVC: " +
                                str(self.resultadoEntrenamiento['svc']['f1']))
        self.svcRecallLabel.setText(
            "Recall SVC: " + str(self.resultadoEntrenamiento['svc']['recall']))
        self.svcPrecisionLabel.setText(
            "Precisión SVC: " +
            str(self.resultadoEntrenamiento['svc']['precisión']))
        self.svcAccuracyLabel.setText(
            "Puntuación Total SVC: " +
            str(self.resultadoEntrenamiento['svc']['puntuación']))
        self.mlpF1Label.setText("Puntuación F1 MLP: " +
                                str(self.resultadoEntrenamiento['mlp']['f1']))
        self.mlpRecallLabel.setText(
            "Recall MLP: " + str(self.resultadoEntrenamiento['mlp']['recall']))
        self.mlpPrecisionLabel.setText(
            "Precisión MLP: " +
            str(self.resultadoEntrenamiento['mlp']['precisión']))
        self.mlpAccuracyLabel.setText(
            "Puntuación Total MLP: " +
            str(self.resultadoEntrenamiento['mlp']['puntuación']))
        self.knnF1Label.setText("Puntuación F1 KNN: " +
                                str(self.resultadoEntrenamiento['knn']['f1']))
        self.knnRecallLabel.setText(
            "Recall KNN: " + str(self.resultadoEntrenamiento['knn']['recall']))
        self.knnPrecisionLabel.setText(
            "Precisión KNN: " +
            str(self.resultadoEntrenamiento['knn']['precisión']))
        self.knnAccuracyLabel.setText(
            "Puntuación Total KNN: " +
            str(self.resultadoEntrenamiento['knn']['puntuación']))
        self.nbF1Label.setText("Puntuación F1 NB: " +
                               str(self.resultadoEntrenamiento['nb']['f1']))
        self.nbRecallLabel.setText(
            "Recall NB: " + str(self.resultadoEntrenamiento['nb']['recall']))
        self.nbPrecisionLabel.setText(
            "Precisión NB: " +
            str(self.resultadoEntrenamiento['nb']['precisión']))
        self.nbAccuracyLabel.setText(
            "Puntuación Total NB: " +
            str(self.resultadoEntrenamiento['nb']['puntuación']))

        self.svcClass.setText(
            str(self.resultadoEntrenamiento['svc']['clasificación']))
        self.mlpClass.setText(
            str(self.resultadoEntrenamiento['mlp']['clasificación']))
        self.nbClass.setText(
            str(self.resultadoEntrenamiento['nb']['clasificación']))
        self.knnClass.setText(
            str(self.resultadoEntrenamiento['knn']['clasificación']))

        #self.setLayout(self.layoutReporte)
        self.layoutWidget.addLayout(self.layoutEntrenamiento)
        self.layoutEntrenamiento.addLayout(self.layoutReporte)
        self.layoutEntrenamiento.addLayout(self.layoutClass)
        self.mostrarLayout(self.layoutReporte)
        self.mostrarLayout(self.layoutClass)
        print("Se muestra el reporte")

    '''Función que entrena todos los algoritmos utilizando datos de ficheros de entrenamiento y de test. En la misma función se limpian los datos tokenizados. Al final detecta cuál es la mejor configuración para el algoritmo y los entrena con dicha configuración'''

    def entrenar_algoritmos(self):
        #self.infoLayout.hide()
        self.reiniciarEstado(438)
        filename2 = self.dialogo1.getOpenFileName(
            self, "Selecciona el fichero de entrenamiento", "/")
        filename2 = filename2[0].split("/")
        filename2 = filename2[-1]
        filename = self.dialogo1.getOpenFileName(
            self, "Selecciona el fichero de pruebas", "/")
        filename = filename[0].split("/")
        filename = filename[-1]
        dataset = pd.read_csv(filename)
        tweetys = dataset['text']  #Para mostrar el tweet?
        prueba = pd.read_csv(filename)  #para la última parte del NER combinado
        dataset2 = pd.read_csv(filename2)
        #CLEANING DATASET
        #TOKENIZATION
        dataset['tokens'] = dataset['text'].apply(TweetTokenizer().tokenize)
        #STOPWORDS
        stopwords_vocabulary = stopwords.words('english')
        dataset['stopwords'] = dataset['tokens'].apply(
            lambda x: [i for i in x if i.lower() not in stopwords_vocabulary])
        #SPECIAL CHARACTERS AND STOPWORDS REMOVAL
        punctuations = list(string.punctuation)
        dataset['punctuation'] = dataset['stopwords'].apply(
            lambda x: [i for i in x if i not in punctuations])
        dataset['digits'] = dataset['punctuation'].apply(
            lambda x: [i for i in x if i[0] not in list(string.digits)])
        dataset['final'] = dataset['digits'].apply(
            lambda x: [i for i in x if len(i) > 1])
        self.progressBarUnTweet.setValue(1)
        self.progresLabel.setText("Limpiando datos fichero entrenamiento")
        #CLEANING DATASET2
        #TOKENIZATION
        dataset2['tokens'] = dataset2['text'].apply(TweetTokenizer().tokenize)
        #STOPWORDS
        stopwords_vocabulary = stopwords.words('english')  #estará en español?
        dataset2['stopwords'] = dataset2['tokens'].apply(
            lambda x: [i for i in x if i.lower() not in stopwords_vocabulary])
        #SPECIAL CHARACTERS AND STOPWORDS REMOVAL
        punctuations = list(string.punctuation)
        dataset2['punctuation'] = dataset2['stopwords'].apply(
            lambda x: [i for i in x if i not in punctuations])
        dataset2['digits'] = dataset2['punctuation'].apply(
            lambda x: [i for i in x if i[0] not in list(string.digits)])
        dataset2['final'] = dataset2['digits'].apply(
            lambda x: [i for i in x if len(i) > 1])
        self.progressBarUnTweet.setValue(2)
        self.progresLabel.setText("Limpiando datos fichero de pruebas")
        #Here is the place where we set the number of tweets that we use to models. Always whit 80:20 percent.
        train_data = dataset2['final'][0:500]
        train_labels = dataset2['label'][0:500]
        test_data = dataset['final'][0:125]
        test_labels = dataset['label'][0:125]
        train_data = list(train_data.apply(' '.join))
        test_data = list(test_data.apply(' '.join))
        self.progressBarUnTweet.setValue(3)
        self.progresLabel.setText("Actualizando datos de entrenamiento")
        #Preparing data for models
        train_vectors = self.vectorizer.fit_transform(train_data)
        test_vectors = self.vectorizer.transform(test_data)
        fvecto = open('fvecto', 'wb')
        dump(self.vectorizer, fvecto)
        #Analisys vectors:
        modelos = ['NaiveBayes', 'Svc', 'Knn', 'Mlp']
        puntuaciones = [0, 0, 0, 0]
        params_svc = [['linear', 'poly', 'tbf', 'sigmod', 'precomputed'],
                      [3, 5, 10], [0.1, 0.5, 0.9], [True, False],
                      [True, False]]
        best_svc = []
        params_knn = [[1, 5, 10], ['uniform', 'distance'],
                      ['ball_tree', 'kd_tree', 'brute', 'auto'], [5, 30, 100],
                      [1, 2]]
        best_knn = []
        params_mlp = [[50, 100, 150], ['identity', 'logistic', 'tanh', 'relu'],
                      [0.00005, 0.0001, 0.001],
                      ['constant', 'invscaling', 'adaptative']]
        best_mlp = []
        self.progressBarUnTweet.setValue(4)
        self.progresLabel.setText("Preparando parámetros de algoritmos")
        #TRAINING ALGORITHMs
        progreso = 5
        for alg in modelos:
            PySide2.QtWidgets.QApplication.processEvents()
            if alg == 'Svc':
                for a in params_svc[0]:
                    PySide2.QtWidgets.QApplication.processEvents()
                    for b in params_svc[1]:
                        PySide2.QtWidgets.QApplication.processEvents()
                        for c in params_svc[2]:
                            PySide2.QtWidgets.QApplication.processEvents()
                            for d in params_svc[3]:
                                PySide2.QtWidgets.QApplication.processEvents()
                                for e in params_svc[4]:
                                    PySide2.QtWidgets.QApplication.processEvents(
                                    )
                                    mod = SVC(kernel=a,
                                              degree=b,
                                              coef0=c,
                                              probability=d,
                                              shrinking=e)
                                    punt = self.entrenar(
                                        1, alg, train_vectors, train_labels,
                                        test_vectors, test_labels, mod)
                                    self.progressBarUnTweet.setValue(progreso)
                                    progreso = progreso + 1
                                    self.progresLabel.setText(
                                        "Entrenando SVC con kernel " + a)
                                    if punt > puntuaciones[0]:
                                        puntuaciones[0] = punt
                                        best_svc = [a, b, c, d, e]
            elif alg == 'NaiveBayes':
                mod = MultinomialNB()
                puntuaciones[1] = self.entrenar(1, alg, train_vectors,
                                                train_labels, test_vectors,
                                                test_labels, mod)
            elif alg == 'Knn':
                for a in params_knn[0]:
                    PySide2.QtWidgets.QApplication.processEvents()
                    for b in params_knn[1]:
                        PySide2.QtWidgets.QApplication.processEvents()
                        for c in params_knn[2]:
                            PySide2.QtWidgets.QApplication.processEvents()
                            for d in params_knn[3]:
                                PySide2.QtWidgets.QApplication.processEvents()
                                for e in params_knn[4]:
                                    PySide2.QtWidgets.QApplication.processEvents(
                                    )
                                    self.progressBarUnTweet.setValue(progreso)
                                    self.progresLabel.setText(
                                        "Entrenando KNN con kernel " + b + c)
                                    progreso = progreso + 1
                                    mod = KNeighborsClassifier(n_neighbors=a,
                                                               weights=b,
                                                               algorithm=c,
                                                               leaf_size=d,
                                                               p=e)
                                    punt = self.entrenar(
                                        1, alg, train_vectors, train_labels,
                                        test_vectors, test_labels, mod)
                                    if punt > puntuaciones[2]:
                                        puntuaciones[2] = punt
                                        best_knn = [a, b, c, d, e]
            elif alg == 'Mlp':
                for a in params_mlp[0]:
                    PySide2.QtWidgets.QApplication.processEvents()
                    for b in params_mlp[1]:
                        PySide2.QtWidgets.QApplication.processEvents()
                        for c in params_mlp[2]:
                            PySide2.QtWidgets.QApplication.processEvents()
                            for d in params_mlp[3]:
                                PySide2.QtWidgets.QApplication.processEvents()
                                self.progressBarUnTweet.setValue(progreso)
                                self.progresLabel.setText(
                                    "Entrenando MLP con kernel " + b + d)
                                progreso = progreso + 1
                                mod = MLPClassifier(hidden_layer_sizes=a,
                                                    activation=b,
                                                    alpha=c,
                                                    learning_rate=d)
                                punt = self.entrenar(1, alg, train_vectors,
                                                     train_labels,
                                                     test_vectors, test_labels,
                                                     mod)
                                if punt > puntuaciones[3]:
                                    puntuaciones[3] = punt
                                    best_mlp = [a, b, c, d]
        #Encontrar el mejor modelo de todos
        tmp = 0
        guia = 0
        for h in puntuaciones:
            if h > tmp:
                best_model = guia
                tmp = h
            guia = guia + 1
        self.progressBarUnTweet.setValue(progreso)
        progreso = progreso + 1
        self.resultadoEntrenamiento['svc'] = self.entrenar(
            2, 'Svc', train_vectors, train_labels, test_vectors, test_labels,
            SVC(kernel=best_svc[0],
                degree=best_svc[1],
                coef0=best_svc[2],
                probability=best_svc[3],
                shrinking=best_svc[4]))
        self.resultadoEntrenamiento['nb'] = self.entrenar(2,
                                                          'NaiveBayes',
                                                          train_vectors,
                                                          train_labels,
                                                          test_vectors,
                                                          test_labels,
                                                          mod=MultinomialNB())
        self.resultadoEntrenamiento['knn'] = self.entrenar(
            2, 'Knn', train_vectors, train_labels, test_vectors, test_labels,
            KNeighborsClassifier(n_neighbors=best_knn[0],
                                 weights=best_knn[1],
                                 algorithm=best_knn[2],
                                 leaf_size=best_knn[3],
                                 p=best_knn[4]))
        self.resultadoEntrenamiento['mlp'] = self.entrenar(
            2, 'Mlp', train_vectors, train_labels, test_vectors, test_labels,
            MLPClassifier(hidden_layer_sizes=best_mlp[0],
                          activation=best_mlp[1],
                          alpha=best_mlp[2],
                          learning_rate=best_mlp[3]))
        self.progressBarUnTweet.setValue(progreso)
        self.progresLabel.setText("FINALIZADO")
        self.entrenamientoLabel.setText(
            "<h1>ENTRENAMIENTO REALIZADO CON ÉXITO</h1>")
        self.layoutWidget.addWidget(self.entrenamientoLabel)
        self.mostrarReporte()

    '''Función auxiliar para usar_NER que guarda los archivos de entrenamiento para que se guarden en futuras sesiones'''

    def entrenar(self, opc, alg, train_vectors, train_labels, test_vectors,
                 test_labels, mod):
        nfile = alg
        if path.exists(nfile):
            file = open(nfile, 'rb')  #abre el archivo en modo lectura
            mod = load(file)  #carga el archivo en la variable
            mod.fit(train_vectors,
                    train_labels).score(test_vectors, test_labels)  #lo entrena
            file.close()  #cierra el archivo
            file = open(nfile, 'wb')  #abre el archivo en modo escritura
            dump(mod, file)  #actualiza el entrenamiento
        else:
            file = open(nfile, 'wb')  #abre el archivo en modo escritura
            mod.fit(train_vectors,
                    train_labels).score(test_vectors, test_labels)  #lo entrena
            dump(mod, file)  #guarda el entrenamiento
        predicted = cross_val_predict(mod, test_vectors, test_labels, cv=10)
        if opc == 1:
            return accuracy_score(test_labels, predicted)
        elif opc == 2:
            return {
                'clasificación':
                classification_report(test_labels, mod.predict(test_vectors)),
                'matrix':
                confusion_matrix(test_labels, mod.predict(test_vectors)),
                'puntuación':
                accuracy_score(test_labels, predicted),
                'f1':
                f1_score(test_labels,
                         mod.predict(test_vectors),
                         average='macro'),
                'recall':
                recall_score(test_labels,
                             mod.predict(test_vectors),
                             average='macro'),
                'precisión':
                precision_score(test_labels,
                                mod.predict(test_vectors),
                                average='macro')
            }
예제 #9
0
class TableWidget(QWidget):
    def __init__(self):
        '''
        # TODO
        '''
        super().__init__()

        # Data
        self.values = []
        self.values_view = []
        self.labels = []

        # Search bar
        self.search_label = QLabel('Cerca')
        self.search_bar = QLineEdit()
        self.search_bar.setSizePolicy(QSizePolicy.MinimumExpanding,
                                      QSizePolicy.Minimum)
        self.search_bar.textChanged[str].connect(self.update_table)
        self.search_layout = QHBoxLayout()
        self.search_layout.addWidget(self.search_label)
        self.search_layout.addWidget(self.search_bar)

        # Table
        self.table = QTableWidget()
        self.item_count = 0

        # Layout
        self.layout = QtWidgets.QVBoxLayout()
        self.layout.addLayout(self.search_layout)
        self.layout.addWidget(self.table)
        self.setLayout(self.layout)

    def populate(self, labels: list, values: list):
        '''
        Method used to populate the table with the results of a query
        Parameters:
        labels : list -- The list of labels.
        values : list -- A list of tuples containing the values to display. len(values[i]) = len(labels) for each i
        '''
        self.labels = labels
        self.values = values
        self.update_table()

    def update_table(self):
        '''
        Updates the table by filtering the rows and reloading the table
        '''
        self.filter_rows()
        self.load_table()

    def load_table(self):
        self.table.clear()
        while self.table.rowCount() > 0:
            self.table.removeRow(0)
        self.item_count = 0
        self.table.setColumnCount(len(self.labels))
        self.table.setHorizontalHeaderLabels(self.labels)
        self.table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        for row in self.values_view:
            self.table.insertRow(self.item_count)
            for i in range(len(self.labels)):
                self.table.setItem(self.item_count, i,
                                   QTableWidgetItem(str(row[i])))
            self.item_count += 1

    def filter_rows(self):
        '''
        Filters the data to display based on the string in self.search_bar
        '''
        self.values_view = []
        for row in self.values:
            if self.search_bar.text().lower() in str(row).lower():
                self.values_view.append(row)
예제 #10
0
class FittingResultViewer(QDialog):
    PAGE_ROWS = 20
    logger = logging.getLogger("root.QGrain.ui.FittingResultViewer")
    result_marked = Signal(SSUResult)

    def __init__(self, reference_viewer: ReferenceResultViewer, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Fitting Result Viewer"))
        self.__fitting_results = []  # type: list[SSUResult]
        self.retry_tasks = {}  # type: dict[UUID, SSUTask]
        self.__reference_viewer = reference_viewer
        self.init_ui()
        self.boxplot_chart = BoxplotChart(parent=self, toolbar=True)
        self.typical_chart = SSUTypicalComponentChart(parent=self,
                                                      toolbar=True)
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_fitting_succeeded)
        self.async_worker.background_worker.task_failed.connect(
            self.on_fitting_failed)
        self.update_page_list()
        self.update_page(self.page_index)

        self.normal_msg = QMessageBox(self)
        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))
        self.outlier_msg = QMessageBox(self)
        self.outlier_msg.setStandardButtons(QMessageBox.Discard
                                            | QMessageBox.Retry
                                            | QMessageBox.Ignore)
        self.outlier_msg.setDefaultButton(QMessageBox.Ignore)
        self.retry_progress_msg = QMessageBox()
        self.retry_progress_msg.addButton(QMessageBox.Ok)
        self.retry_progress_msg.button(QMessageBox.Ok).hide()
        self.retry_progress_msg.setWindowTitle(self.tr("Progress"))
        self.retry_timer = QTimer(self)
        self.retry_timer.setSingleShot(True)
        self.retry_timer.timeout.connect(
            lambda: self.retry_progress_msg.exec_())

    def init_ui(self):
        self.data_table = QTableWidget(100, 100)
        self.data_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.data_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.data_table.setAlternatingRowColors(True)
        self.data_table.setContextMenuPolicy(Qt.CustomContextMenu)
        self.main_layout = QGridLayout(self)
        self.main_layout.addWidget(self.data_table, 0, 0, 1, 3)

        self.previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Previous"))
        self.previous_button.setToolTip(
            self.tr("Click to back to the previous page."))
        self.previous_button.clicked.connect(self.on_previous_button_clicked)
        self.current_page_combo_box = QComboBox()
        self.current_page_combo_box.addItem(self.tr("Page {0}").format(1))
        self.current_page_combo_box.currentIndexChanged.connect(
            self.update_page)
        self.next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                       self.tr("Next"))
        self.next_button.setToolTip(self.tr("Click to jump to the next page."))
        self.next_button.clicked.connect(self.on_next_button_clicked)
        self.main_layout.addWidget(self.previous_button, 1, 0)
        self.main_layout.addWidget(self.current_page_combo_box, 1, 1)
        self.main_layout.addWidget(self.next_button, 1, 2)

        self.distance_label = QLabel(self.tr("Distance"))
        self.distance_label.setToolTip(
            self.
            tr("It's the function to calculate the difference (on the contrary, similarity) between two samples."
               ))
        self.distance_combo_box = QComboBox()
        self.distance_combo_box.addItems(built_in_distances)
        self.distance_combo_box.setCurrentText("log10MSE")
        self.distance_combo_box.currentTextChanged.connect(
            lambda: self.update_page(self.page_index))
        self.main_layout.addWidget(self.distance_label, 2, 0)
        self.main_layout.addWidget(self.distance_combo_box, 2, 1, 1, 2)
        self.menu = QMenu(self.data_table)
        self.menu.setShortcutAutoRepeat(True)
        self.mark_action = self.menu.addAction(
            qta.icon("mdi.marker-check"),
            self.tr("Mark Selection(s) as Reference"))
        self.mark_action.triggered.connect(self.mark_selections)
        self.remove_selection_action = self.menu.addAction(
            qta.icon("fa.remove"), self.tr("Remove Selection(s)"))
        self.remove_selection_action.triggered.connect(self.remove_selections)
        self.remove_all_action = self.menu.addAction(qta.icon("fa.remove"),
                                                     self.tr("Remove All"))
        self.remove_all_action.triggered.connect(self.remove_all_results)
        self.plot_loss_chart_action = self.menu.addAction(
            qta.icon("mdi.chart-timeline-variant"), self.tr("Plot Loss Chart"))
        self.plot_loss_chart_action.triggered.connect(self.show_distance)
        self.plot_distribution_chart_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"), self.tr("Plot Distribution Chart"))
        self.plot_distribution_chart_action.triggered.connect(
            self.show_distribution)
        self.plot_distribution_animation_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"),
            self.tr("Plot Distribution Chart (Animation)"))
        self.plot_distribution_animation_action.triggered.connect(
            self.show_history_distribution)

        self.detect_outliers_menu = self.menu.addMenu(
            qta.icon("mdi.magnify"), self.tr("Detect Outliers"))
        self.check_nan_and_inf_action = self.detect_outliers_menu.addAction(
            self.tr("Check NaN and Inf"))
        self.check_nan_and_inf_action.triggered.connect(self.check_nan_and_inf)
        self.check_final_distances_action = self.detect_outliers_menu.addAction(
            self.tr("Check Final Distances"))
        self.check_final_distances_action.triggered.connect(
            self.check_final_distances)
        self.check_component_mean_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Mean"))
        self.check_component_mean_action.triggered.connect(
            lambda: self.check_component_moments("mean"))
        self.check_component_std_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component STD"))
        self.check_component_std_action.triggered.connect(
            lambda: self.check_component_moments("std"))
        self.check_component_skewness_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Skewness"))
        self.check_component_skewness_action.triggered.connect(
            lambda: self.check_component_moments("skewness"))
        self.check_component_kurtosis_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Kurtosis"))
        self.check_component_kurtosis_action.triggered.connect(
            lambda: self.check_component_moments("kurtosis"))
        self.check_component_fractions_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Fractions"))
        self.check_component_fractions_action.triggered.connect(
            self.check_component_fractions)
        self.degrade_results_action = self.detect_outliers_menu.addAction(
            self.tr("Degrade Results"))
        self.degrade_results_action.triggered.connect(self.degrade_results)
        self.try_align_components_action = self.detect_outliers_menu.addAction(
            self.tr("Try Align Components"))
        self.try_align_components_action.triggered.connect(
            self.try_align_components)

        self.analyse_typical_components_action = self.menu.addAction(
            qta.icon("ei.tags"), self.tr("Analyse Typical Components"))
        self.analyse_typical_components_action.triggered.connect(
            self.analyse_typical_components)
        self.load_dump_action = self.menu.addAction(
            qta.icon("fa.database"), self.tr("Load Binary Dump"))
        self.load_dump_action.triggered.connect(self.load_dump)
        self.save_dump_action = self.menu.addAction(
            qta.icon("fa.save"), self.tr("Save Binary Dump"))
        self.save_dump_action.triggered.connect(self.save_dump)
        self.save_excel_action = self.menu.addAction(
            qta.icon("mdi.microsoft-excel"), self.tr("Save Excel"))
        self.save_excel_action.triggered.connect(
            lambda: self.on_save_excel_clicked(align_components=False))
        self.save_excel_align_action = self.menu.addAction(
            qta.icon("mdi.microsoft-excel"),
            self.tr("Save Excel (Force Alignment)"))
        self.save_excel_align_action.triggered.connect(
            lambda: self.on_save_excel_clicked(align_components=True))
        self.data_table.customContextMenuRequested.connect(self.show_menu)
        # necessary to add actions of menu to this widget itself,
        # otherwise, the shortcuts will not be triggered
        self.addActions(self.menu.actions())

    def show_menu(self, pos: QPoint):
        self.menu.popup(QCursor.pos())

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    @property
    def distance_name(self) -> str:
        return self.distance_combo_box.currentText()

    @property
    def distance_func(self) -> typing.Callable:
        return get_distance_func_by_name(self.distance_combo_box.currentText())

    @property
    def page_index(self) -> int:
        return self.current_page_combo_box.currentIndex()

    @property
    def n_pages(self) -> int:
        return self.current_page_combo_box.count()

    @property
    def n_results(self) -> int:
        return len(self.__fitting_results)

    @property
    def selections(self):
        start = self.page_index * self.PAGE_ROWS
        temp = set()
        for item in self.data_table.selectedRanges():
            for i in range(item.topRow(),
                           min(self.PAGE_ROWS + 1,
                               item.bottomRow() + 1)):
                temp.add(i + start)
        indexes = list(temp)
        indexes.sort()
        return indexes

    def update_page_list(self):
        last_page_index = self.page_index
        if self.n_results == 0:
            n_pages = 1
        else:
            n_pages, left = divmod(self.n_results, self.PAGE_ROWS)
            if left != 0:
                n_pages += 1
        self.current_page_combo_box.blockSignals(True)
        self.current_page_combo_box.clear()
        self.current_page_combo_box.addItems(
            [self.tr("Page {0}").format(i + 1) for i in range(n_pages)])
        if last_page_index >= n_pages:
            self.current_page_combo_box.setCurrentIndex(n_pages - 1)
        else:
            self.current_page_combo_box.setCurrentIndex(last_page_index)
        self.current_page_combo_box.blockSignals(False)

    def update_page(self, page_index: int):
        def write(row: int, col: int, value: str):
            if isinstance(value, str):
                pass
            elif isinstance(value, int):
                value = str(value)
            elif isinstance(value, float):
                value = f"{value: 0.4f}"
            else:
                value = value.__str__()
            item = QTableWidgetItem(value)
            item.setTextAlignment(Qt.AlignCenter)
            self.data_table.setItem(row, col, item)

        # necessary to clear
        self.data_table.clear()
        if page_index == self.n_pages - 1:
            start = page_index * self.PAGE_ROWS
            end = self.n_results
        else:
            start, end = page_index * self.PAGE_ROWS, (page_index +
                                                       1) * self.PAGE_ROWS
        self.data_table.setRowCount(end - start)
        self.data_table.setColumnCount(7)
        self.data_table.setHorizontalHeaderLabels([
            self.tr("Resolver"),
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("N_iterations"),
            self.tr("Spent Time [s]"),
            self.tr("Final Distance"),
            self.tr("Has Reference")
        ])
        sample_names = [
            result.sample.name for result in self.__fitting_results[start:end]
        ]
        self.data_table.setVerticalHeaderLabels(sample_names)
        for row, result in enumerate(self.__fitting_results[start:end]):
            write(row, 0, result.task.resolver)
            write(row, 1,
                  self.get_distribution_name(result.task.distribution_type))
            write(row, 2, result.task.n_components)
            write(row, 3, result.n_iterations)
            write(row, 4, result.time_spent)
            write(
                row, 5,
                self.distance_func(result.sample.distribution,
                                   result.distribution))
            has_ref = result.task.initial_guess is not None or result.task.reference is not None
            write(row, 6, self.tr("Yes") if has_ref else self.tr("No"))

        self.data_table.resizeColumnsToContents()

    def on_previous_button_clicked(self):
        if self.page_index > 0:
            self.current_page_combo_box.setCurrentIndex(self.page_index - 1)

    def on_next_button_clicked(self):
        if self.page_index < self.n_pages - 1:
            self.current_page_combo_box.setCurrentIndex(self.page_index + 1)

    def get_distribution_name(self, distribution_type: DistributionType):
        if distribution_type == DistributionType.Normal:
            return self.tr("Normal")
        elif distribution_type == DistributionType.Weibull:
            return self.tr("Weibull")
        elif distribution_type == DistributionType.SkewNormal:
            return self.tr("Skew Normal")
        else:
            raise NotImplementedError(distribution_type)

    def add_result(self, result: SSUResult):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.append(result)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def add_results(self, results: typing.List[SSUResult]):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.extend(results)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def mark_selections(self):
        for index in self.selections:
            self.result_marked.emit(self.__fitting_results[index])

    def remove_results(self, indexes):
        results = []
        for i in reversed(indexes):
            res = self.__fitting_results.pop(i)
            results.append(res)
        self.update_page_list()
        self.update_page(self.page_index)

    def remove_selections(self):
        indexes = self.selections
        self.remove_results(indexes)

    def remove_all_results(self):
        res = self.remove_warning_msg.exec_()
        if res == QMessageBox.Yes:
            self.__fitting_results.clear()
            self.update_page_list()
            self.update_page(0)

    def show_distance(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.distance_chart.show_distance_series(result.get_distance_series(
            self.distance_name),
                                                 title=result.sample.name)
        self.distance_chart.show()

    def show_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_model(result.view_model)
        self.mixed_distribution_chart.show()

    def show_history_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_result(result)
        self.mixed_distribution_chart.show()

    def load_dump(self):
        filename, _ = self.file_dialog.getOpenFileName(
            self, self.tr("Select a binary dump file of SSU results"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "rb") as f:
            results = pickle.load(f)  # type: list[SSUResult]
            valid = True
            if isinstance(results, list):
                for result in results:
                    if not isinstance(result, SSUResult):
                        valid = False
                        break
            else:
                valid = False

            if valid:
                if self.n_results != 0 and len(results) != 0:
                    old_classes = self.__fitting_results[0].classes_φ
                    new_classes = results[0].classes_φ
                    classes_inconsistent = False
                    if len(old_classes) != len(new_classes):
                        classes_inconsistent = True
                    else:
                        classes_error = np.abs(old_classes - new_classes)
                        if not np.all(np.less_equal(classes_error, 1e-8)):
                            classes_inconsistent = True
                    if classes_inconsistent:
                        self.show_error(
                            self.
                            tr("The results in the dump file has inconsistent grain-size classes with that in your list."
                               ))
                        return
                self.add_results(results)
            else:
                self.show_error(self.tr("The binary dump file is invalid."))

    def save_dump(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Save the SSU results to binary dump file"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "wb") as f:
            pickle.dump(self.__fitting_results, f)

    def save_excel(self, filename, align_components=False):
        if self.n_results == 0:
            return

        results = self.__fitting_results.copy()
        classes_μm = results[0].classes_μm
        n_components_list = [
            result.n_components for result in self.__fitting_results
        ]
        count_dict = Counter(n_components_list)
        max_n_components = max(count_dict.keys())
        self.logger.debug(
            f"N_components: {count_dict}, Max N_components: {max_n_components}"
        )

        flags = []
        if not align_components:
            for result in results:
                flags.extend(range(result.n_components))
        else:
            n_components_desc = "\n".join([
                self.tr("{0} Component(s): {1}").format(n_components, count)
                for n_components, count in count_dict.items()
            ])
            self.show_info(
                self.tr("N_components distribution of Results:\n{0}").format(
                    n_components_desc))
            stacked_components = []
            for result in self.__fitting_results:
                for component in result.components:
                    stacked_components.append(component.distribution)
            stacked_components = np.array(stacked_components)
            cluser = KMeans(n_clusters=max_n_components)
            flags = cluser.fit_predict(stacked_components)
            # check flags to make it unique
            flag_index = 0
            for i, result in enumerate(self.__fitting_results):
                result_flags = set()
                for component in result.components:
                    if flags[flag_index] in result_flags:
                        if flags[flag_index] == max_n_components:
                            flags[flag_index] = max_n_components - 1
                        else:
                            flag_index[flag_index] += 1
                        result_flags.add(flags[flag_index])
                    flag_index += 1

            flag_set = set(flags)
            picked = []
            for target_flag in flag_set:
                for i, flag in enumerate(flags):
                    if flag == target_flag:
                        picked.append(
                            (target_flag,
                             logarithmic(classes_μm,
                                         stacked_components[i])["mean"]))
                        break
            picked.sort(key=lambda x: x[1])
            flag_map = {flag: index for index, (flag, _) in enumerate(picked)}
            flags = np.array([flag_map[flag] for flag in flags])

        wb = openpyxl.Workbook()
        prepare_styles(wb)
        ws = wb.active
        ws.title = self.tr("README")
        description = \
            """
            This Excel file was generated by QGrain ({0}).

            Please cite:
            Liu, Y., Liu, X., Sun, Y., 2021. QGrain: An open-source and easy-to-use software for the comprehensive analysis of grain size distributions. Sedimentary Geology 423, 105980. https://doi.org/10.1016/j.sedgeo.2021.105980

            It contanins 4 + max(N_components) sheets:
            1. The first sheet is the sample distributions of SSU results.
            2. The second sheet is used to put the infomation of fitting.
            3. The third sheet is the statistic parameters calculated by statistic moment method.
            4. The fouth sheet is the distributions of unmixed components and their sum of each sample.
            5. Other sheets are the unmixed end-member distributions which were discretely stored.

            The SSU algorithm is implemented by QGrain.

            """.format(QGRAIN_VERSION)

        def write(row, col, value, style="normal_light"):
            cell = ws.cell(row + 1, col + 1, value=value)
            cell.style = style

        lines_of_desc = description.split("\n")
        for row, line in enumerate(lines_of_desc):
            write(row, 0, line, style="description")
        ws.column_dimensions[column_to_char(0)].width = 200

        ws = wb.create_sheet(self.tr("Sample Distributions"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        for col, value in enumerate(classes_μm, 1):
            write(0, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            for col, value in enumerate(result.sample.distribution, 1):
                write(row, col, value, style=style)
            QCoreApplication.processEvents()

        ws = wb.create_sheet(self.tr("Information of Fitting"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        headers = [
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("Resolver"),
            self.tr("Resolver Settings"),
            self.tr("Initial Guess"),
            self.tr("Reference"),
            self.tr("Spent Time [s]"),
            self.tr("N_iterations"),
            self.tr("Final Distance [log10MSE]")
        ]
        for col, value in enumerate(headers, 1):
            write(0, col, value, style="header")
            if col in (4, 5, 6):
                ws.column_dimensions[column_to_char(col)].width = 10
            else:
                ws.column_dimensions[column_to_char(col)].width = 10
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            write(row, 1, result.distribution_type.name, style=style)
            write(row, 2, result.n_components, style=style)
            write(row, 3, result.task.resolver, style=style)
            write(row,
                  4,
                  self.tr("Default") if result.task.resolver_setting is None
                  else result.task.resolver_setting.__str__(),
                  style=style)
            write(row,
                  5,
                  self.tr("None") if result.task.initial_guess is None else
                  result.task.initial_guess.__str__(),
                  style=style)
            write(row,
                  6,
                  self.tr("None") if result.task.reference is None else
                  result.task.reference.__str__(),
                  style=style)
            write(row, 7, result.time_spent, style=style)
            write(row, 8, result.n_iterations, style=style)
            write(row, 9, result.get_distance("log10MSE"), style=style)

        ws = wb.create_sheet(self.tr("Statistic Moments"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
        ws.column_dimensions[column_to_char(0)].width = 16
        headers = []
        sub_headers = [
            self.tr("Proportion"),
            self.tr("Mean [φ]"),
            self.tr("Mean [μm]"),
            self.tr("STD [φ]"),
            self.tr("STD [μm]"),
            self.tr("Skewness"),
            self.tr("Kurtosis")
        ]
        for i in range(max_n_components):
            write(0,
                  i * len(sub_headers) + 1,
                  self.tr("C{0}").format(i + 1),
                  style="header")
            ws.merge_cells(start_row=1,
                           start_column=i * len(sub_headers) + 2,
                           end_row=1,
                           end_column=(i + 1) * len(sub_headers) + 1)
            headers.extend(sub_headers)
        for col, value in enumerate(headers, 1):
            write(1, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        flag_index = 0
        for row, result in enumerate(results, 2):
            if row % 2 == 0:
                style = "normal_light"
            else:
                style = "normal_dark"
            write(row, 0, result.sample.name, style=style)
            for component in result.components:
                index = flags[flag_index]
                write(row,
                      index * len(sub_headers) + 1,
                      component.fraction,
                      style=style)
                write(row,
                      index * len(sub_headers) + 2,
                      component.logarithmic_moments["mean"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 3,
                      component.geometric_moments["mean"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 4,
                      component.logarithmic_moments["std"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 5,
                      component.geometric_moments["std"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 6,
                      component.logarithmic_moments["skewness"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 7,
                      component.logarithmic_moments["kurtosis"],
                      style=style)
                flag_index += 1

        ws = wb.create_sheet(self.tr("Unmixed Components"))
        ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=2)
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        for col, value in enumerate(classes_μm, 2):
            write(0, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        row = 1
        for result_index, result in enumerate(results, 1):
            if result_index % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            ws.merge_cells(start_row=row + 1,
                           start_column=1,
                           end_row=row + result.n_components + 1,
                           end_column=1)
            for component_i, component in enumerate(result.components, 1):
                write(row, 1, self.tr("C{0}").format(component_i), style=style)
                for col, value in enumerate(
                        component.distribution * component.fraction, 2):
                    write(row, col, value, style=style)
                row += 1
            write(row, 1, self.tr("Sum"), style=style)
            for col, value in enumerate(result.distribution, 2):
                write(row, col, value, style=style)
            row += 1

        ws_dict = {}
        flag_set = set(flags)
        for flag in flag_set:
            ws = wb.create_sheet(self.tr("Unmixed EM{0}").format(flag + 1))
            write(0, 0, self.tr("Sample Name"), style="header")
            ws.column_dimensions[column_to_char(0)].width = 16
            for col, value in enumerate(classes_μm, 1):
                write(0, col, value, style="header")
                ws.column_dimensions[column_to_char(col)].width = 10
            ws_dict[flag] = ws

        flag_index = 0
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"

            for component in result.components:
                flag = flags[flag_index]
                ws = ws_dict[flag]
                write(row, 0, result.sample.name, style=style)
                for col, value in enumerate(component.distribution, 1):
                    write(row, col, value, style=style)
                flag_index += 1

        wb.save(filename)
        wb.close()

    def on_save_excel_clicked(self, align_components=False):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any SSU result."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            None, self.tr("Choose a filename to save SSU Results"), None,
            "Microsoft Excel (*.xlsx)")
        if filename is None or filename == "":
            return
        try:
            self.save_excel(filename, align_components)
            self.show_info(
                self.tr("SSU results have been saved to:\n    {0}").format(
                    filename))
        except Exception as e:
            self.show_error(
                self.
                tr("Error raised while save SSU results to Excel file.\n    {0}"
                   ).format(e.__str__()))

    def on_fitting_succeeded(self, result: SSUResult):
        result_replace_index = self.retry_tasks[result.task.uuid]
        self.__fitting_results[result_replace_index] = result
        self.retry_tasks.pop(result.task.uuid)
        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(self.retry_tasks)))
        if len(self.retry_tasks) == 0:
            self.retry_progress_msg.close()

        self.logger.debug(
            f"Retried task succeeded, sample name={result.task.sample.name}, distribution_type={result.task.distribution_type.name}, n_components={result.task.n_components}"
        )
        self.update_page(self.page_index)

    def on_fitting_failed(self, failed_info: str, task: SSUTask):
        # necessary to remove it from the dict
        self.retry_tasks.pop(task.uuid)
        if len(self.retry_tasks) == 0:
            self.retry_progress_msg.close()
        self.show_error(
            self.tr("Failed to retry task, sample name={0}.\n{1}").format(
                task.sample.name, failed_info))
        self.logger.warning(
            f"Failed to retry task, sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
        )

    def retry_results(self, indexes, results):
        assert len(indexes) == len(results)
        if len(results) == 0:
            return
        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(results)))
        self.retry_timer.start(1)
        for index, result in zip(indexes, results):
            query = self.__reference_viewer.query_reference(result.sample)
            ref_result = None
            if query is None:
                nearby_results = self.__fitting_results[
                    index - 5:index] + self.__fitting_results[index + 1:index +
                                                              6]
                ref_result = self.__reference_viewer.find_similar(
                    result.sample, nearby_results)
            else:
                ref_result = query
            keys = ["mean", "std", "skewness"]
            # reference = [{key: comp.logarithmic_moments[key] for key in keys} for comp in ref_result.components]
            task = SSUTask(
                result.sample,
                ref_result.distribution_type,
                ref_result.n_components,
                resolver=ref_result.task.resolver,
                resolver_setting=ref_result.task.resolver_setting,
                #    reference=reference)
                initial_guess=ref_result.last_func_args)

            self.logger.debug(
                f"Retry task: sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
            )
            self.retry_tasks[task.uuid] = index
            self.async_worker.execute_task(task)

    def degrade_results(self):
        degrade_results = []  # type: list[SSUResult]
        degrade_indexes = []  # type: list[int]
        for i, result in enumerate(self.__fitting_results):
            for component in result.components:
                if component.fraction < 1e-3:
                    degrade_results.append(result)
                    degrade_indexes.append(i)
                    break
        self.logger.debug(
            f"Results should be degrade (have a redundant component): {[result.sample.name for result in degrade_results]}"
        )
        if len(degrade_results) == 0:
            self.show_info(
                self.tr("No fitting result was evaluated as an outlier."))
            return
        self.show_info(
            self.
            tr("The results below should be degrade (have a redundant component:\n    {0}"
               ).format(", ".join(
                   [result.sample.name for result in degrade_results])))

        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(degrade_results)))
        self.retry_timer.start(1)
        for index, result in zip(degrade_indexes, degrade_results):
            reference = []
            n_redundant = 0
            for component in result.components:
                if component.fraction < 1e-3:
                    n_redundant += 1
                else:
                    reference.append(
                        dict(mean=component.logarithmic_moments["mean"],
                             std=component.logarithmic_moments["std"],
                             skewness=component.logarithmic_moments["skewness"]
                             ))
            task = SSUTask(
                result.sample,
                result.distribution_type,
                result.n_components -
                n_redundant if result.n_components > n_redundant else 1,
                resolver=result.task.resolver,
                resolver_setting=result.task.resolver_setting,
                reference=reference)
            self.logger.debug(
                f"Retry task: sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
            )
            self.retry_tasks[task.uuid] = index
            self.async_worker.execute_task(task)

    def ask_deal_outliers(self, outlier_results: typing.List[SSUResult],
                          outlier_indexes: typing.List[int]):
        assert len(outlier_indexes) == len(outlier_results)
        if len(outlier_results) == 0:
            self.show_info(
                self.tr("No fitting result was evaluated as an outlier."))
        else:
            if len(outlier_results) > 100:
                self.outlier_msg.setText(
                    self.
                    tr("The fitting results have the component that its fraction is near zero:\n    {0}...(total {1} outliers)\nHow to deal with them?"
                       ).format(
                           ", ".join([
                               result.sample.name
                               for result in outlier_results[:100]
                           ]), len(outlier_results)))
            else:
                self.outlier_msg.setText(
                    self.
                    tr("The fitting results have the component that its fraction is near zero:\n    {0}\nHow to deal with them?"
                       ).format(", ".join([
                           result.sample.name for result in outlier_results
                       ])))
            res = self.outlier_msg.exec_()
            if res == QMessageBox.Discard:
                self.remove_results(outlier_indexes)
            elif res == QMessageBox.Retry:
                self.retry_results(outlier_indexes, outlier_results)
            else:
                pass

    def check_nan_and_inf(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        outlier_results = []
        outlier_indexes = []
        for i, result in enumerate(self.__fitting_results):
            if not result.is_valid:
                outlier_results.append(result)
                outlier_indexes.append(i)
        self.logger.debug(
            f"Outlier results with the nan or inf value(s): {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_final_distances(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        distances = []
        for result in self.__fitting_results:
            distances.append(result.get_distance(self.distance_name))
        distances = np.array(distances)
        self.boxplot_chart.show_dataset([distances],
                                        xlabels=[self.distance_name],
                                        ylabel=self.tr("Distance"))
        self.boxplot_chart.show()

        # calculate the 1/4, 1/2, and 3/4 postion value to judge which result is invalid
        # 1. the mean squared errors are much higher in the results which are lack of components
        # 2. with the component number getting higher, the mean squared error will get lower and finally reach the minimum
        median = np.median(distances)
        upper_group = distances[np.greater(distances, median)]
        lower_group = distances[np.less(distances, median)]
        value_1_4 = np.median(lower_group)
        value_3_4 = np.median(upper_group)
        distance_QR = value_3_4 - value_1_4
        outlier_results = []
        outlier_indexes = []
        for i, (result,
                distance) in enumerate(zip(self.__fitting_results, distances)):
            if distance > value_3_4 + distance_QR * 1.5:
                # which error too small is not outlier
                # if distance > value_3_4 + distance_QR * 1.5 or distance < value_1_4 - distance_QR * 1.5:
                outlier_results.append(result)
                outlier_indexes.append(i)
        self.logger.debug(
            f"Outlier results with too greater distances: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_component_moments(self, key: str):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        max_n_components = 0
        for result in self.__fitting_results:
            if result.n_components > max_n_components:
                max_n_components = result.n_components
        moments = []
        for i in range(max_n_components):
            moments.append([])

        for result in self.__fitting_results:
            for i, component in enumerate(result.components):
                if np.isnan(component.logarithmic_moments[key]) or np.isinf(
                        component.logarithmic_moments[key]):
                    pass
                else:
                    moments[i].append(component.logarithmic_moments[key])

        # key_trans = {"mean": self.tr("Mean"), "std": self.tr("STD"), "skewness": self.tr("Skewness"), "kurtosis": self.tr("Kurtosis")}
        key_label_trans = {
            "mean": self.tr("Mean [φ]"),
            "std": self.tr("STD [φ]"),
            "skewness": self.tr("Skewness"),
            "kurtosis": self.tr("Kurtosis")
        }
        self.boxplot_chart.show_dataset(
            moments,
            xlabels=[f"C{i+1}" for i in range(max_n_components)],
            ylabel=key_label_trans[key])
        self.boxplot_chart.show()

        outlier_dict = {}

        for i in range(max_n_components):
            stacked_moments = np.array(moments[i])
            # calculate the 1/4, 1/2, and 3/4 postion value to judge which result is invalid
            # 1. the mean squared errors are much higher in the results which are lack of components
            # 2. with the component number getting higher, the mean squared error will get lower and finally reach the minimum
            median = np.median(stacked_moments)
            upper_group = stacked_moments[np.greater(stacked_moments, median)]
            lower_group = stacked_moments[np.less(stacked_moments, median)]
            value_1_4 = np.median(lower_group)
            value_3_4 = np.median(upper_group)
            distance_QR = value_3_4 - value_1_4

            for j, result in enumerate(self.__fitting_results):
                if result.n_components > i:
                    distance = result.components[i].logarithmic_moments[key]
                    if distance > value_3_4 + distance_QR * 1.5 or distance < value_1_4 - distance_QR * 1.5:
                        outlier_dict[j] = result

        outlier_results = []
        outlier_indexes = []
        for index, result in sorted(outlier_dict.items(), key=lambda x: x[0]):
            outlier_indexes.append(index)
            outlier_results.append(result)
        self.logger.debug(
            f"Outlier results with abnormal {key} values of their components: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_component_fractions(self):
        outlier_results = []
        outlier_indexes = []
        for i, result in enumerate(self.__fitting_results):
            for component in result.components:
                if component.fraction < 1e-3:
                    outlier_results.append(result)
                    outlier_indexes.append(i)
                    break
        self.logger.debug(
            f"Outlier results with the component that its fraction is near zero: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def try_align_components(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        import matplotlib.pyplot as plt
        n_components_list = [
            result.n_components for result in self.__fitting_results
        ]
        count_dict = Counter(n_components_list)
        max_n_components = max(count_dict.keys())
        self.logger.debug(
            f"N_components: {count_dict}, Max N_components: {max_n_components}"
        )
        n_components_desc = "\n".join([
            self.tr("{0} Component(s): {1}").format(n_components, count)
            for n_components, count in count_dict.items()
        ])
        self.show_info(
            self.tr("N_components distribution of Results:\n{0}").format(
                n_components_desc))

        x = self.__fitting_results[0].classes_μm
        stacked_components = []
        for result in self.__fitting_results:
            for component in result.components:
                stacked_components.append(component.distribution)
        stacked_components = np.array(stacked_components)

        cluser = KMeans(n_clusters=max_n_components)
        flags = cluser.fit_predict(stacked_components)

        figure = plt.figure(figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        axes = figure.add_subplot(1, 1, 1)
        for flag, distribution in zip(flags, stacked_components):
            plt.plot(x, distribution, c=cmap(flag), zorder=flag)
        axes.set_xscale("log")
        axes.set_xlabel(self.tr("Grain-size [μm]"))
        axes.set_ylabel(self.tr("Frequency"))
        figure.tight_layout()
        figure.show()

        outlier_results = []
        outlier_indexes = []
        flag_index = 0
        for i, result in enumerate(self.__fitting_results):
            result_flags = set()
            for component in result.components:
                if flags[flag_index] in result_flags:
                    outlier_results.append(result)
                    outlier_indexes.append(i)
                    break
                else:
                    result_flags.add(flags[flag_index])
                flag_index += 1
        self.logger.debug(
            f"Outlier results that have two components in the same cluster: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def analyse_typical_components(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return

        self.typical_chart.show_typical(self.__fitting_results)
        self.typical_chart.show()
예제 #11
0
class SoundBoard(QDialog):
    def __init__(self):
        super(SoundBoard, self).__init__()
        self.title = '=== SoundBoard ==='
        # positionnement de la fenêtre à l'ouverture
        self.left = 50
        self.top = 50
        # initialisation de la largeur et hauteur par défaut
        self.width = 500
        self.height = 500
        self.currFileName = ""
        self.pbPosToModify = -1
        self.initUI()

    def initUI(self):
        self.setWindowTitle(self.title)
        self.setGeometry(self.left, self.top, self.width, self.height)
        self.windowLayout = QHBoxLayout()
        self.tableWidget = QTableWidget()
        self.tableWidget.horizontalHeader().hide()
        self.tableWidget.verticalHeader().hide()
        self.initIcons()
        self.initMenu()
        self.initColorPicker()
        self.initButtons()
        self.windowLayout.setStretch(1, 0)
        self.setLayout(self.windowLayout)
        self.show()

    def initIcons(self):
        self.iEdit = QIcon()
        self.iEdit.addPixmap(QPixmap("./icons/edit.png"), QIcon.Normal,
                             QIcon.Off)

        self.iPlus = QIcon()
        self.iPlus.addPixmap(QPixmap("./icons/plus.png"), QIcon.Normal,
                             QIcon.Off)

        self.iMinus = QIcon()
        self.iMinus.addPixmap(QPixmap("./icons/minus.png"), QIcon.Normal,
                              QIcon.Off)

        self.iParam = QIcon()
        self.iParam.addPixmap(QPixmap("./icons/cog.png"), QIcon.Normal,
                              QIcon.Off)

    def initMenu(self):
        layout = QVBoxLayout()
        hlayout = QHBoxLayout()

        # bouton ajout
        self.tbPlus = QToolButton()
        self.tbPlus.setGeometry(QRect(0, 0, 32, 32))
        self.tbPlus.setIcon(self.iPlus)
        self.tbPlus.setObjectName("tbPlus")

        hlayout.addWidget(self.tbPlus)
        self.tbPlus.clicked.connect(self.add)

        # bouton suppression
        self.tbMinus = QToolButton()
        self.tbMinus.setGeometry(QRect(0, 0, 32, 32))
        self.tbMinus.setIcon(self.iMinus)
        self.tbMinus.setObjectName("tbMinus")

        hlayout.addWidget(self.tbMinus)
        self.tbMinus.clicked.connect(self.delete)

        # bouton édition
        self.tbEdit = QToolButton()
        self.tbEdit.setGeometry(QRect(0, 0, 32, 32))
        self.tbEdit.setIcon(self.iEdit)
        self.tbEdit.setObjectName("tbEdit")

        hlayout.addWidget(self.tbEdit)
        self.tbEdit.clicked.connect(self.editBtn)

        # bouton paramètres
        self.tbParam = QToolButton()
        self.tbParam.setGeometry(QRect(0, 0, 32, 32))
        self.tbParam.setIcon(self.iParam)
        self.tbParam.setObjectName("tbParam")

        hlayout.addWidget(self.tbParam)
        self.tbParam.clicked.connect(self.settings)

        layout.addLayout(hlayout)

        self.pbStop = QPushButton("Don't STOP\n\nthe\n\nSoundBoard")
        self.pbStop.setStyleSheet("font-weight: bold;")
        self.pbStop.setMinimumSize(QSize(100, 100))
        self.pbStop.setGeometry(QRect(0, 0, 100, 100))
        layout.addWidget(self.pbStop)
        self.pbStop.clicked.connect(self.stop)

        spacerMenu = QSpacerItem(20, 40, QSizePolicy.Minimum,
                                 QSizePolicy.Expanding)
        layout.addItem(spacerMenu)

        self.windowLayout.addLayout(layout)

    def startInitButtons(self):
        self.tableWidget.clear()
        self.tableWidget.clearSpans()
        self.tableWidget.setColumnWidth(0, 100)
        self.tableWidget.setColumnWidth(2, 100)
        self.cdColorPicker.setVisible(False)

        self.tableWidget.horizontalHeader().hide()
        # import des informations boutons contenues dans le json
        with open('buttons.json', encoding='utf-8') as json_file:
            self.data_buttons = json.load(json_file)

        # stockage de la position la plus élevée pour le cadrage
        self.positions = [p['position'] for p in self.data_buttons['buttons']]
        self.max_pos = max(self.positions)

        # calcul du nombre de boutons par hauteur et largeur
        self.BtnH = self.data_buttons['buttons_grid']['height']
        self.BtnW = self.data_buttons['buttons_grid']['width']
        self.setGeometry(self.left, self.top, 140 + self.BtnW * 100,
                         175 if self.BtnH * 31 < 175 else 25 + self.BtnH * 30)
        self.tableWidget.setColumnCount(self.BtnW)
        self.tableWidget.setRowCount(self.BtnH)

    def endInitButtons(self):
        buttonsLayout = QVBoxLayout()
        buttonsLayout.setStretch(0, 1)
        buttonsLayout.addWidget(self.tableWidget)

        self.windowLayout.addLayout(buttonsLayout)

        self.setGeometry(self.left, self.top, 140 + self.BtnW * 100,
                         175 if self.BtnH * 31 < 175 else 25 + self.BtnH * 30)

    def initButtons(self):
        self.startInitButtons()

        # positionnement des boutons en fonction des positions du json
        for ligne in range(self.BtnH):
            for colonne in range(self.BtnW):
                if (ligne * self.BtnW) + (colonne + 1) in self.positions:
                    for b in self.data_buttons['buttons']:
                        if b['position'] == (ligne * self.BtnW) + (colonne +
                                                                   1):
                            pb = QPushButton(b['name'][:9])
                            pb.setProperty('pbPos', b['position'])
                            # si fond clair, font noire, si sombre, font blanche
                            if (b['r'] * 0.299 + b['g'] * 0.587 +
                                    b['b'] * 0.114) > 186:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #000000;"
                                )
                            else:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #ffffff;"
                                )
                            self.tableWidget.setCellWidget(ligne, colonne, pb)
                            pb.clicked.connect(self.play)
                else:
                    pb = QPushButton('Nouveau')
                    calcPos = self.BtnW * ligne + colonne + 1
                    pb.setProperty('pbPos', f"nouveau,{calcPos}")
                    pb.clicked.connect(self.add)
                    self.tableWidget.setCellWidget(ligne, colonne, pb)
                colonne += 1
            ligne += 1

        self.endInitButtons()

    def initColorPicker(self):
        self.lColorPicker = QVBoxLayout()
        self.cdColorPicker = QColorDialog()
        self.cdColorPicker.setOption(self.cdColorPicker.NoButtons, True)
        self.colorSelected = self.cdColorPicker.currentColor()

        self.lColorPicker.addWidget(self.cdColorPicker)
        self.cdColorPicker.setVisible(False)
        self.cdColorPicker.currentColorChanged.connect(self.colorChanged)

        self.windowLayout.addLayout(self.lColorPicker)

    def play(self):
        pb = self.sender()
        pbPos = pb.property('pbPos')
        for b in self.data_buttons['buttons']:
            if pbPos == b['position']:
                pbFile = b['file']
        if (p.get_state() == vlc.State.Playing):
            p.stop()
            media = instance.media_new(soundRep + pbFile)
            if (self.currFileName != pbFile):
                p.set_media(media)
                p.play()
                self.currFileName = pbFile
        else:
            media = instance.media_new(soundRep + pbFile)
            p.set_media(media)
            p.play()
            self.currFileName = pbFile

    def stop(self):
        p.stop()

    def add(self):
        self.cdColorPicker.setVisible(True)
        self.tableWidget.clear()
        self.tableWidget.clearSpans()
        self.tableWidget.setColumnWidth(2, 100)

        self.tableWidget.setColumnCount(6)
        self.tableWidget.setRowCount(len(self.data_buttons['buttons']) + 1)

        self.tableWidget.horizontalHeader().show()

        self.tableWidget.setHorizontalHeaderItem(0, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(0).setText('Nom')
        self.tableWidget.setHorizontalHeaderItem(1, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(1).setText('Fichier')
        self.tableWidget.setHorizontalHeaderItem(2, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(2).setText('')
        self.tableWidget.setColumnWidth(2, 22)
        self.tableWidget.setHorizontalHeaderItem(3, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(3).setText('Position')
        self.tableWidget.setHorizontalHeaderItem(4, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(4).setText('Couleur')
        self.tableWidget.setHorizontalHeaderItem(5, QTableWidgetItem())
        self.tableWidget.horizontalHeaderItem(5).setText('')

        # nom
        self.leName = QLineEdit()
        self.leName.setPlaceholderText('Nom (10 max.)')
        self.tableWidget.setCellWidget(0, 0, self.leName)
        # fichier
        self.leFile = QLineEdit()
        self.leFile.setPlaceholderText('Fichier')
        self.tableWidget.setCellWidget(0, 1, self.leFile)
        # browse
        pbBrowser = QPushButton('...')
        pbBrowser.setMinimumSize(QSize(21, 21))
        pbBrowser.clicked.connect(self.browseMedia)
        self.tableWidget.setCellWidget(0, 2, pbBrowser)
        # position
        self.lePos = QLineEdit()
        self.lePos.setPlaceholderText('Position')
        self.tableWidget.setCellWidget(0, 3, self.lePos)
        # couleur
        self.leColor = QLineEdit()
        self.leColor.setPlaceholderText('255,255,255')
        self.leColor.setText(
            str(self.colorSelected.red()) + "," +
            str(self.colorSelected.green()) + "," +
            str(self.colorSelected.blue()))
        self.tableWidget.setCellWidget(0, 4, self.leColor)
        # validation
        pbValid = QPushButton('Valider')
        pbValid.clicked.connect(self.addValid)
        self.tableWidget.setCellWidget(0, 5, pbValid)

        pb = self.sender()
        pbPos = pb.property('pbPos')
        if pbPos is not None:
            if str(pbPos)[:8] == 'nouveau,':
                self.lePos.setText(pbPos[8:])

        def sortByPos(val):
            return val['position']

        self.data_buttons['buttons'].sort(key=sortByPos)
        for ligne, b in enumerate(self.data_buttons['buttons'], start=1):
            self.tableWidget.setSpan(ligne, 1, 1, 2)
            self.tableWidget.setCellWidget(ligne, 0, QLabel(b['name']))
            self.tableWidget.setCellWidget(ligne, 1, QLabel(b['file']))
            self.tableWidget.setCellWidget(ligne, 3,
                                           QLabel(str(b['position'])))
            self.tableWidget.setCellWidget(ligne, 4, QLabel('Couleur'))

        # 530 color picker width
        self.setGeometry(self.left, self.top, 690 + 530, 300)

    def addValid(self):
        gName = self.leName.text()
        self.leName.setStyleSheet("color: rgb(0,0,0);")
        gFile = self.leFile.text()
        self.leFile.setStyleSheet("color: rgb(0,0,0);")
        gPos = self.lePos.text()
        self.lePos.setStyleSheet("color: rgb(0,0,0);")
        gColor = self.leColor.text()
        self.leColor.setStyleSheet("color: rgb(0,0,0);")
        # si champs vides
        if ((gName == '' or gName == 'Obligatoire !')
                or (gFile == '' or gFile == 'Obligatoire !')
                or (gPos == '' or gColor == 'Obligatoire !')
                or (gColor == '' or gColor == 'Obligatoire !')):
            if gName == '' or gName == 'Obligatoire !':
                self.leName.setText('Obligatoire !')
                self.leName.setStyleSheet(
                    "color: rgb(255,0,0); font-weight: bold;")
            if gFile == '' or gFile == 'Obligatoire !':
                self.leFile.setText('Obligatoire !')
                self.leFile.setStyleSheet(
                    "color: rgb(255,0,0); font-weight: bold;")
            if gPos == '' or gColor == 'Obligatoire !':
                self.lePos.setText('Obligatoire !')
                self.lePos.setStyleSheet(
                    "color: rgb(255,0,0); font-weight: bold;")
            if gColor == '' or gColor == 'Obligatoire !':
                self.leColor.setText('Obligatoire !')
                self.leColor.setStyleSheet(
                    "color: rgb(255,0,0); font-weight: bold;")
        else:
            # vérif si champ position est un nombre
            try:
                flag = 0
                flag = int(gPos)
            except ValueError:
                self.lePos.setText(f"{str(gPos)} n'est pas un nombre")
                self.lePos.setStyleSheet(
                    "color: rgb(255,0,0); font-weight: bold;")
            # si position est un nombre
            if flag != 0:
                # si position hors grille
                if int(gPos) < 0 or int(gPos) > self.data_buttons[
                        'buttons_grid']['height'] * self.data_buttons[
                            'buttons_grid']['width']:
                    self.lePos.setText(f"{str(gPos)} hors grille")
                    self.lePos.setStyleSheet(
                        "color: rgb(255,0,0); font-weight: bold;")
                else:
                    dictToAppend = {
                        "name": gName,
                        "file": gFile,
                        "position": int(gPos),
                        "r": self.colorSelected.red(),
                        "g": self.colorSelected.green(),
                        "b": self.colorSelected.blue()
                    }
                    # si c'est une modification
                    if self.pbPosToModify != -1:
                        for b in self.data_buttons['buttons']:
                            if b['position'] == self.pbPosToModify:
                                self.data_buttons['buttons'].remove(b)
                        self.data_buttons['buttons'].append(dictToAppend)
                        with open('buttons.json', 'w',
                                  encoding='utf-8') as outfile:
                            json.dump(self.data_buttons, outfile, indent=4)
                        self.initButtons()
                    else:
                        # si position déjà prise
                        if int(gPos) in self.positions:
                            self.lePos.setText(f"{str(gPos)} déjà prise")
                            self.lePos.setStyleSheet(
                                "color: rgb(255,0,0); font-weight: bold;")
                        else:
                            self.data_buttons['buttons'].append(dictToAppend)
                            with open('buttons.json', 'w',
                                      encoding='utf-8') as outfile:
                                json.dump(self.data_buttons, outfile, indent=4)
                            self.initButtons()

    def delete(self):
        self.startInitButtons()

        # positionnement des boutons en fonction des positions du json
        for ligne in range(self.BtnH):
            for colonne in range(self.BtnW):
                if (ligne * self.BtnW) + (colonne + 1) in self.positions:
                    for b in self.data_buttons['buttons']:
                        if b['position'] == (ligne * self.BtnW) + (colonne +
                                                                   1):
                            pb = QPushButton(b['name'][:9])
                            pb.setProperty('pbPos', b['position'])
                            pb.setIcon(self.iMinus)
                            # si fond clair, font noire, si sombre, font blanche
                            if (b['r'] * 0.299 + b['g'] * 0.587 +
                                    b['b'] * 0.114) > 186:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #000000;"
                                )
                            else:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #ffffff;"
                                )
                            self.tableWidget.setCellWidget(ligne, colonne, pb)
                            pb.clicked.connect(self.deleteTw)
                else:
                    pb = QPushButton('Nouveau')
                    calcPos = self.BtnW * ligne + colonne + 1
                    pb.setProperty('pbPos', f"nouveau,{calcPos}")
                    pb.clicked.connect(self.add)
                    self.tableWidget.setCellWidget(ligne, colonne, pb)
                colonne += 1
            ligne += 1

        self.endInitButtons()

    def deleteTw(self):
        pb = self.sender()
        pbPos = pb.property('pbPos')
        for b in self.data_buttons['buttons']:
            if b['position'] == pbPos:
                self.data_buttons['buttons'].remove(b)
                with open('buttons.json', 'w', encoding='utf-8') as outfile:
                    json.dump(self.data_buttons, outfile, indent=4)
                self.delete()

    def editBtn(self):
        self.startInitButtons()

        # positionnement des boutons en fonction des positions du json
        for ligne in range(self.BtnH):
            for colonne in range(self.BtnW):
                if (ligne * self.BtnW) + (colonne + 1) in self.positions:
                    for b in self.data_buttons['buttons']:
                        if b['position'] == (ligne * self.BtnW) + (colonne +
                                                                   1):
                            pb = QPushButton(b['name'][:9])
                            pb.setProperty('pbPos', b['position'])
                            pb.setIcon(self.iEdit)
                            # si fond clair, font noire, si sombre, font blanche
                            if (b['r'] * 0.299 + b['g'] * 0.587 +
                                    b['b'] * 0.114) > 186:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #000000;"
                                )
                            else:
                                pb.setStyleSheet(
                                    f"background-color: rgb({b['r']},{b['g']},{b['b']}); color: #ffffff;"
                                )
                            self.tableWidget.setCellWidget(ligne, colonne, pb)
                            pb.clicked.connect(self.editTw)
                else:
                    pb = QPushButton('Nouveau')
                    pb.setIcon(self.iEdit)
                    calcPos = self.BtnW * ligne + colonne + 1
                    pb.setProperty('pbPos', f"nouveau,{calcPos}")
                    pb.clicked.connect(self.add)
                    self.tableWidget.setCellWidget(ligne, colonne, pb)
                colonne += 1
            ligne += 1

        self.endInitButtons()

    def editTw(self):
        pb = self.sender()
        pbPos = pb.property('pbPos')
        self.pbPosToModify = pbPos
        self.add()
        for b in self.data_buttons['buttons']:
            if b['position'] == pbPos:
                self.leName.setText(b['name'])
                self.leFile.setText(b['file'])
                self.lePos.setText(str(b['position']))
                self.cdColorPicker.setCurrentColor(
                    QColor(b['r'], b['g'], b['b']))

    def settings(self):
        self.tableWidget.clear()
        self.tableWidget.clearSpans()
        self.tableWidget.setColumnWidth(2, 100)

        self.cdColorPicker.setVisible(False)

        self.tableWidget.setColumnCount(2)
        self.tableWidget.setRowCount(4)
        self.tableWidget.horizontalHeader().setSectionResizeMode(
            0, QHeaderView.Stretch)

        self.tableWidget.horizontalHeader().hide()

        # bouton validation
        pb = QPushButton('Valider')
        self.tableWidget.setCellWidget(3, 0, pb)
        pb.clicked.connect(self.saveSettings)

        # bouton annulation
        pb = QPushButton('Annuler')
        self.tableWidget.setCellWidget(3, 1, pb)
        pb.clicked.connect(self.refreshUI)

        # parameters
        self.tableWidget.setSpan(0, 0, 1, 2)
        self.lAlert = QLabel("La modification de ces valeurs entrainera la "
                             "modification de position des boutons")
        self.lAlert.setStyleSheet("font-weight: bold;")
        self.tableWidget.setCellWidget(0, 0, self.lAlert)
        self.tableWidget.setCellWidget(1, 0,
                                       QLabel('Nombre de boutons en Hauteur'))
        self.leH = QLineEdit(str(self.data_buttons['buttons_grid']['height']))
        self.tableWidget.setCellWidget(1, 1, self.leH)
        self.tableWidget.setCellWidget(2, 0,
                                       QLabel('Nombre de boutons en Largeur'))
        self.leW = QLineEdit(str(self.data_buttons['buttons_grid']['width']))
        self.tableWidget.setCellWidget(2, 1, self.leW)

        settingsLayout = QVBoxLayout()
        settingsLayout.setStretch(0, 1)
        settingsLayout.addWidget(self.tableWidget)

        self.windowLayout.addLayout(settingsLayout)

        self.setGeometry(self.left, self.top, 600, 300)

    def saveSettings(self):
        h = int(self.leH.text())
        w = int(self.leW.text())
        if h * w < self.max_pos:
            self.lAlert.setText(f"Le bouton à la position {str(self.max_pos)} "
                                f"est en dehors de la grille {h} x {w}")
            self.lAlert.setStyleSheet(
                "color: rgb(255,0,0); font-weight: bold;")
        else:
            self.data_buttons['buttons_grid']['height'] = int(self.leH.text())
            self.data_buttons['buttons_grid']['width'] = int(self.leW.text())
            with open('buttons.json', 'w', encoding='utf-8') as outfile:
                json.dump(self.data_buttons, outfile, indent=4)
            self.initButtons()

    def refreshUI(self):
        self.initButtons()

    def browseMedia(self):
        self.openFile = QFileDialog.getOpenFileName(
            self, "Sélectionner un média...", "./sons",
            "Image Files (*.avi *.mp3 *.wav)")
        filenameSplitted = self.openFile[0].split('/')
        self.leFile.setText(filenameSplitted[-1])

    def colorChanged(self):
        self.colorSelected = self.cdColorPicker.currentColor()
        self.leColor.setText(
            str(self.colorSelected.red()) + "," +
            str(self.colorSelected.green()) + "," +
            str(self.colorSelected.blue()))
예제 #12
0
class assSelect(QDialog):
    assSummary = Signal(list)

    def __init__(self):
        super().__init__()
        self.subDict = {
            '': {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': []
            }
        }
        self.resize(550, 800)
        self.setWindowTitle('选择要导入的ass字幕轨道')
        layout = QGridLayout()
        self.setLayout(layout)
        layout.addWidget(QLabel('检测到字幕样式:'), 0, 0, 1, 1)
        layout.addWidget(QLabel(''), 0, 1, 1, 1)
        self.subCombox = QComboBox()
        self.subCombox.currentTextChanged.connect(self.selectChange)
        layout.addWidget(self.subCombox, 0, 2, 1, 1)
        self.subTable = QTableWidget()
        self.subTable.setEditTriggers(QAbstractItemView.NoEditTriggers)
        layout.addWidget(self.subTable, 1, 0, 6, 3)
        self.confirm = QPushButton('导入')
        self.confirm.clicked.connect(self.sendSub)
        layout.addWidget(self.confirm, 7, 0, 1, 1)
        self.cancel = QPushButton('取消')
        self.cancel.clicked.connect(self.hide)
        layout.addWidget(self.cancel, 7, 2, 1, 1)

    def setDefault(self, subtitlePath='', index=0):
        if subtitlePath:
            self.assCheck(subtitlePath)
            self.index = index

    def selectChange(self, styleName):
        self.subTable.clear()
        self.subTable.setRowCount(
            len(self.subDict[styleName]) +
            len(self.subDict[styleName]['Tableview']) - 2)
        self.subTable.setColumnCount(3)
        self.subTable.setColumnWidth(2, 270)
        y = 0
        for k, v in self.subDict[styleName].items():
            if k not in ['Tableview', 'Events']:
                self.subTable.setItem(y, 0, QTableWidgetItem(k))
                self.subTable.setItem(y, 1, QTableWidgetItem(v))
                y += 1
            elif k == 'Tableview':
                for line in v:
                    self.subTable.setItem(y, 0, QTableWidgetItem(line[0]))
                    self.subTable.setItem(y, 1, QTableWidgetItem(line[1]))
                    self.subTable.setItem(y, 2, QTableWidgetItem(line[2]))
                    y += 1

    def sendSub(self):
        self.assSummary.emit(
            [self.index, self.subDict[self.subCombox.currentText()]])
        self.hide()

    def assCheck(self, subtitlePath):
        self.subDict = {
            '': {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': []
            }
        }
        ass = codecs.open(subtitlePath, 'r', 'utf_8_sig')
        f = ass.readlines()
        ass.close()
        V4Token = False
        styleFormat = []
        styles = []
        eventToken = False
        eventFormat = []
        events = []
        for line in f:
            if '[V4+ Styles]' in line:
                V4Token = True
            elif V4Token and 'Format:' in line:
                styleFormat = line.strip().replace(' ',
                                                   '').split(':')[1].split(',')
            elif V4Token and 'Style:' in line and styleFormat:
                styles.append(line.strip().replace(
                    ' ', '').split(':')[1].split(','))
            elif '[Events]' in line:
                eventToken = True
                V4Token = False
            elif eventToken and 'Format:' in line:
                eventFormat = line.strip().replace(' ',
                                                   '').split(':')[1].split(',')
            elif eventToken and 'Comment:' in line and eventFormat:
                events.append(line.strip().replace(
                    ' ', '').split('Comment:')[1].split(','))
            elif eventToken and 'Dialogue:' in line and eventFormat:
                events.append(line.strip().replace(
                    ' ', '').split('Dialogue:')[1].split(','))

        for cnt, _format in enumerate(eventFormat):
            if _format == 'Start':
                Start = cnt
            elif _format == 'End':
                End = cnt
            elif _format == 'Style':
                Style = cnt
            elif _format == 'Text':
                Text = cnt

        for style in styles:
            styleName = style[0]
            self.subDict[styleName] = {
                'Fontname': '',
                'Fontsize': '',
                'PrimaryColour': '',
                'SecondaryColour': '',
                'OutlineColour': '',
                'BackColour': '',
                'Bold': '',
                'Italic': '',
                'Underline': '',
                'StrikeOut': '',
                'ScaleX': '',
                'ScaleY': '',
                'Spacing': '',
                'Angle': '',
                'BorderStyle': '',
                'Outline': '',
                'Shadow': '',
                'Alignment': '',
                'MarginL': '',
                'MarginR': '',
                'MarginV': '',
                'Encoding': '',
                'Tableview': [],
                'Events': {}
            }
            for cnt, _format in enumerate(styleFormat):
                if _format in self.subDict[styleName]:
                    self.subDict[styleName][_format] = style[cnt]
            for line in events:
                if styleName == line[Style]:
                    start = calSubTime(line[Start])
                    delta = calSubTime(line[End]) - start
                    self.subDict[styleName]['Tableview'].append(
                        [line[Start], line[End], line[Text]])
                    self.subDict[styleName]['Events'][start] = [
                        delta, line[Text]
                    ]

        self.subCombox.clear()
        combox = []
        for style in self.subDict.keys():
            if style:
                combox.append(style)
        self.subCombox.addItems(combox)
예제 #13
0
class ReferenceResultViewer(QDialog):
    PAGE_ROWS = 20
    logger = logging.getLogger("root.QGrain.ui.ReferenceResultViewer")

    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Reference Result Viewer"))
        self.__fitting_results = []
        self.__reference_map = {}
        self.retry_tasks = {}
        self.init_ui()
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.update_page_list()
        self.update_page(self.page_index)

        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))

        self.normal_msg = QMessageBox(self)

    def init_ui(self):
        self.data_table = QTableWidget(100, 100)
        self.data_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.data_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.data_table.setAlternatingRowColors(True)
        self.data_table.setContextMenuPolicy(Qt.CustomContextMenu)
        self.main_layout = QGridLayout(self)
        self.main_layout.addWidget(self.data_table, 0, 0, 1, 3)

        self.previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Previous"))
        self.previous_button.setToolTip(
            self.tr("Click to back to the previous page."))
        self.previous_button.clicked.connect(self.on_previous_button_clicked)
        self.current_page_combo_box = QComboBox()
        self.current_page_combo_box.addItem(self.tr("Page {0}").format(1))
        self.current_page_combo_box.currentIndexChanged.connect(
            self.update_page)
        self.next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                       self.tr("Next"))
        self.next_button.setToolTip(self.tr("Click to jump to the next page."))
        self.next_button.clicked.connect(self.on_next_button_clicked)
        self.main_layout.addWidget(self.previous_button, 1, 0)
        self.main_layout.addWidget(self.current_page_combo_box, 1, 1)
        self.main_layout.addWidget(self.next_button, 1, 2)

        self.distance_label = QLabel(self.tr("Distance"))
        self.distance_label.setToolTip(
            self.
            tr("It's the function to calculate the difference (on the contrary, similarity) between two samples."
               ))
        self.distance_combo_box = QComboBox()
        self.distance_combo_box.addItems(built_in_distances)
        self.distance_combo_box.setCurrentText("log10MSE")
        self.distance_combo_box.currentTextChanged.connect(
            lambda: self.update_page(self.page_index))
        self.main_layout.addWidget(self.distance_label, 2, 0)
        self.main_layout.addWidget(self.distance_combo_box, 2, 1, 1, 2)
        self.menu = QMenu(self.data_table)
        self.mark_action = self.menu.addAction(
            qta.icon("mdi.marker-check"),
            self.tr("Mark Selection(s) as Reference"))
        self.mark_action.triggered.connect(self.mark_selections)
        self.unmark_action = self.menu.addAction(
            qta.icon("mdi.do-not-disturb"), self.tr("Unmark Selection(s)"))
        self.unmark_action.triggered.connect(self.unmark_selections)
        self.remove_action = self.menu.addAction(
            qta.icon("fa.remove"), self.tr("Remove Selection(s)"))
        self.remove_action.triggered.connect(self.remove_selections)
        self.remove_all_action = self.menu.addAction(qta.icon("fa.remove"),
                                                     self.tr("Remove All"))
        self.remove_all_action.triggered.connect(self.remove_all_results)
        self.plot_loss_chart_action = self.menu.addAction(
            qta.icon("mdi.chart-timeline-variant"), self.tr("Plot Loss Chart"))
        self.plot_loss_chart_action.triggered.connect(self.show_distance)
        self.plot_distribution_chart_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"), self.tr("Plot Distribution Chart"))
        self.plot_distribution_chart_action.triggered.connect(
            self.show_distribution)
        self.plot_distribution_animation_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"),
            self.tr("Plot Distribution Chart (Animation)"))
        self.plot_distribution_animation_action.triggered.connect(
            self.show_history_distribution)

        self.load_dump_action = self.menu.addAction(
            qta.icon("fa.database"), self.tr("Load Binary Dump"))
        self.load_dump_action.triggered.connect(
            lambda: self.load_dump(mark_ref=True))
        self.save_dump_action = self.menu.addAction(
            qta.icon("fa.save"), self.tr("Save Binary Dump"))
        self.save_dump_action.triggered.connect(self.save_dump)
        self.data_table.customContextMenuRequested.connect(self.show_menu)

    def show_menu(self, pos):
        self.menu.popup(QCursor.pos())

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    @property
    def distance_name(self) -> str:
        return self.distance_combo_box.currentText()

    @property
    def distance_func(self) -> typing.Callable:
        return get_distance_func_by_name(self.distance_combo_box.currentText())

    @property
    def page_index(self) -> int:
        return self.current_page_combo_box.currentIndex()

    @property
    def n_pages(self) -> int:
        return self.current_page_combo_box.count()

    @property
    def n_results(self) -> int:
        return len(self.__fitting_results)

    @property
    def selections(self):
        start = self.page_index * self.PAGE_ROWS
        temp = set()
        for item in self.data_table.selectedRanges():
            for i in range(item.topRow(),
                           min(self.PAGE_ROWS + 1,
                               item.bottomRow() + 1)):
                temp.add(i + start)
        indexes = list(temp)
        indexes.sort()
        return indexes

    def update_page_list(self):
        last_page_index = self.page_index
        if self.n_results == 0:
            n_pages = 1
        else:
            n_pages, left = divmod(self.n_results, self.PAGE_ROWS)
            if left != 0:
                n_pages += 1
        self.current_page_combo_box.blockSignals(True)
        self.current_page_combo_box.clear()
        self.current_page_combo_box.addItems(
            [self.tr("Page {0}").format(i + 1) for i in range(n_pages)])
        if last_page_index >= n_pages:
            self.current_page_combo_box.setCurrentIndex(n_pages - 1)
        else:
            self.current_page_combo_box.setCurrentIndex(last_page_index)
        self.current_page_combo_box.blockSignals(False)

    def update_page(self, page_index: int):
        def write(row: int, col: int, value: str):
            if isinstance(value, str):
                pass
            elif isinstance(value, int):
                value = str(value)
            elif isinstance(value, float):
                value = f"{value: 0.4f}"
            else:
                value = value.__str__()
            item = QTableWidgetItem(value)
            item.setTextAlignment(Qt.AlignCenter)
            self.data_table.setItem(row, col, item)

        # necessary to clear
        self.data_table.clear()
        if page_index == self.n_pages - 1:
            start = page_index * self.PAGE_ROWS
            end = self.n_results
        else:
            start, end = page_index * self.PAGE_ROWS, (page_index +
                                                       1) * self.PAGE_ROWS
        self.data_table.setRowCount(end - start)
        self.data_table.setColumnCount(8)
        self.data_table.setHorizontalHeaderLabels([
            self.tr("Resolver"),
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("N_iterations"),
            self.tr("Spent Time [s]"),
            self.tr("Final Distance"),
            self.tr("Has Reference"),
            self.tr("Is Reference")
        ])
        sample_names = [
            result.sample.name for result in self.__fitting_results[start:end]
        ]
        self.data_table.setVerticalHeaderLabels(sample_names)
        for row, result in enumerate(self.__fitting_results[start:end]):
            write(row, 0, result.task.resolver)
            write(row, 1,
                  self.get_distribution_name(result.task.distribution_type))
            write(row, 2, result.task.n_components)
            write(row, 3, result.n_iterations)
            write(row, 4, result.time_spent)
            write(
                row, 5,
                self.distance_func(result.sample.distribution,
                                   result.distribution))
            has_ref = result.task.initial_guess is not None or result.task.reference is not None
            write(row, 6, self.tr("Yes") if has_ref else self.tr("No"))
            is_ref = result.uuid in self.__reference_map
            write(row, 7, self.tr("Yes") if is_ref else self.tr("No"))

        self.data_table.resizeColumnsToContents()

    def on_previous_button_clicked(self):
        if self.page_index > 0:
            self.current_page_combo_box.setCurrentIndex(self.page_index - 1)

    def on_next_button_clicked(self):
        if self.page_index < self.n_pages - 1:
            self.current_page_combo_box.setCurrentIndex(self.page_index + 1)

    def get_distribution_name(self, distribution_type: DistributionType):
        if distribution_type == DistributionType.Normal:
            return self.tr("Normal")
        elif distribution_type == DistributionType.Weibull:
            return self.tr("Weibull")
        elif distribution_type == DistributionType.SkewNormal:
            return self.tr("Skew Normal")
        else:
            raise NotImplementedError(distribution_type)

    def add_result(self, result: SSUResult):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.append(result)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def add_results(self, results: typing.List[SSUResult]):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.extend(results)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def mark_results(self, results: typing.List[SSUResult]):
        for result in results:
            self.__reference_map[result.uuid] = result

        self.update_page(self.page_index)

    def unmark_results(self, results: typing.List[SSUResult]):
        for result in results:
            if result.uuid in self.__reference_map:
                self.__reference_map.pop(result.uuid)

        self.update_page(self.page_index)

    def add_references(self, results: typing.List[SSUResult]):
        self.add_results(results)
        self.mark_results(results)

    def mark_selections(self):
        results = [
            self.__fitting_results[selection] for selection in self.selections
        ]
        self.mark_results(results)

    def unmark_selections(self):
        results = [
            self.__fitting_results[selection] for selection in self.selections
        ]
        self.unmark_results(results)

    def remove_results(self, indexes):
        results = []
        for i in reversed(indexes):
            res = self.__fitting_results.pop(i)
            results.append(res)
        self.unmark_results(results)
        self.update_page_list()
        self.update_page(self.page_index)

    def remove_selections(self):
        indexes = self.selections
        self.remove_results(indexes)

    def remove_all_results(self):
        res = self.remove_warning_msg.exec_()
        if res == QMessageBox.Yes:
            self.__fitting_results.clear()
            self.update_page_list()
            self.update_page(0)

    def show_distance(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.distance_chart.show_distance_series(result.get_distance_series(
            self.distance_name),
                                                 title=result.sample.name)
        self.distance_chart.show()

    def show_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_model(result.view_model)
        self.mixed_distribution_chart.show()

    def show_history_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_result(result)
        self.mixed_distribution_chart.show()

    def load_dump(self, mark_ref=False):
        filename, _ = self.file_dialog.getOpenFileName(
            self, self.tr("Select a binary dump file of SSU results"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "rb") as f:
            results = pickle.load(f)
            valid = True
            if isinstance(results, list):
                for result in results:
                    if not isinstance(result, SSUResult):
                        valid = False
                        break
            else:
                valid = False

            if valid:
                self.add_results(results)
                if mark_ref:
                    self.mark_results(results)
            else:
                self.show_error(self.tr("The binary dump file is invalid."))

    def save_dump(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Save the SSU results to binary dump file"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "wb") as f:
            pickle.dump(self.__fitting_results, f)

    def find_similar(self, target: GrainSizeSample,
                     ref_results: typing.List[SSUResult]):
        assert len(ref_results) != 0
        # sample_moments = logarithmic(sample.classes_φ, sample.distribution)
        # keys_to_check = ["mean", "std", "skewness", "kurtosis"]

        start_time = time.time()
        from scipy.interpolate import interp1d
        min_distance = 1e100
        min_result = None
        trans_func = interp1d(target.classes_φ,
                              target.distribution,
                              bounds_error=False,
                              fill_value=0.0)
        for result in ref_results:
            # TODO: To scale the classes of result to that of sample
            # use moments to calculate? MOMENTS MAY NOT BE PERFECT, MAY IGNORE THE MINOR DIFFERENCE
            # result_moments = logarithmic(result.classes_φ, result.distribution)
            # distance = sum([(sample_moments[key]-result_moments[key])**2 for key in keys_to_check])
            trans_dist = trans_func(result.classes_φ)
            distance = self.distance_func(result.distribution, trans_dist)

            if distance < min_distance:
                min_distance = distance
                min_result = result

        self.logger.debug(
            f"It took {time.time()-start_time:0.4f} s to query the reference from {len(ref_results)} results."
        )
        return min_result

    def query_reference(self, sample: GrainSizeSample):
        if len(self.__reference_map) == 0:
            self.logger.debug("No result is marked as reference.")
            return None
        return self.find_similar(sample, self.__reference_map.values())
class SpreadSheet(QTableWidget):
    def __init__(self, rows, cols, Imagelimit, parent=None):
        super().__init__(parent)

        self.control = ControlBox()

        #print('Spread sheet is active')
        self.rows = rows
        self.cols = cols
        self.Image_limit = Imagelimit
        self.table = QTableWidget(self.rows, self.cols, self)

        self.count = 0

        self.jointname = [
            '0. nose', '1.neck', '2.Right shoulder', "3.Right elbow",
            "4.Right hand", "5.Left shoulder", "6.Left elbow", "7.Left hand",
            "8.Right hip", "9.Right knee", "10.Right foot", "11.Left hip",
            "12.Left knee", "13.Left foot", "14.Right eye", "15.Left eye",
            "16.Right ear", "17.Left ear"
        ]

        self.update_table()

    def update_table(self):
        if self.count < self.Image_limit:
            if not self.count == 0:
                self.Save_table()
            self.count_UP()
            self.Show_table()

    def prev_table(self):
        if self.count > 1:
            if not self.count == 0:
                self.Save_table()
            self.count_DOWN()
            self.Show_table()

    def Show_table(self):

        self.table.setSizeAdjustPolicy(QTableWidget.AdjustToContents)

        for c, joint in enumerate(self.jointname):
            self.table.setHorizontalHeaderItem(c, QTableWidgetItem(joint))

        self.next_csv = './Joint_csv/{}.csv'.format(self.count)
        human_num = []  # ???csv?????????????????????????????????????

        if not os.path.exists(self.next_csv):
            print('You dont have next_csv')
            with open(self.next_csv, 'w') as f:
                writer = csv.writer(f)
        else:
            print('csv is activate', self.next_csv)
            with open(self.next_csv) as f_read:
                reader = csv.reader(f_read)
                reader_list = [row for row in reader]
                for rindex, row in enumerate(reader_list):
                    #print('reader_list is ', rindex, row, '\n')
                    if row:
                        human_num.append(
                            rindex)  # csv???????????????????????????
                    for cindex, column in enumerate(row):
                        self.table.setItem(rindex, cindex,
                                           QTableWidgetItem(column))

        print('human_num is', human_num)
        self.control.humancombo.clear()
        if human_num:
            for i in human_num:
                self.control.humancombo.addItem(str(i + 1))
        else:
            print('else of human_num')
            self.control.humancombo.addItem('1')

        return self

    def Save_table(self):
        print('spreadsheet is saved')
        self.pre_csv = './Joint_csv/{}.csv'.format(self.count)
        print('save to', self.pre_csv)
        with open(self.pre_csv, 'w') as f_out:
            writer = csv.writer(f_out, lineterminator='\n')
            for row in range(self.rows):
                row_data = []
                for column in range(self.cols):
                    try:
                        item = self.table.item(row, column).text()
                        #print('item is', item)
                        row_data.append(item)
                    except AttributeError:
                        row_data.append('')
                #print('row_data is ', row_data)
                writer.writerow(row_data)
        return self

    def count_UP(self):
        self.count += 1
        self.table.clear()

    def count_DOWN(self):
        self.count -= 1
        self.table.clear()

    def joint_mouseEvent(self, x, y):
        #print('You enter the joint_mouseEvent / mouse coord is ', x, y)
        self.coord = str(x) + ',' + str(y)
        #print(self.coord)
        self.joint_row = self.control.jointcombo.currentIndex()
        self.joint_column = int(self.control.humancombo.currentText()) - 1
        #print(self.joint_row)
        #print(self.joint_column)
        #print('Youre going to insert coord  ', self.joint_row, self.joint_column)
        self.table.setItem(self.joint_column, self.joint_row,
                           QTableWidgetItem(self.coord))
예제 #15
0
파일: pay.py 프로젝트: zmyang789/DD_KaoRou2
class pay(QDialog):
    def __init__(self):
        super().__init__()
        self.setWindowTitle('赞助和支持')
        self.resize(600, 520)
        layout = QGridLayout()
        self.setLayout(layout)
        txt = u'DD烤肉机由B站up:执鸣神君 业余时间独立开发制作。\n\
\n所有功能全部永久免费给广大烤肉man使用,无需专门找我获取授权。\n\
\n有独立经济来源的老板们如觉得烤肉机好用的话,不妨小小支持亿下\n\
\n一元也是对我继续更新烤肉机的莫大鼓励。十分感谢!\n'

        label = QLabel(txt)
        label.setTextInteractionFlags(Qt.TextSelectableByMouse)
        label.setAlignment(Qt.AlignCenter)
        layout.addWidget(label, 0, 0, 1, 1)

        bilibili_url = QLabel()
        bilibili_url.setAlignment(Qt.AlignCenter)
        bilibili_url.setOpenExternalLinks(True)
        bilibili_url.setText(
            _translate(
                "MainWindow",
                "<html><head/><body><p><a href=\"https://space.bilibili.com/637783\">\
<span style=\" text-decoration: underline; color:#cccccc;\">执鸣神君B站主页: https://space.bilibili.com/637783</span></a></p></body></html>",
                None))
        layout.addWidget(bilibili_url, 1, 0, 1, 1)

        github_url = QLabel()
        github_url.setAlignment(Qt.AlignCenter)
        github_url.setOpenExternalLinks(True)
        github_url.setText(
            _translate(
                "MainWindow",
                "<html><head/><body><p><a href=\"https://github.com/jiafangjun/DD_KaoRou2\">\
<span style=\" text-decoration: underline; color:#cccccc;\">烤肉机项目开源地址: https://github.com/jiafangjun/DD_KaoRou2</span></a></p></body></html>",
                None))
        layout.addWidget(github_url, 2, 0, 1, 1)

        layout.addWidget(QLabel(), 3, 0, 1, 1)
        alipay = QLabel()
        alipay.setFixedSize(260, 338)
        alipay.setStyleSheet('border-image: url(:/images/0.jpg)')
        layout.addWidget(alipay, 4, 0, 1, 1)
        weixin = QLabel()
        weixin.setFixedSize(260, 338)
        weixin.setStyleSheet('border-image: url(:/images/1.jpg)')
        layout.addWidget(weixin, 4, 1, 1, 1)
        layout.addWidget(QLabel(), 5, 0, 1, 1)

        self.bossTable = QTableWidget()
        self.bossTable.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.bossTable.setRowCount(3)
        self.bossTable.setColumnCount(2)
        for i in range(2):
            self.bossTable.setColumnWidth(i, 105)
        self.bossTable.setHorizontalHeaderLabels(['石油王', '打赏'])
        self.bossTable.setItem(0, 0, QTableWidgetItem('石油王鸣谢名单'))
        self.bossTable.setItem(0, 1, QTableWidgetItem('正在获取...'))
        layout.addWidget(self.bossTable, 0, 1, 3, 1)

        self.thankToBoss = thankToBoss()
        self.thankToBoss.bossList.connect(self.updateBossList)
        self.thankToBoss.start()

    def updateBossList(self, bossList):
        self.bossTable.clear()
        self.bossTable.setColumnCount(2)
        self.bossTable.setRowCount(len(bossList))
        if len(bossList) > 3:
            biggestBossList = []
            for _ in range(3):
                sc = 0
                for cnt, i in enumerate(bossList):
                    money = float(i[1].split(' ')[0])
                    if money > sc:
                        sc = money
                        bossNum = cnt
                biggestBossList.append(bossList.pop(bossNum))
            for y, i in enumerate(biggestBossList):
                self.bossTable.setItem(y, 0, QTableWidgetItem(i[0]))
                self.bossTable.setItem(y, 1, QTableWidgetItem(i[1]))
                self.bossTable.item(y, 0).setTextAlignment(Qt.AlignCenter)
                self.bossTable.item(y, 1).setTextAlignment(Qt.AlignCenter)
            for y, i in enumerate(bossList):
                self.bossTable.setItem(y + 3, 0, QTableWidgetItem(i[0]))
                self.bossTable.setItem(y + 3, 1, QTableWidgetItem(i[1]))
                self.bossTable.item(y + 3, 0).setTextAlignment(Qt.AlignCenter)
                self.bossTable.item(y + 3, 1).setTextAlignment(Qt.AlignCenter)
            self.bossTable.setHorizontalHeaderLabels(['石油王', '打赏'])