Пример #1
0
    def __init__(self, reference_viewer: ReferenceResultViewer, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Fitting Result Viewer"))
        self.__fitting_results = []  # type: list[SSUResult]
        self.retry_tasks = {}  # type: dict[UUID, SSUTask]
        self.__reference_viewer = reference_viewer
        self.init_ui()
        self.boxplot_chart = BoxplotChart(parent=self, toolbar=True)
        self.typical_chart = SSUTypicalComponentChart(parent=self,
                                                      toolbar=True)
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_fitting_succeeded)
        self.async_worker.background_worker.task_failed.connect(
            self.on_fitting_failed)
        self.update_page_list()
        self.update_page(self.page_index)

        self.normal_msg = QMessageBox(self)
        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))
        self.outlier_msg = QMessageBox(self)
        self.outlier_msg.setStandardButtons(QMessageBox.Discard
                                            | QMessageBox.Retry
                                            | QMessageBox.Ignore)
        self.outlier_msg.setDefaultButton(QMessageBox.Ignore)
        self.retry_progress_msg = QMessageBox()
        self.retry_progress_msg.addButton(QMessageBox.Ok)
        self.retry_progress_msg.button(QMessageBox.Ok).hide()
        self.retry_progress_msg.setWindowTitle(self.tr("Progress"))
        self.retry_timer = QTimer(self)
        self.retry_timer.setSingleShot(True)
        self.retry_timer.timeout.connect(
            lambda: self.retry_progress_msg.exec_())
Пример #2
0
    def initialize_ui(self):
        self.main_layout = QGridLayout(self)

        self.chart_group = QGroupBox(self.tr("Chart"))
        self.chart_layout = QGridLayout(self.chart_group)
        self.chart = MixedDistributionChart(show_mode=True, toolbar=False)
        self.chart_layout.addWidget(self.chart)

        self.control_group = QGroupBox(self.tr("Control"))
        self.control_layout = QGridLayout(self.control_group)
        self.try_button = QPushButton(qta.icon("mdi.test-tube"),
                                      self.tr("Try"))
        self.try_button.clicked.connect(self.on_try_clicked)
        self.control_layout.addWidget(self.try_button, 1, 0, 1, 4)
        self.confirm_button = QPushButton(qta.icon("ei.ok-circle"),
                                          self.tr("Confirm"))
        self.confirm_button.clicked.connect(self.on_confirm_clicked)
        self.control_layout.addWidget(self.confirm_button, 2, 0, 1, 4)

        self.splitter = QSplitter(Qt.Horizontal)
        self.splitter.addWidget(self.chart_group)
        self.splitter.addWidget(self.control_group)
        self.main_layout.addWidget(self.splitter)
Пример #3
0
    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Reference Result Viewer"))
        self.__fitting_results = []
        self.__reference_map = {}
        self.retry_tasks = {}
        self.init_ui()
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.update_page_list()
        self.update_page(self.page_index)

        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))

        self.normal_msg = QMessageBox(self)
Пример #4
0
 def init_ui(self):
     self.setAttribute(Qt.WA_StyledBackground, True)
     self.main_layout = QGridLayout(self)
     # self.main_layout.setContentsMargins(0, 0, 0, 0)
     # control group
     self.control_group = QGroupBox(self.tr("Control"))
     self.control_layout = QGridLayout(self.control_group)
     self.resolver_label = QLabel(self.tr("Resolver"))
     self.resolver_combo_box = QComboBox()
     self.resolver_combo_box.addItems(["classic", "neural"])
     self.control_layout.addWidget(self.resolver_label, 0, 0)
     self.control_layout.addWidget(self.resolver_combo_box, 0, 1)
     self.configure_generating_button = QPushButton(
         qta.icon("fa.cubes"), self.tr("Configure Sample Generating"))
     self.configure_generating_button.clicked.connect(
         self.on_configure_generating_clicked)
     self.configure_fitting_button = QPushButton(
         qta.icon("fa.gears"), self.tr("Configure Fitting Algorithm"))
     self.configure_fitting_button.clicked.connect(
         self.on_configure_fitting_clicked)
     self.control_layout.addWidget(self.configure_generating_button, 1, 0)
     self.control_layout.addWidget(self.configure_fitting_button, 1, 1)
     self.distribution_label = QLabel(self.tr("Distribution Type"))
     self.distribution_combo_box = QComboBox()
     self.distribution_combo_box.addItems(
         [name for _, name in self.distribution_types])
     self.component_number_label = QLabel(self.tr("Component Number"))
     self.n_components_input = QSpinBox()
     self.n_components_input.setRange(1, 10)
     self.n_components_input.setValue(3)
     self.control_layout.addWidget(self.distribution_label, 2, 0)
     self.control_layout.addWidget(self.distribution_combo_box, 2, 1)
     self.control_layout.addWidget(self.component_number_label, 3, 0)
     self.control_layout.addWidget(self.n_components_input, 3, 1)
     self.single_test_button = QPushButton(qta.icon("fa.play-circle"),
                                           self.tr("Single Test"))
     self.single_test_button.clicked.connect(self.on_single_test_clicked)
     self.continuous_test_button = QPushButton(
         qta.icon("mdi.playlist-play"), self.tr("Continuous Test"))
     self.continuous_test_button.clicked.connect(
         self.on_continuous_test_clicked)
     self.control_layout.addWidget(self.single_test_button, 4, 0)
     self.control_layout.addWidget(self.continuous_test_button, 4, 1)
     self.clear_stats_button = QPushButton(qta.icon("fa.eraser"),
                                           self.tr("Clear Statistics"))
     self.clear_stats_button.clicked.connect(self.clear_records)
     self.control_layout.addWidget(self.clear_stats_button, 5, 0, 1, 2)
     # chart group
     self.chart_group = QGroupBox(self.tr("Chart"))
     self.chart_layout = QGridLayout(self.chart_group)
     self.sample_chart = MixedDistributionChart(show_mode=True,
                                                toolbar=False)
     self.result_chart = MixedDistributionChart(show_mode=True,
                                                toolbar=False)
     self.chart_layout.addWidget(self.sample_chart, 0, 0)
     self.chart_layout.addWidget(self.result_chart, 0, 1)
     # stats group
     self.stats_group = QGroupBox(self.tr("Statistics"))
     self.stats_layout = QGridLayout(self.stats_group)
     self.n_task_label = QLabel(self.tr("Total Tasks:"))
     self.n_tasks_display = QLabel("0")
     self.n_failed_tasks_label = QLabel(self.tr("Failed Tasks:"))
     self.n_failed_tasks_display = QLabel("0")
     self.n_unqualified_tasks_label = QLabel(self.tr("Unqualified Tasks:"))
     self.n_unquelified_tasks_display = QLabel("0")
     self.stats_layout.addWidget(self.n_task_label, 0, 0)
     self.stats_layout.addWidget(self.n_tasks_display, 0, 1)
     self.stats_layout.addWidget(self.n_failed_tasks_label, 1, 0)
     self.stats_layout.addWidget(self.n_failed_tasks_display, 1, 1)
     self.stats_layout.addWidget(self.n_unqualified_tasks_label, 2, 0)
     self.stats_layout.addWidget(self.n_unquelified_tasks_display, 2, 1)
     self.mean_spent_time_label = QLabel(self.tr("Mean Spent Time [s]:"))
     self.mean_spent_time_display = QLabel("0.0")
     self.mean_n_iterations_label = QLabel(
         self.tr("Mean N<sub>iterations</sub>:"))
     self.mean_n_iterations_display = QLabel("0")
     self.mean_distance_label = QLabel(self.tr("Mean distance:"))
     self.mean_distance_display = QLabel("0.0")
     self.stats_layout.addWidget(self.mean_spent_time_label, 3, 0)
     self.stats_layout.addWidget(self.mean_spent_time_display, 3, 1)
     self.stats_layout.addWidget(self.mean_n_iterations_label, 4, 0)
     self.stats_layout.addWidget(self.mean_n_iterations_display, 4, 1)
     self.stats_layout.addWidget(self.mean_distance_label, 5, 0)
     self.stats_layout.addWidget(self.mean_distance_display, 5, 1)
     # table group
     self.table_group = QGroupBox(self.tr("Table"))
     self.reference_view = ReferenceResultViewer()
     self.result_view = FittingResultViewer(self.reference_view)
     self.result_view.result_marked.connect(
         lambda result: self.reference_view.add_references([result]))
     self.table_tab = QTabWidget()
     self.table_tab.addTab(self.result_view, qta.icon("fa.cubes"),
                           self.tr("Result"))
     self.table_tab.addTab(self.reference_view, qta.icon("fa5s.key"),
                           self.tr("Reference"))
     self.result_layout = QGridLayout(self.table_group)
     self.result_layout.addWidget(self.table_tab, 0, 0)
     # pack all group
     self.splitter1 = QSplitter(Qt.Orientation.Horizontal)
     self.splitter1.addWidget(self.control_group)
     self.splitter1.addWidget(self.stats_group)
     self.splitter2 = QSplitter(Qt.Orientation.Vertical)
     self.splitter2.addWidget(self.splitter1)
     self.splitter2.addWidget(self.chart_group)
     self.splitter3 = QSplitter(Qt.Orientation.Horizontal)
     self.splitter3.addWidget(self.splitter2)
     self.splitter3.addWidget(self.table_group)
     self.main_layout.addWidget(self.splitter3, 0, 0)
Пример #5
0
class SSUAlgorithmTesterPanel(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("Algorithm Tester"))
        self.distribution_types = [
            (DistributionType.Normal, self.tr("Normal")),
            (DistributionType.Weibull, self.tr("Weibull")),
            (DistributionType.SkewNormal, self.tr("Skew Normal"))
        ]
        self.generate_setting = RandomDatasetGenerator(parent=self)
        self.classic_setting = ClassicResolverSettingWidget(parent=self)
        self.neural_setting = NNResolverSettingWidget(parent=self)
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_fitting_succeeded)
        self.async_worker.background_worker.task_failed.connect(
            self.on_fitting_failed)
        self.task_table = {}
        self.task_results = {}
        self.failed_task_ids = []
        self.unquelified_task_ids = []
        self.__continuous_flag = False
        self.init_ui()

    def init_ui(self):
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.main_layout = QGridLayout(self)
        # self.main_layout.setContentsMargins(0, 0, 0, 0)
        # control group
        self.control_group = QGroupBox(self.tr("Control"))
        self.control_layout = QGridLayout(self.control_group)
        self.resolver_label = QLabel(self.tr("Resolver"))
        self.resolver_combo_box = QComboBox()
        self.resolver_combo_box.addItems(["classic", "neural"])
        self.control_layout.addWidget(self.resolver_label, 0, 0)
        self.control_layout.addWidget(self.resolver_combo_box, 0, 1)
        self.configure_generating_button = QPushButton(
            qta.icon("fa.cubes"), self.tr("Configure Sample Generating"))
        self.configure_generating_button.clicked.connect(
            self.on_configure_generating_clicked)
        self.configure_fitting_button = QPushButton(
            qta.icon("fa.gears"), self.tr("Configure Fitting Algorithm"))
        self.configure_fitting_button.clicked.connect(
            self.on_configure_fitting_clicked)
        self.control_layout.addWidget(self.configure_generating_button, 1, 0)
        self.control_layout.addWidget(self.configure_fitting_button, 1, 1)
        self.distribution_label = QLabel(self.tr("Distribution Type"))
        self.distribution_combo_box = QComboBox()
        self.distribution_combo_box.addItems(
            [name for _, name in self.distribution_types])
        self.component_number_label = QLabel(self.tr("Component Number"))
        self.n_components_input = QSpinBox()
        self.n_components_input.setRange(1, 10)
        self.n_components_input.setValue(3)
        self.control_layout.addWidget(self.distribution_label, 2, 0)
        self.control_layout.addWidget(self.distribution_combo_box, 2, 1)
        self.control_layout.addWidget(self.component_number_label, 3, 0)
        self.control_layout.addWidget(self.n_components_input, 3, 1)
        self.single_test_button = QPushButton(qta.icon("fa.play-circle"),
                                              self.tr("Single Test"))
        self.single_test_button.clicked.connect(self.on_single_test_clicked)
        self.continuous_test_button = QPushButton(
            qta.icon("mdi.playlist-play"), self.tr("Continuous Test"))
        self.continuous_test_button.clicked.connect(
            self.on_continuous_test_clicked)
        self.control_layout.addWidget(self.single_test_button, 4, 0)
        self.control_layout.addWidget(self.continuous_test_button, 4, 1)
        self.clear_stats_button = QPushButton(qta.icon("fa.eraser"),
                                              self.tr("Clear Statistics"))
        self.clear_stats_button.clicked.connect(self.clear_records)
        self.control_layout.addWidget(self.clear_stats_button, 5, 0, 1, 2)
        # chart group
        self.chart_group = QGroupBox(self.tr("Chart"))
        self.chart_layout = QGridLayout(self.chart_group)
        self.sample_chart = MixedDistributionChart(show_mode=True,
                                                   toolbar=False)
        self.result_chart = MixedDistributionChart(show_mode=True,
                                                   toolbar=False)
        self.chart_layout.addWidget(self.sample_chart, 0, 0)
        self.chart_layout.addWidget(self.result_chart, 0, 1)
        # stats group
        self.stats_group = QGroupBox(self.tr("Statistics"))
        self.stats_layout = QGridLayout(self.stats_group)
        self.n_task_label = QLabel(self.tr("Total Tasks:"))
        self.n_tasks_display = QLabel("0")
        self.n_failed_tasks_label = QLabel(self.tr("Failed Tasks:"))
        self.n_failed_tasks_display = QLabel("0")
        self.n_unqualified_tasks_label = QLabel(self.tr("Unqualified Tasks:"))
        self.n_unquelified_tasks_display = QLabel("0")
        self.stats_layout.addWidget(self.n_task_label, 0, 0)
        self.stats_layout.addWidget(self.n_tasks_display, 0, 1)
        self.stats_layout.addWidget(self.n_failed_tasks_label, 1, 0)
        self.stats_layout.addWidget(self.n_failed_tasks_display, 1, 1)
        self.stats_layout.addWidget(self.n_unqualified_tasks_label, 2, 0)
        self.stats_layout.addWidget(self.n_unquelified_tasks_display, 2, 1)
        self.mean_spent_time_label = QLabel(self.tr("Mean Spent Time [s]:"))
        self.mean_spent_time_display = QLabel("0.0")
        self.mean_n_iterations_label = QLabel(
            self.tr("Mean N<sub>iterations</sub>:"))
        self.mean_n_iterations_display = QLabel("0")
        self.mean_distance_label = QLabel(self.tr("Mean distance:"))
        self.mean_distance_display = QLabel("0.0")
        self.stats_layout.addWidget(self.mean_spent_time_label, 3, 0)
        self.stats_layout.addWidget(self.mean_spent_time_display, 3, 1)
        self.stats_layout.addWidget(self.mean_n_iterations_label, 4, 0)
        self.stats_layout.addWidget(self.mean_n_iterations_display, 4, 1)
        self.stats_layout.addWidget(self.mean_distance_label, 5, 0)
        self.stats_layout.addWidget(self.mean_distance_display, 5, 1)
        # table group
        self.table_group = QGroupBox(self.tr("Table"))
        self.reference_view = ReferenceResultViewer()
        self.result_view = FittingResultViewer(self.reference_view)
        self.result_view.result_marked.connect(
            lambda result: self.reference_view.add_references([result]))
        self.table_tab = QTabWidget()
        self.table_tab.addTab(self.result_view, qta.icon("fa.cubes"),
                              self.tr("Result"))
        self.table_tab.addTab(self.reference_view, qta.icon("fa5s.key"),
                              self.tr("Reference"))
        self.result_layout = QGridLayout(self.table_group)
        self.result_layout.addWidget(self.table_tab, 0, 0)
        # pack all group
        self.splitter1 = QSplitter(Qt.Orientation.Horizontal)
        self.splitter1.addWidget(self.control_group)
        self.splitter1.addWidget(self.stats_group)
        self.splitter2 = QSplitter(Qt.Orientation.Vertical)
        self.splitter2.addWidget(self.splitter1)
        self.splitter2.addWidget(self.chart_group)
        self.splitter3 = QSplitter(Qt.Orientation.Horizontal)
        self.splitter3.addWidget(self.splitter2)
        self.splitter3.addWidget(self.table_group)
        self.main_layout.addWidget(self.splitter3, 0, 0)

    @property
    def distribution_type(self) -> DistributionType:
        distribution_type, _ = self.distribution_types[
            self.distribution_combo_box.currentIndex()]
        return distribution_type

    @property
    def n_components(self) -> int:
        return self.n_components_input.value()

    def on_configure_generating_clicked(self):
        self.generate_setting.show()

    def on_configure_fitting_clicked(self):
        if self.resolver_combo_box.currentText() == "classic":
            self.classic_setting.show()
        else:
            self.neural_setting.show()

    def update_sample_chart(self, artificial_sample: ArtificialSample):
        self.sample_chart.show_model(artificial_sample.view_model)

    def update_fitting_chart(self, fitting_result: SSUResult):
        self.result_chart.show_model(fitting_result.view_model)

    def evaluate_result(self,
                        artificial_sample: ArtificialSample,
                        fitting_result: SSUResult,
                        tolerance: float = 0.1):
        component_errors = []
        unqualified = False
        for target, result in zip(artificial_sample.components,
                                  fitting_result.components):
            target_moments = logarithmic(artificial_sample.classes_φ,
                                         target.distribution)
            result_moments = logarithmic(artificial_sample.classes_φ,
                                         result.distribution)
            mean_error = np.abs(
                (target_moments["mean"] - result_moments["mean"]) /
                target_moments["mean"])
            fraction_error = np.abs(
                (target.fraction - result.fraction) / target.fraction)
            component_errors.append((mean_error, fraction_error))
            unqualified = (mean_error > tolerance) or (fraction_error >
                                                       tolerance)
        return unqualified, component_errors

    def generate_task(self, query_ref=True):
        artificial_sample = self.generate_setting.get_random_sample()
        resolver = self.resolver_combo_box.currentText()
        if resolver == "classic":
            setting = self.classic_setting.setting
        else:
            setting = self.neural_setting.setting
        sample = artificial_sample.sample_to_fit
        query = self.reference_view.query_reference(sample)  # type: SSUResult
        if not query_ref or query is None:
            task = SSUTask(sample,
                           self.distribution_type,
                           self.n_components,
                           resolver=resolver,
                           resolver_setting=setting)
        else:
            keys = ["mean", "std", "skewness"]
            reference = [{key: comp.logarithmic_moments[key]
                          for key in keys} for comp in query.components]
            task = SSUTask(sample,
                           query.distribution_type,
                           query.n_components,
                           resolver=resolver,
                           resolver_setting=setting,
                           reference=reference)
        return artificial_sample, task

    def update_stats(self):
        n_tasks = len(self.task_table)
        n_failed = len(self.failed_task_ids)
        n_unquelified = len(self.unquelified_task_ids)
        mean_spent_time = np.mean(
            [result.time_spent for uuid, result in self.task_results.items()])
        mean_n_iterations = np.mean([
            result.n_iterations for uuid, result in self.task_results.items()
        ])
        mean_distance = np.mean([
            result.get_distance(self.result_view.distance_name)
            for uuid, result in self.task_results.items()
        ])
        self.n_tasks_display.setText(str(n_tasks))
        self.n_failed_tasks_display.setText(str(n_failed))
        self.n_unquelified_tasks_display.setText(str(n_unquelified))
        self.mean_spent_time_display.setText(f"{mean_spent_time:0.4f}")
        self.mean_n_iterations_display.setText(f"{mean_n_iterations:0.2f}")
        self.mean_distance_display.setText(f"{mean_distance:0.4f}")

    def on_fitting_succeeded(self, fitting_result: SSUResult):
        # update chart
        self.update_sample_chart(self.task_table[fitting_result.task.uuid][0])
        self.update_fitting_chart(fitting_result)
        self.task_results[fitting_result.task.uuid] = fitting_result
        self.result_view.add_result(fitting_result)
        if not fitting_result.is_valid:
            self.unquelified_task_ids.append(fitting_result.task.uuid)
        else:
            unqualified, errors = self.evaluate_result(
                self.task_table[fitting_result.task.uuid][0], fitting_result)
            if unqualified:
                self.unquelified_task_ids.append(fitting_result.task.uuid)

        self.update_stats()

        if self.__continuous_flag:
            self.do_test()
        self.single_test_button.setEnabled(True)
        self.continuous_test_button.setEnabled(True)
        self.clear_stats_button.setEnabled(True)

    def on_fitting_failed(self, failed_info: str, task: SSUTask):
        self.failed_task_ids.append(task.uuid)
        self.update_stats()
        if self.__continuous_flag:
            self.do_test()
        self.single_test_button.setEnabled(True)
        self.continuous_test_button.setEnabled(True)
        self.clear_stats_button.setEnabled(True)

    def clear_records(self):
        self.task_table = {}
        self.task_results = {}
        self.failed_task_ids = []
        self.unquelified_task_ids = []
        self.n_tasks_display.setText("0")
        self.n_failed_tasks_display.setText("0")
        self.n_unquelified_tasks_display.setText("0")
        self.mean_spent_time_display.setText("0.0")
        self.mean_n_iterations_display.setText("0")
        self.mean_distance_display.setText("0.0")

    def do_test(self):
        self.single_test_button.setEnabled(False)
        self.clear_stats_button.setEnabled(False)
        if not self.__continuous_flag:
            self.continuous_test_button.setEnabled(False)
        artificial_sample, task = self.generate_task()
        self.task_table[task.uuid] = (artificial_sample, task)
        self.async_worker.execute_task(task)

    def on_single_test_clicked(self):
        self.do_test()

    def on_continuous_test_clicked(self):
        if self.__continuous_flag:
            self.__continuous_flag = not self.__continuous_flag
            self.continuous_test_button.setText(self.tr("Continuous Test"))
        else:
            self.continuous_test_button.setText(self.tr("Cancel"))
            self.__continuous_flag = not self.__continuous_flag
            self.do_test()
Пример #6
0
class ManualFittingPanel(QDialog):
    manual_fitting_finished = Signal(SSUResult)

    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Manual Fitting Panel"))
        self.control_widgets = []
        self.input_widgets = []
        self.last_task = None
        self.last_result = None
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_task_succeeded)
        self.initialize_ui()
        self.normal_msg = QMessageBox(self)
        self.chart_timer = QTimer()
        self.chart_timer.timeout.connect(self.update_chart)
        self.chart_timer.setSingleShot(True)

    def initialize_ui(self):
        self.main_layout = QGridLayout(self)

        self.chart_group = QGroupBox(self.tr("Chart"))
        self.chart_layout = QGridLayout(self.chart_group)
        self.chart = MixedDistributionChart(show_mode=True, toolbar=False)
        self.chart_layout.addWidget(self.chart)

        self.control_group = QGroupBox(self.tr("Control"))
        self.control_layout = QGridLayout(self.control_group)
        self.try_button = QPushButton(qta.icon("mdi.test-tube"),
                                      self.tr("Try"))
        self.try_button.clicked.connect(self.on_try_clicked)
        self.control_layout.addWidget(self.try_button, 1, 0, 1, 4)
        self.confirm_button = QPushButton(qta.icon("ei.ok-circle"),
                                          self.tr("Confirm"))
        self.confirm_button.clicked.connect(self.on_confirm_clicked)
        self.control_layout.addWidget(self.confirm_button, 2, 0, 1, 4)

        self.splitter = QSplitter(Qt.Horizontal)
        self.splitter.addWidget(self.chart_group)
        self.splitter.addWidget(self.control_group)
        self.main_layout.addWidget(self.splitter)

    def change_n_components(self, n_components: int):
        for widget in self.control_widgets:
            self.control_layout.removeWidget(widget)
            widget.hide()
        self.control_widgets.clear()
        self.input_widgets.clear()

        widgets = []
        slider_range = (0, 1000)
        input_widgets = []
        mean_range = (-5, 15)
        std_range = (0.0, 10)
        weight_range = (0, 10)
        names = [self.tr("Mean"), self.tr("STD"), self.tr("Weight")]
        ranges = [mean_range, std_range, weight_range]
        slider_values = [500, 100, 100]
        input_values = [0.0, 1.0, 1.0]

        for i in range(n_components):
            group = QGroupBox(f"C{i+1}")
            group.setMinimumWidth(200)
            group_layout = QGridLayout(group)
            inputs = []
            for j, (name, range_, slider_value, input_value) in enumerate(
                    zip(names, ranges, slider_values, input_values)):
                label = QLabel(name)
                slider = QSlider()
                slider.setRange(*slider_range)
                slider.setValue(slider_value)
                slider.setOrientation(Qt.Horizontal)
                input_ = QDoubleSpinBox()
                input_.setRange(*range_)
                input_.setDecimals(3)
                input_.setSingleStep(0.01)
                input_.setValue(input_value)
                slider.valueChanged.connect(self.on_value_changed)
                input_.valueChanged.connect(self.on_value_changed)
                slider.valueChanged.connect(
                    lambda x, input_=input_, range_=range_: input_.setValue(
                        x / 1000 * (range_[-1] - range_[0]) + range_[0]))
                input_.valueChanged.connect(
                    lambda x, slider=slider, range_=range_: slider.setValue(
                        (x - range_[0]) / (range_[-1] - range_[0]) * 1000))

                group_layout.addWidget(label, j, 0)
                group_layout.addWidget(slider, j, 1)
                group_layout.addWidget(input_, j, 2)
                inputs.append(input_)

            self.control_layout.addWidget(group, i + 5, 0, 1, 4)
            widgets.append(group)
            input_widgets.append(inputs)

        self.control_widgets = widgets
        self.input_widgets = input_widgets

    @property
    def n_components(self) -> int:
        return len(self.input_widgets)

    @property
    def expected(self):
        reference = []
        weights = []
        for i, (mean, std, weight) in enumerate(self.input_widgets):
            reference.append(
                dict(mean=mean.value(), std=std.value(), skewness=0.0))
            weights.append(weight.value())
        weights = np.array(weights)
        fractions = weights / np.sum(weights)
        return reference, fractions

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    def on_confirm_clicked(self):
        if self.last_result is not None:
            for component, (mean, std,
                            weight) in zip(self.last_result.components,
                                           self.input_widgets):
                mean.setValue(component.logarithmic_moments["mean"])
                std.setValue(component.logarithmic_moments["std"])
                weight.setValue(component.fraction * 10)
            self.manual_fitting_finished.emit(self.last_result)

            self.last_result = None
            self.last_task = None
            self.try_button.setEnabled(False)
            self.confirm_button.setEnabled(False)
            self.hide()

    def on_task_failed(self, info: str, task: SSUTask):
        self.show_error(info)

    def on_task_succeeded(self, result: SSUResult):
        self.chart.show_model(result.view_model)
        self.last_result = result
        self.confirm_button.setEnabled(True)

    def on_try_clicked(self):
        if self.last_task is None:
            return
        new_task = copy.copy(self.last_task)
        reference, fractions = self.expected
        initial_guess = BaseDistribution.get_initial_guess(
            self.last_task.distribution_type, reference, fractions=fractions)
        new_task.initial_guess = initial_guess
        self.async_worker.execute_task(new_task)

    def on_value_changed(self):
        self.chart_timer.stop()
        self.chart_timer.start(10)

    def update_chart(self):
        if self.last_task is None:
            return
        reference, fractions = self.expected
        for comp_ref in reference:
            if comp_ref["std"] == 0.0:
                return
        # print(reference)
        initial_guess = BaseDistribution.get_initial_guess(
            self.last_task.distribution_type, reference, fractions=fractions)
        result = SSUResult(self.last_task, initial_guess)
        self.chart.show_model(result.view_model, quick=True)

    def setup_task(self, task: SSUTask):
        self.last_task = task
        self.try_button.setEnabled(True)
        if self.n_components != task.n_components:
            self.change_n_components(task.n_components)
        reference, fractions = self.expected
        initial_guess = BaseDistribution.get_initial_guess(
            task.distribution_type, reference, fractions=fractions)
        result = SSUResult(task, initial_guess)
        self.chart.show_model(result.view_model, quick=False)
Пример #7
0
    def init_ui(self):
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.main_layout = QGridLayout(self)
        # self.main_layout.setContentsMargins(0, 0, 0, 0)
        # control group
        self.control_group = QGroupBox(self.tr("Control"))
        self.control_layout = QGridLayout(self.control_group)
        self.resolver_label = QLabel(self.tr("Resolver"))
        self.resolver_combo_box = QComboBox()
        self.resolver_combo_box.addItems(["classic", "neural"])
        self.control_layout.addWidget(self.resolver_label, 0, 0)
        self.control_layout.addWidget(self.resolver_combo_box, 0, 1)
        self.load_dataset_button = QPushButton(qta.icon("fa.database"),
                                               self.tr("Load Dataset"))
        self.load_dataset_button.clicked.connect(self.on_load_dataset_clicked)
        self.configure_fitting_button = QPushButton(
            qta.icon("fa.gears"), self.tr("Configure Fitting Algorithm"))
        self.configure_fitting_button.clicked.connect(
            self.on_configure_fitting_clicked)
        self.control_layout.addWidget(self.load_dataset_button, 1, 0)
        self.control_layout.addWidget(self.configure_fitting_button, 1, 1)
        self.distribution_label = QLabel(self.tr("Distribution Type"))
        self.distribution_combo_box = QComboBox()
        self.distribution_combo_box.addItems(
            [name for _, name in self.distribution_types])
        self.component_number_label = QLabel(self.tr("N<sub>components</sub>"))
        self.n_components_input = QSpinBox()
        self.n_components_input.setRange(1, 10)
        self.n_components_input.setValue(3)
        self.control_layout.addWidget(self.distribution_label, 2, 0)
        self.control_layout.addWidget(self.distribution_combo_box, 2, 1)
        self.control_layout.addWidget(self.component_number_label, 3, 0)
        self.control_layout.addWidget(self.n_components_input, 3, 1)

        self.n_samples_label = QLabel(self.tr("N<sub>samples</sub>"))
        self.n_samples_display = QLabel(self.tr("Unknown"))
        self.control_layout.addWidget(self.n_samples_label, 4, 0)
        self.control_layout.addWidget(self.n_samples_display, 4, 1)
        self.sample_index_label = QLabel(self.tr("Sample Index"))
        self.sample_index_input = QSpinBox()
        self.sample_index_input.valueChanged.connect(
            self.on_sample_index_changed)
        self.sample_index_input.setEnabled(False)
        self.control_layout.addWidget(self.sample_index_label, 5, 0)
        self.control_layout.addWidget(self.sample_index_input, 5, 1)
        self.sample_name_label = QLabel(self.tr("Sample Name"))
        self.sample_name_display = QLabel(self.tr("Unknown"))
        self.control_layout.addWidget(self.sample_name_label, 6, 0)
        self.control_layout.addWidget(self.sample_name_display, 6, 1)

        self.manual_test_button = QPushButton(qta.icon("fa.sliders"),
                                              self.tr("Manual Test"))
        self.manual_test_button.setEnabled(False)
        self.manual_test_button.clicked.connect(self.on_manual_test_clicked)
        self.load_reference_button = QPushButton(qta.icon("mdi.map-check"),
                                                 self.tr("Load Reference"))
        self.load_reference_button.clicked.connect(
            lambda: self.reference_view.load_dump(mark_ref=True))
        self.control_layout.addWidget(self.manual_test_button, 7, 0)
        self.control_layout.addWidget(self.load_reference_button, 7, 1)

        self.single_test_button = QPushButton(qta.icon("fa.play-circle"),
                                              self.tr("Single Test"))
        self.single_test_button.setEnabled(False)
        self.single_test_button.clicked.connect(self.on_single_test_clicked)
        self.continuous_test_button = QPushButton(
            qta.icon("mdi.playlist-play"), self.tr("Continuous Test"))
        self.continuous_test_button.setEnabled(False)
        self.continuous_test_button.clicked.connect(
            self.on_continuous_test_clicked)
        self.control_layout.addWidget(self.single_test_button, 8, 0)
        self.control_layout.addWidget(self.continuous_test_button, 8, 1)

        self.test_previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Test Previous"))
        self.test_previous_button.setEnabled(False)
        self.test_previous_button.clicked.connect(
            self.on_test_previous_clicked)
        self.test_next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                            self.tr("Test Next"))
        self.test_next_button.setEnabled(False)
        self.test_next_button.clicked.connect(self.on_test_next_clicked)
        self.control_layout.addWidget(self.test_previous_button, 9, 0)
        self.control_layout.addWidget(self.test_next_button, 9, 1)

        # chart group
        self.chart_group = QGroupBox(self.tr("Chart"))
        self.chart_layout = QGridLayout(self.chart_group)
        self.result_chart = MixedDistributionChart(show_mode=True,
                                                   toolbar=False)
        self.chart_layout.addWidget(self.result_chart, 0, 0)

        # table group
        self.table_group = QGroupBox(self.tr("Table"))
        self.reference_view = ReferenceResultViewer()
        self.result_view = FittingResultViewer(self.reference_view)
        self.result_view.result_marked.connect(
            lambda result: self.reference_view.add_references([result]))
        self.table_tab = QTabWidget()
        self.table_tab.addTab(self.result_view, qta.icon("fa.cubes"),
                              self.tr("Result"))
        self.table_tab.addTab(self.reference_view, qta.icon("fa5s.key"),
                              self.tr("Reference"))
        self.result_layout = QGridLayout(self.table_group)
        self.result_layout.addWidget(self.table_tab, 0, 0)

        # pack all group
        self.splitter1 = QSplitter(Qt.Orientation.Vertical)
        self.splitter1.addWidget(self.control_group)
        self.splitter1.addWidget(self.chart_group)
        self.splitter2 = QSplitter(Qt.Orientation.Horizontal)
        self.splitter2.addWidget(self.splitter1)
        self.splitter2.addWidget(self.table_group)
        self.main_layout.addWidget(self.splitter2, 0, 0)
Пример #8
0
class SSUResolverPanel(QDialog):
    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Resolver"))
        self.distribution_types = [
            (DistributionType.Normal, self.tr("Normal")),
            (DistributionType.Weibull, self.tr("Weibull")),
            (DistributionType.SkewNormal, self.tr("Skew Normal"))
        ]
        self.load_dataset_dialog = LoadDatasetDialog(parent=self)
        self.load_dataset_dialog.dataset_loaded.connect(self.on_dataset_loaded)
        self.classic_setting = ClassicResolverSettingWidget(parent=self)
        self.neural_setting = NNResolverSettingWidget(parent=self)
        self.manual_panel = ManualFittingPanel(parent=self)
        self.manual_panel.manual_fitting_finished.connect(
            self.on_fitting_succeeded)
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_fitting_succeeded)
        self.async_worker.background_worker.task_failed.connect(
            self.on_fitting_failed)
        self.normal_msg = QMessageBox(self)
        self.dataset = None
        self.task_table = {}
        self.task_results = {}
        self.failed_task_ids = []
        self.__continuous_flag = False
        self.init_ui()

    def init_ui(self):
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.main_layout = QGridLayout(self)
        # self.main_layout.setContentsMargins(0, 0, 0, 0)
        # control group
        self.control_group = QGroupBox(self.tr("Control"))
        self.control_layout = QGridLayout(self.control_group)
        self.resolver_label = QLabel(self.tr("Resolver"))
        self.resolver_combo_box = QComboBox()
        self.resolver_combo_box.addItems(["classic", "neural"])
        self.control_layout.addWidget(self.resolver_label, 0, 0)
        self.control_layout.addWidget(self.resolver_combo_box, 0, 1)
        self.load_dataset_button = QPushButton(qta.icon("fa.database"),
                                               self.tr("Load Dataset"))
        self.load_dataset_button.clicked.connect(self.on_load_dataset_clicked)
        self.configure_fitting_button = QPushButton(
            qta.icon("fa.gears"), self.tr("Configure Fitting Algorithm"))
        self.configure_fitting_button.clicked.connect(
            self.on_configure_fitting_clicked)
        self.control_layout.addWidget(self.load_dataset_button, 1, 0)
        self.control_layout.addWidget(self.configure_fitting_button, 1, 1)
        self.distribution_label = QLabel(self.tr("Distribution Type"))
        self.distribution_combo_box = QComboBox()
        self.distribution_combo_box.addItems(
            [name for _, name in self.distribution_types])
        self.component_number_label = QLabel(self.tr("N<sub>components</sub>"))
        self.n_components_input = QSpinBox()
        self.n_components_input.setRange(1, 10)
        self.n_components_input.setValue(3)
        self.control_layout.addWidget(self.distribution_label, 2, 0)
        self.control_layout.addWidget(self.distribution_combo_box, 2, 1)
        self.control_layout.addWidget(self.component_number_label, 3, 0)
        self.control_layout.addWidget(self.n_components_input, 3, 1)

        self.n_samples_label = QLabel(self.tr("N<sub>samples</sub>"))
        self.n_samples_display = QLabel(self.tr("Unknown"))
        self.control_layout.addWidget(self.n_samples_label, 4, 0)
        self.control_layout.addWidget(self.n_samples_display, 4, 1)
        self.sample_index_label = QLabel(self.tr("Sample Index"))
        self.sample_index_input = QSpinBox()
        self.sample_index_input.valueChanged.connect(
            self.on_sample_index_changed)
        self.sample_index_input.setEnabled(False)
        self.control_layout.addWidget(self.sample_index_label, 5, 0)
        self.control_layout.addWidget(self.sample_index_input, 5, 1)
        self.sample_name_label = QLabel(self.tr("Sample Name"))
        self.sample_name_display = QLabel(self.tr("Unknown"))
        self.control_layout.addWidget(self.sample_name_label, 6, 0)
        self.control_layout.addWidget(self.sample_name_display, 6, 1)

        self.manual_test_button = QPushButton(qta.icon("fa.sliders"),
                                              self.tr("Manual Test"))
        self.manual_test_button.setEnabled(False)
        self.manual_test_button.clicked.connect(self.on_manual_test_clicked)
        self.load_reference_button = QPushButton(qta.icon("mdi.map-check"),
                                                 self.tr("Load Reference"))
        self.load_reference_button.clicked.connect(
            lambda: self.reference_view.load_dump(mark_ref=True))
        self.control_layout.addWidget(self.manual_test_button, 7, 0)
        self.control_layout.addWidget(self.load_reference_button, 7, 1)

        self.single_test_button = QPushButton(qta.icon("fa.play-circle"),
                                              self.tr("Single Test"))
        self.single_test_button.setEnabled(False)
        self.single_test_button.clicked.connect(self.on_single_test_clicked)
        self.continuous_test_button = QPushButton(
            qta.icon("mdi.playlist-play"), self.tr("Continuous Test"))
        self.continuous_test_button.setEnabled(False)
        self.continuous_test_button.clicked.connect(
            self.on_continuous_test_clicked)
        self.control_layout.addWidget(self.single_test_button, 8, 0)
        self.control_layout.addWidget(self.continuous_test_button, 8, 1)

        self.test_previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Test Previous"))
        self.test_previous_button.setEnabled(False)
        self.test_previous_button.clicked.connect(
            self.on_test_previous_clicked)
        self.test_next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                            self.tr("Test Next"))
        self.test_next_button.setEnabled(False)
        self.test_next_button.clicked.connect(self.on_test_next_clicked)
        self.control_layout.addWidget(self.test_previous_button, 9, 0)
        self.control_layout.addWidget(self.test_next_button, 9, 1)

        # chart group
        self.chart_group = QGroupBox(self.tr("Chart"))
        self.chart_layout = QGridLayout(self.chart_group)
        self.result_chart = MixedDistributionChart(show_mode=True,
                                                   toolbar=False)
        self.chart_layout.addWidget(self.result_chart, 0, 0)

        # table group
        self.table_group = QGroupBox(self.tr("Table"))
        self.reference_view = ReferenceResultViewer()
        self.result_view = FittingResultViewer(self.reference_view)
        self.result_view.result_marked.connect(
            lambda result: self.reference_view.add_references([result]))
        self.table_tab = QTabWidget()
        self.table_tab.addTab(self.result_view, qta.icon("fa.cubes"),
                              self.tr("Result"))
        self.table_tab.addTab(self.reference_view, qta.icon("fa5s.key"),
                              self.tr("Reference"))
        self.result_layout = QGridLayout(self.table_group)
        self.result_layout.addWidget(self.table_tab, 0, 0)

        # pack all group
        self.splitter1 = QSplitter(Qt.Orientation.Vertical)
        self.splitter1.addWidget(self.control_group)
        self.splitter1.addWidget(self.chart_group)
        self.splitter2 = QSplitter(Qt.Orientation.Horizontal)
        self.splitter2.addWidget(self.splitter1)
        self.splitter2.addWidget(self.table_group)
        self.main_layout.addWidget(self.splitter2, 0, 0)

    @property
    def distribution_type(self) -> DistributionType:
        distribution_type, _ = self.distribution_types[
            self.distribution_combo_box.currentIndex()]
        return distribution_type

    @property
    def n_components(self) -> int:
        return self.n_components_input.value()

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    def on_load_dataset_clicked(self):
        self.load_dataset_dialog.show()

    def on_dataset_loaded(self, dataset: GrainSizeDataset):
        self.dataset = dataset
        self.n_samples_display.setText(str(dataset.n_samples))
        self.sample_index_input.setRange(1, dataset.n_samples)
        self.sample_index_input.setEnabled(True)
        self.manual_test_button.setEnabled(True)
        self.single_test_button.setEnabled(True)
        self.continuous_test_button.setEnabled(True)
        self.test_previous_button.setEnabled(True)
        self.test_next_button.setEnabled(True)

    def on_configure_fitting_clicked(self):
        if self.resolver_combo_box.currentText() == "classic":
            self.classic_setting.show()
        else:
            self.neural_setting.show()

    def on_sample_index_changed(self, index):
        self.sample_name_display.setText(self.dataset.samples[index - 1].name)

    def generate_task(self, query_ref=True):
        sample_index = self.sample_index_input.value() - 1
        sample = self.dataset.samples[sample_index]

        resolver = self.resolver_combo_box.currentText()
        if resolver == "classic":
            setting = self.classic_setting.setting
        else:
            setting = self.neural_setting.setting

        query = self.reference_view.query_reference(sample)  # type: SSUResult
        if not query_ref or query is None:
            task = SSUTask(sample,
                           self.distribution_type,
                           self.n_components,
                           resolver=resolver,
                           resolver_setting=setting)
        else:
            keys = ["mean", "std", "skewness"]
            # reference = [{key: comp.logarithmic_moments[key] for key in keys} for comp in query.components]
            task = SSUTask(
                sample,
                query.distribution_type,
                query.n_components,
                resolver=resolver,
                resolver_setting=setting,
                # reference=reference)
                initial_guess=query.last_func_args)
        return task

    def on_fitting_succeeded(self, fitting_result: SSUResult):
        # update chart
        self.result_chart.show_model(fitting_result.view_model)
        self.result_view.add_result(fitting_result)
        self.task_results[fitting_result.task.uuid] = fitting_result
        if self.__continuous_flag:
            if self.sample_index_input.value() == self.dataset.n_samples:
                self.on_continuous_test_clicked()
            else:
                self.sample_index_input.setValue(
                    self.sample_index_input.value() + 1)
                self.do_test()
                return
        self.manual_test_button.setEnabled(True)
        self.single_test_button.setEnabled(True)
        self.continuous_test_button.setEnabled(True)
        self.test_previous_button.setEnabled(True)
        self.test_next_button.setEnabled(True)

    def on_fitting_failed(self, failed_info: str, task: SSUTask):
        self.failed_task_ids.append(task.uuid)
        if self.__continuous_flag:
            self.on_continuous_test_clicked()
        self.manual_test_button.setEnabled(True)
        self.single_test_button.setEnabled(True)
        self.continuous_test_button.setEnabled(True)
        self.test_previous_button.setEnabled(True)
        self.test_next_button.setEnabled(True)
        self.show_error(failed_info)

    def do_test(self):
        self.manual_test_button.setEnabled(False)
        self.single_test_button.setEnabled(False)
        self.test_previous_button.setEnabled(False)
        self.test_next_button.setEnabled(False)
        if not self.__continuous_flag:
            self.continuous_test_button.setEnabled(False)
        task = self.generate_task()
        self.task_table[task.uuid] = task
        self.async_worker.execute_task(task)

    def on_manual_test_clicked(self):
        task = self.generate_task(query_ref=False)
        self.manual_panel.setup_task(task)
        self.manual_panel.show()

    def on_single_test_clicked(self):
        self.do_test()

    def on_continuous_test_clicked(self):
        if self.__continuous_flag:
            self.__continuous_flag = not self.__continuous_flag
            self.continuous_test_button.setText(self.tr("Continuous Test"))
        else:
            self.continuous_test_button.setText(self.tr("Cancel"))
            self.__continuous_flag = not self.__continuous_flag
            self.do_test()

    def on_test_previous_clicked(self):
        current = self.sample_index_input.value()
        if current == 1:
            return
        self.sample_index_input.setValue(current - 1)
        self.do_test()

    def on_test_next_clicked(self):
        current = self.sample_index_input.value()
        if current == self.dataset.n_samples:
            return
        self.sample_index_input.setValue(current + 1)
        self.do_test()
Пример #9
0
class FittingResultViewer(QDialog):
    PAGE_ROWS = 20
    logger = logging.getLogger("root.QGrain.ui.FittingResultViewer")
    result_marked = Signal(SSUResult)

    def __init__(self, reference_viewer: ReferenceResultViewer, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Fitting Result Viewer"))
        self.__fitting_results = []  # type: list[SSUResult]
        self.retry_tasks = {}  # type: dict[UUID, SSUTask]
        self.__reference_viewer = reference_viewer
        self.init_ui()
        self.boxplot_chart = BoxplotChart(parent=self, toolbar=True)
        self.typical_chart = SSUTypicalComponentChart(parent=self,
                                                      toolbar=True)
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.async_worker = AsyncWorker()
        self.async_worker.background_worker.task_succeeded.connect(
            self.on_fitting_succeeded)
        self.async_worker.background_worker.task_failed.connect(
            self.on_fitting_failed)
        self.update_page_list()
        self.update_page(self.page_index)

        self.normal_msg = QMessageBox(self)
        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))
        self.outlier_msg = QMessageBox(self)
        self.outlier_msg.setStandardButtons(QMessageBox.Discard
                                            | QMessageBox.Retry
                                            | QMessageBox.Ignore)
        self.outlier_msg.setDefaultButton(QMessageBox.Ignore)
        self.retry_progress_msg = QMessageBox()
        self.retry_progress_msg.addButton(QMessageBox.Ok)
        self.retry_progress_msg.button(QMessageBox.Ok).hide()
        self.retry_progress_msg.setWindowTitle(self.tr("Progress"))
        self.retry_timer = QTimer(self)
        self.retry_timer.setSingleShot(True)
        self.retry_timer.timeout.connect(
            lambda: self.retry_progress_msg.exec_())

    def init_ui(self):
        self.data_table = QTableWidget(100, 100)
        self.data_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.data_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.data_table.setAlternatingRowColors(True)
        self.data_table.setContextMenuPolicy(Qt.CustomContextMenu)
        self.main_layout = QGridLayout(self)
        self.main_layout.addWidget(self.data_table, 0, 0, 1, 3)

        self.previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Previous"))
        self.previous_button.setToolTip(
            self.tr("Click to back to the previous page."))
        self.previous_button.clicked.connect(self.on_previous_button_clicked)
        self.current_page_combo_box = QComboBox()
        self.current_page_combo_box.addItem(self.tr("Page {0}").format(1))
        self.current_page_combo_box.currentIndexChanged.connect(
            self.update_page)
        self.next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                       self.tr("Next"))
        self.next_button.setToolTip(self.tr("Click to jump to the next page."))
        self.next_button.clicked.connect(self.on_next_button_clicked)
        self.main_layout.addWidget(self.previous_button, 1, 0)
        self.main_layout.addWidget(self.current_page_combo_box, 1, 1)
        self.main_layout.addWidget(self.next_button, 1, 2)

        self.distance_label = QLabel(self.tr("Distance"))
        self.distance_label.setToolTip(
            self.
            tr("It's the function to calculate the difference (on the contrary, similarity) between two samples."
               ))
        self.distance_combo_box = QComboBox()
        self.distance_combo_box.addItems(built_in_distances)
        self.distance_combo_box.setCurrentText("log10MSE")
        self.distance_combo_box.currentTextChanged.connect(
            lambda: self.update_page(self.page_index))
        self.main_layout.addWidget(self.distance_label, 2, 0)
        self.main_layout.addWidget(self.distance_combo_box, 2, 1, 1, 2)
        self.menu = QMenu(self.data_table)
        self.menu.setShortcutAutoRepeat(True)
        self.mark_action = self.menu.addAction(
            qta.icon("mdi.marker-check"),
            self.tr("Mark Selection(s) as Reference"))
        self.mark_action.triggered.connect(self.mark_selections)
        self.remove_selection_action = self.menu.addAction(
            qta.icon("fa.remove"), self.tr("Remove Selection(s)"))
        self.remove_selection_action.triggered.connect(self.remove_selections)
        self.remove_all_action = self.menu.addAction(qta.icon("fa.remove"),
                                                     self.tr("Remove All"))
        self.remove_all_action.triggered.connect(self.remove_all_results)
        self.plot_loss_chart_action = self.menu.addAction(
            qta.icon("mdi.chart-timeline-variant"), self.tr("Plot Loss Chart"))
        self.plot_loss_chart_action.triggered.connect(self.show_distance)
        self.plot_distribution_chart_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"), self.tr("Plot Distribution Chart"))
        self.plot_distribution_chart_action.triggered.connect(
            self.show_distribution)
        self.plot_distribution_animation_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"),
            self.tr("Plot Distribution Chart (Animation)"))
        self.plot_distribution_animation_action.triggered.connect(
            self.show_history_distribution)

        self.detect_outliers_menu = self.menu.addMenu(
            qta.icon("mdi.magnify"), self.tr("Detect Outliers"))
        self.check_nan_and_inf_action = self.detect_outliers_menu.addAction(
            self.tr("Check NaN and Inf"))
        self.check_nan_and_inf_action.triggered.connect(self.check_nan_and_inf)
        self.check_final_distances_action = self.detect_outliers_menu.addAction(
            self.tr("Check Final Distances"))
        self.check_final_distances_action.triggered.connect(
            self.check_final_distances)
        self.check_component_mean_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Mean"))
        self.check_component_mean_action.triggered.connect(
            lambda: self.check_component_moments("mean"))
        self.check_component_std_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component STD"))
        self.check_component_std_action.triggered.connect(
            lambda: self.check_component_moments("std"))
        self.check_component_skewness_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Skewness"))
        self.check_component_skewness_action.triggered.connect(
            lambda: self.check_component_moments("skewness"))
        self.check_component_kurtosis_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Kurtosis"))
        self.check_component_kurtosis_action.triggered.connect(
            lambda: self.check_component_moments("kurtosis"))
        self.check_component_fractions_action = self.detect_outliers_menu.addAction(
            self.tr("Check Component Fractions"))
        self.check_component_fractions_action.triggered.connect(
            self.check_component_fractions)
        self.degrade_results_action = self.detect_outliers_menu.addAction(
            self.tr("Degrade Results"))
        self.degrade_results_action.triggered.connect(self.degrade_results)
        self.try_align_components_action = self.detect_outliers_menu.addAction(
            self.tr("Try Align Components"))
        self.try_align_components_action.triggered.connect(
            self.try_align_components)

        self.analyse_typical_components_action = self.menu.addAction(
            qta.icon("ei.tags"), self.tr("Analyse Typical Components"))
        self.analyse_typical_components_action.triggered.connect(
            self.analyse_typical_components)
        self.load_dump_action = self.menu.addAction(
            qta.icon("fa.database"), self.tr("Load Binary Dump"))
        self.load_dump_action.triggered.connect(self.load_dump)
        self.save_dump_action = self.menu.addAction(
            qta.icon("fa.save"), self.tr("Save Binary Dump"))
        self.save_dump_action.triggered.connect(self.save_dump)
        self.save_excel_action = self.menu.addAction(
            qta.icon("mdi.microsoft-excel"), self.tr("Save Excel"))
        self.save_excel_action.triggered.connect(
            lambda: self.on_save_excel_clicked(align_components=False))
        self.save_excel_align_action = self.menu.addAction(
            qta.icon("mdi.microsoft-excel"),
            self.tr("Save Excel (Force Alignment)"))
        self.save_excel_align_action.triggered.connect(
            lambda: self.on_save_excel_clicked(align_components=True))
        self.data_table.customContextMenuRequested.connect(self.show_menu)
        # necessary to add actions of menu to this widget itself,
        # otherwise, the shortcuts will not be triggered
        self.addActions(self.menu.actions())

    def show_menu(self, pos: QPoint):
        self.menu.popup(QCursor.pos())

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    @property
    def distance_name(self) -> str:
        return self.distance_combo_box.currentText()

    @property
    def distance_func(self) -> typing.Callable:
        return get_distance_func_by_name(self.distance_combo_box.currentText())

    @property
    def page_index(self) -> int:
        return self.current_page_combo_box.currentIndex()

    @property
    def n_pages(self) -> int:
        return self.current_page_combo_box.count()

    @property
    def n_results(self) -> int:
        return len(self.__fitting_results)

    @property
    def selections(self):
        start = self.page_index * self.PAGE_ROWS
        temp = set()
        for item in self.data_table.selectedRanges():
            for i in range(item.topRow(),
                           min(self.PAGE_ROWS + 1,
                               item.bottomRow() + 1)):
                temp.add(i + start)
        indexes = list(temp)
        indexes.sort()
        return indexes

    def update_page_list(self):
        last_page_index = self.page_index
        if self.n_results == 0:
            n_pages = 1
        else:
            n_pages, left = divmod(self.n_results, self.PAGE_ROWS)
            if left != 0:
                n_pages += 1
        self.current_page_combo_box.blockSignals(True)
        self.current_page_combo_box.clear()
        self.current_page_combo_box.addItems(
            [self.tr("Page {0}").format(i + 1) for i in range(n_pages)])
        if last_page_index >= n_pages:
            self.current_page_combo_box.setCurrentIndex(n_pages - 1)
        else:
            self.current_page_combo_box.setCurrentIndex(last_page_index)
        self.current_page_combo_box.blockSignals(False)

    def update_page(self, page_index: int):
        def write(row: int, col: int, value: str):
            if isinstance(value, str):
                pass
            elif isinstance(value, int):
                value = str(value)
            elif isinstance(value, float):
                value = f"{value: 0.4f}"
            else:
                value = value.__str__()
            item = QTableWidgetItem(value)
            item.setTextAlignment(Qt.AlignCenter)
            self.data_table.setItem(row, col, item)

        # necessary to clear
        self.data_table.clear()
        if page_index == self.n_pages - 1:
            start = page_index * self.PAGE_ROWS
            end = self.n_results
        else:
            start, end = page_index * self.PAGE_ROWS, (page_index +
                                                       1) * self.PAGE_ROWS
        self.data_table.setRowCount(end - start)
        self.data_table.setColumnCount(7)
        self.data_table.setHorizontalHeaderLabels([
            self.tr("Resolver"),
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("N_iterations"),
            self.tr("Spent Time [s]"),
            self.tr("Final Distance"),
            self.tr("Has Reference")
        ])
        sample_names = [
            result.sample.name for result in self.__fitting_results[start:end]
        ]
        self.data_table.setVerticalHeaderLabels(sample_names)
        for row, result in enumerate(self.__fitting_results[start:end]):
            write(row, 0, result.task.resolver)
            write(row, 1,
                  self.get_distribution_name(result.task.distribution_type))
            write(row, 2, result.task.n_components)
            write(row, 3, result.n_iterations)
            write(row, 4, result.time_spent)
            write(
                row, 5,
                self.distance_func(result.sample.distribution,
                                   result.distribution))
            has_ref = result.task.initial_guess is not None or result.task.reference is not None
            write(row, 6, self.tr("Yes") if has_ref else self.tr("No"))

        self.data_table.resizeColumnsToContents()

    def on_previous_button_clicked(self):
        if self.page_index > 0:
            self.current_page_combo_box.setCurrentIndex(self.page_index - 1)

    def on_next_button_clicked(self):
        if self.page_index < self.n_pages - 1:
            self.current_page_combo_box.setCurrentIndex(self.page_index + 1)

    def get_distribution_name(self, distribution_type: DistributionType):
        if distribution_type == DistributionType.Normal:
            return self.tr("Normal")
        elif distribution_type == DistributionType.Weibull:
            return self.tr("Weibull")
        elif distribution_type == DistributionType.SkewNormal:
            return self.tr("Skew Normal")
        else:
            raise NotImplementedError(distribution_type)

    def add_result(self, result: SSUResult):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.append(result)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def add_results(self, results: typing.List[SSUResult]):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.extend(results)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def mark_selections(self):
        for index in self.selections:
            self.result_marked.emit(self.__fitting_results[index])

    def remove_results(self, indexes):
        results = []
        for i in reversed(indexes):
            res = self.__fitting_results.pop(i)
            results.append(res)
        self.update_page_list()
        self.update_page(self.page_index)

    def remove_selections(self):
        indexes = self.selections
        self.remove_results(indexes)

    def remove_all_results(self):
        res = self.remove_warning_msg.exec_()
        if res == QMessageBox.Yes:
            self.__fitting_results.clear()
            self.update_page_list()
            self.update_page(0)

    def show_distance(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.distance_chart.show_distance_series(result.get_distance_series(
            self.distance_name),
                                                 title=result.sample.name)
        self.distance_chart.show()

    def show_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_model(result.view_model)
        self.mixed_distribution_chart.show()

    def show_history_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_result(result)
        self.mixed_distribution_chart.show()

    def load_dump(self):
        filename, _ = self.file_dialog.getOpenFileName(
            self, self.tr("Select a binary dump file of SSU results"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "rb") as f:
            results = pickle.load(f)  # type: list[SSUResult]
            valid = True
            if isinstance(results, list):
                for result in results:
                    if not isinstance(result, SSUResult):
                        valid = False
                        break
            else:
                valid = False

            if valid:
                if self.n_results != 0 and len(results) != 0:
                    old_classes = self.__fitting_results[0].classes_φ
                    new_classes = results[0].classes_φ
                    classes_inconsistent = False
                    if len(old_classes) != len(new_classes):
                        classes_inconsistent = True
                    else:
                        classes_error = np.abs(old_classes - new_classes)
                        if not np.all(np.less_equal(classes_error, 1e-8)):
                            classes_inconsistent = True
                    if classes_inconsistent:
                        self.show_error(
                            self.
                            tr("The results in the dump file has inconsistent grain-size classes with that in your list."
                               ))
                        return
                self.add_results(results)
            else:
                self.show_error(self.tr("The binary dump file is invalid."))

    def save_dump(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Save the SSU results to binary dump file"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "wb") as f:
            pickle.dump(self.__fitting_results, f)

    def save_excel(self, filename, align_components=False):
        if self.n_results == 0:
            return

        results = self.__fitting_results.copy()
        classes_μm = results[0].classes_μm
        n_components_list = [
            result.n_components for result in self.__fitting_results
        ]
        count_dict = Counter(n_components_list)
        max_n_components = max(count_dict.keys())
        self.logger.debug(
            f"N_components: {count_dict}, Max N_components: {max_n_components}"
        )

        flags = []
        if not align_components:
            for result in results:
                flags.extend(range(result.n_components))
        else:
            n_components_desc = "\n".join([
                self.tr("{0} Component(s): {1}").format(n_components, count)
                for n_components, count in count_dict.items()
            ])
            self.show_info(
                self.tr("N_components distribution of Results:\n{0}").format(
                    n_components_desc))
            stacked_components = []
            for result in self.__fitting_results:
                for component in result.components:
                    stacked_components.append(component.distribution)
            stacked_components = np.array(stacked_components)
            cluser = KMeans(n_clusters=max_n_components)
            flags = cluser.fit_predict(stacked_components)
            # check flags to make it unique
            flag_index = 0
            for i, result in enumerate(self.__fitting_results):
                result_flags = set()
                for component in result.components:
                    if flags[flag_index] in result_flags:
                        if flags[flag_index] == max_n_components:
                            flags[flag_index] = max_n_components - 1
                        else:
                            flag_index[flag_index] += 1
                        result_flags.add(flags[flag_index])
                    flag_index += 1

            flag_set = set(flags)
            picked = []
            for target_flag in flag_set:
                for i, flag in enumerate(flags):
                    if flag == target_flag:
                        picked.append(
                            (target_flag,
                             logarithmic(classes_μm,
                                         stacked_components[i])["mean"]))
                        break
            picked.sort(key=lambda x: x[1])
            flag_map = {flag: index for index, (flag, _) in enumerate(picked)}
            flags = np.array([flag_map[flag] for flag in flags])

        wb = openpyxl.Workbook()
        prepare_styles(wb)
        ws = wb.active
        ws.title = self.tr("README")
        description = \
            """
            This Excel file was generated by QGrain ({0}).

            Please cite:
            Liu, Y., Liu, X., Sun, Y., 2021. QGrain: An open-source and easy-to-use software for the comprehensive analysis of grain size distributions. Sedimentary Geology 423, 105980. https://doi.org/10.1016/j.sedgeo.2021.105980

            It contanins 4 + max(N_components) sheets:
            1. The first sheet is the sample distributions of SSU results.
            2. The second sheet is used to put the infomation of fitting.
            3. The third sheet is the statistic parameters calculated by statistic moment method.
            4. The fouth sheet is the distributions of unmixed components and their sum of each sample.
            5. Other sheets are the unmixed end-member distributions which were discretely stored.

            The SSU algorithm is implemented by QGrain.

            """.format(QGRAIN_VERSION)

        def write(row, col, value, style="normal_light"):
            cell = ws.cell(row + 1, col + 1, value=value)
            cell.style = style

        lines_of_desc = description.split("\n")
        for row, line in enumerate(lines_of_desc):
            write(row, 0, line, style="description")
        ws.column_dimensions[column_to_char(0)].width = 200

        ws = wb.create_sheet(self.tr("Sample Distributions"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        for col, value in enumerate(classes_μm, 1):
            write(0, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            for col, value in enumerate(result.sample.distribution, 1):
                write(row, col, value, style=style)
            QCoreApplication.processEvents()

        ws = wb.create_sheet(self.tr("Information of Fitting"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        headers = [
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("Resolver"),
            self.tr("Resolver Settings"),
            self.tr("Initial Guess"),
            self.tr("Reference"),
            self.tr("Spent Time [s]"),
            self.tr("N_iterations"),
            self.tr("Final Distance [log10MSE]")
        ]
        for col, value in enumerate(headers, 1):
            write(0, col, value, style="header")
            if col in (4, 5, 6):
                ws.column_dimensions[column_to_char(col)].width = 10
            else:
                ws.column_dimensions[column_to_char(col)].width = 10
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            write(row, 1, result.distribution_type.name, style=style)
            write(row, 2, result.n_components, style=style)
            write(row, 3, result.task.resolver, style=style)
            write(row,
                  4,
                  self.tr("Default") if result.task.resolver_setting is None
                  else result.task.resolver_setting.__str__(),
                  style=style)
            write(row,
                  5,
                  self.tr("None") if result.task.initial_guess is None else
                  result.task.initial_guess.__str__(),
                  style=style)
            write(row,
                  6,
                  self.tr("None") if result.task.reference is None else
                  result.task.reference.__str__(),
                  style=style)
            write(row, 7, result.time_spent, style=style)
            write(row, 8, result.n_iterations, style=style)
            write(row, 9, result.get_distance("log10MSE"), style=style)

        ws = wb.create_sheet(self.tr("Statistic Moments"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
        ws.column_dimensions[column_to_char(0)].width = 16
        headers = []
        sub_headers = [
            self.tr("Proportion"),
            self.tr("Mean [φ]"),
            self.tr("Mean [μm]"),
            self.tr("STD [φ]"),
            self.tr("STD [μm]"),
            self.tr("Skewness"),
            self.tr("Kurtosis")
        ]
        for i in range(max_n_components):
            write(0,
                  i * len(sub_headers) + 1,
                  self.tr("C{0}").format(i + 1),
                  style="header")
            ws.merge_cells(start_row=1,
                           start_column=i * len(sub_headers) + 2,
                           end_row=1,
                           end_column=(i + 1) * len(sub_headers) + 1)
            headers.extend(sub_headers)
        for col, value in enumerate(headers, 1):
            write(1, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        flag_index = 0
        for row, result in enumerate(results, 2):
            if row % 2 == 0:
                style = "normal_light"
            else:
                style = "normal_dark"
            write(row, 0, result.sample.name, style=style)
            for component in result.components:
                index = flags[flag_index]
                write(row,
                      index * len(sub_headers) + 1,
                      component.fraction,
                      style=style)
                write(row,
                      index * len(sub_headers) + 2,
                      component.logarithmic_moments["mean"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 3,
                      component.geometric_moments["mean"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 4,
                      component.logarithmic_moments["std"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 5,
                      component.geometric_moments["std"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 6,
                      component.logarithmic_moments["skewness"],
                      style=style)
                write(row,
                      index * len(sub_headers) + 7,
                      component.logarithmic_moments["kurtosis"],
                      style=style)
                flag_index += 1

        ws = wb.create_sheet(self.tr("Unmixed Components"))
        ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=2)
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 16
        for col, value in enumerate(classes_μm, 2):
            write(0, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        row = 1
        for result_index, result in enumerate(results, 1):
            if result_index % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, result.sample.name, style=style)
            ws.merge_cells(start_row=row + 1,
                           start_column=1,
                           end_row=row + result.n_components + 1,
                           end_column=1)
            for component_i, component in enumerate(result.components, 1):
                write(row, 1, self.tr("C{0}").format(component_i), style=style)
                for col, value in enumerate(
                        component.distribution * component.fraction, 2):
                    write(row, col, value, style=style)
                row += 1
            write(row, 1, self.tr("Sum"), style=style)
            for col, value in enumerate(result.distribution, 2):
                write(row, col, value, style=style)
            row += 1

        ws_dict = {}
        flag_set = set(flags)
        for flag in flag_set:
            ws = wb.create_sheet(self.tr("Unmixed EM{0}").format(flag + 1))
            write(0, 0, self.tr("Sample Name"), style="header")
            ws.column_dimensions[column_to_char(0)].width = 16
            for col, value in enumerate(classes_μm, 1):
                write(0, col, value, style="header")
                ws.column_dimensions[column_to_char(col)].width = 10
            ws_dict[flag] = ws

        flag_index = 0
        for row, result in enumerate(results, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"

            for component in result.components:
                flag = flags[flag_index]
                ws = ws_dict[flag]
                write(row, 0, result.sample.name, style=style)
                for col, value in enumerate(component.distribution, 1):
                    write(row, col, value, style=style)
                flag_index += 1

        wb.save(filename)
        wb.close()

    def on_save_excel_clicked(self, align_components=False):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any SSU result."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            None, self.tr("Choose a filename to save SSU Results"), None,
            "Microsoft Excel (*.xlsx)")
        if filename is None or filename == "":
            return
        try:
            self.save_excel(filename, align_components)
            self.show_info(
                self.tr("SSU results have been saved to:\n    {0}").format(
                    filename))
        except Exception as e:
            self.show_error(
                self.
                tr("Error raised while save SSU results to Excel file.\n    {0}"
                   ).format(e.__str__()))

    def on_fitting_succeeded(self, result: SSUResult):
        result_replace_index = self.retry_tasks[result.task.uuid]
        self.__fitting_results[result_replace_index] = result
        self.retry_tasks.pop(result.task.uuid)
        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(self.retry_tasks)))
        if len(self.retry_tasks) == 0:
            self.retry_progress_msg.close()

        self.logger.debug(
            f"Retried task succeeded, sample name={result.task.sample.name}, distribution_type={result.task.distribution_type.name}, n_components={result.task.n_components}"
        )
        self.update_page(self.page_index)

    def on_fitting_failed(self, failed_info: str, task: SSUTask):
        # necessary to remove it from the dict
        self.retry_tasks.pop(task.uuid)
        if len(self.retry_tasks) == 0:
            self.retry_progress_msg.close()
        self.show_error(
            self.tr("Failed to retry task, sample name={0}.\n{1}").format(
                task.sample.name, failed_info))
        self.logger.warning(
            f"Failed to retry task, sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
        )

    def retry_results(self, indexes, results):
        assert len(indexes) == len(results)
        if len(results) == 0:
            return
        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(results)))
        self.retry_timer.start(1)
        for index, result in zip(indexes, results):
            query = self.__reference_viewer.query_reference(result.sample)
            ref_result = None
            if query is None:
                nearby_results = self.__fitting_results[
                    index - 5:index] + self.__fitting_results[index + 1:index +
                                                              6]
                ref_result = self.__reference_viewer.find_similar(
                    result.sample, nearby_results)
            else:
                ref_result = query
            keys = ["mean", "std", "skewness"]
            # reference = [{key: comp.logarithmic_moments[key] for key in keys} for comp in ref_result.components]
            task = SSUTask(
                result.sample,
                ref_result.distribution_type,
                ref_result.n_components,
                resolver=ref_result.task.resolver,
                resolver_setting=ref_result.task.resolver_setting,
                #    reference=reference)
                initial_guess=ref_result.last_func_args)

            self.logger.debug(
                f"Retry task: sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
            )
            self.retry_tasks[task.uuid] = index
            self.async_worker.execute_task(task)

    def degrade_results(self):
        degrade_results = []  # type: list[SSUResult]
        degrade_indexes = []  # type: list[int]
        for i, result in enumerate(self.__fitting_results):
            for component in result.components:
                if component.fraction < 1e-3:
                    degrade_results.append(result)
                    degrade_indexes.append(i)
                    break
        self.logger.debug(
            f"Results should be degrade (have a redundant component): {[result.sample.name for result in degrade_results]}"
        )
        if len(degrade_results) == 0:
            self.show_info(
                self.tr("No fitting result was evaluated as an outlier."))
            return
        self.show_info(
            self.
            tr("The results below should be degrade (have a redundant component:\n    {0}"
               ).format(", ".join(
                   [result.sample.name for result in degrade_results])))

        self.retry_progress_msg.setText(
            self.tr("Tasks to be retried: {0}").format(len(degrade_results)))
        self.retry_timer.start(1)
        for index, result in zip(degrade_indexes, degrade_results):
            reference = []
            n_redundant = 0
            for component in result.components:
                if component.fraction < 1e-3:
                    n_redundant += 1
                else:
                    reference.append(
                        dict(mean=component.logarithmic_moments["mean"],
                             std=component.logarithmic_moments["std"],
                             skewness=component.logarithmic_moments["skewness"]
                             ))
            task = SSUTask(
                result.sample,
                result.distribution_type,
                result.n_components -
                n_redundant if result.n_components > n_redundant else 1,
                resolver=result.task.resolver,
                resolver_setting=result.task.resolver_setting,
                reference=reference)
            self.logger.debug(
                f"Retry task: sample name={task.sample.name}, distribution_type={task.distribution_type.name}, n_components={task.n_components}"
            )
            self.retry_tasks[task.uuid] = index
            self.async_worker.execute_task(task)

    def ask_deal_outliers(self, outlier_results: typing.List[SSUResult],
                          outlier_indexes: typing.List[int]):
        assert len(outlier_indexes) == len(outlier_results)
        if len(outlier_results) == 0:
            self.show_info(
                self.tr("No fitting result was evaluated as an outlier."))
        else:
            if len(outlier_results) > 100:
                self.outlier_msg.setText(
                    self.
                    tr("The fitting results have the component that its fraction is near zero:\n    {0}...(total {1} outliers)\nHow to deal with them?"
                       ).format(
                           ", ".join([
                               result.sample.name
                               for result in outlier_results[:100]
                           ]), len(outlier_results)))
            else:
                self.outlier_msg.setText(
                    self.
                    tr("The fitting results have the component that its fraction is near zero:\n    {0}\nHow to deal with them?"
                       ).format(", ".join([
                           result.sample.name for result in outlier_results
                       ])))
            res = self.outlier_msg.exec_()
            if res == QMessageBox.Discard:
                self.remove_results(outlier_indexes)
            elif res == QMessageBox.Retry:
                self.retry_results(outlier_indexes, outlier_results)
            else:
                pass

    def check_nan_and_inf(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        outlier_results = []
        outlier_indexes = []
        for i, result in enumerate(self.__fitting_results):
            if not result.is_valid:
                outlier_results.append(result)
                outlier_indexes.append(i)
        self.logger.debug(
            f"Outlier results with the nan or inf value(s): {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_final_distances(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        distances = []
        for result in self.__fitting_results:
            distances.append(result.get_distance(self.distance_name))
        distances = np.array(distances)
        self.boxplot_chart.show_dataset([distances],
                                        xlabels=[self.distance_name],
                                        ylabel=self.tr("Distance"))
        self.boxplot_chart.show()

        # calculate the 1/4, 1/2, and 3/4 postion value to judge which result is invalid
        # 1. the mean squared errors are much higher in the results which are lack of components
        # 2. with the component number getting higher, the mean squared error will get lower and finally reach the minimum
        median = np.median(distances)
        upper_group = distances[np.greater(distances, median)]
        lower_group = distances[np.less(distances, median)]
        value_1_4 = np.median(lower_group)
        value_3_4 = np.median(upper_group)
        distance_QR = value_3_4 - value_1_4
        outlier_results = []
        outlier_indexes = []
        for i, (result,
                distance) in enumerate(zip(self.__fitting_results, distances)):
            if distance > value_3_4 + distance_QR * 1.5:
                # which error too small is not outlier
                # if distance > value_3_4 + distance_QR * 1.5 or distance < value_1_4 - distance_QR * 1.5:
                outlier_results.append(result)
                outlier_indexes.append(i)
        self.logger.debug(
            f"Outlier results with too greater distances: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_component_moments(self, key: str):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        max_n_components = 0
        for result in self.__fitting_results:
            if result.n_components > max_n_components:
                max_n_components = result.n_components
        moments = []
        for i in range(max_n_components):
            moments.append([])

        for result in self.__fitting_results:
            for i, component in enumerate(result.components):
                if np.isnan(component.logarithmic_moments[key]) or np.isinf(
                        component.logarithmic_moments[key]):
                    pass
                else:
                    moments[i].append(component.logarithmic_moments[key])

        # key_trans = {"mean": self.tr("Mean"), "std": self.tr("STD"), "skewness": self.tr("Skewness"), "kurtosis": self.tr("Kurtosis")}
        key_label_trans = {
            "mean": self.tr("Mean [φ]"),
            "std": self.tr("STD [φ]"),
            "skewness": self.tr("Skewness"),
            "kurtosis": self.tr("Kurtosis")
        }
        self.boxplot_chart.show_dataset(
            moments,
            xlabels=[f"C{i+1}" for i in range(max_n_components)],
            ylabel=key_label_trans[key])
        self.boxplot_chart.show()

        outlier_dict = {}

        for i in range(max_n_components):
            stacked_moments = np.array(moments[i])
            # calculate the 1/4, 1/2, and 3/4 postion value to judge which result is invalid
            # 1. the mean squared errors are much higher in the results which are lack of components
            # 2. with the component number getting higher, the mean squared error will get lower and finally reach the minimum
            median = np.median(stacked_moments)
            upper_group = stacked_moments[np.greater(stacked_moments, median)]
            lower_group = stacked_moments[np.less(stacked_moments, median)]
            value_1_4 = np.median(lower_group)
            value_3_4 = np.median(upper_group)
            distance_QR = value_3_4 - value_1_4

            for j, result in enumerate(self.__fitting_results):
                if result.n_components > i:
                    distance = result.components[i].logarithmic_moments[key]
                    if distance > value_3_4 + distance_QR * 1.5 or distance < value_1_4 - distance_QR * 1.5:
                        outlier_dict[j] = result

        outlier_results = []
        outlier_indexes = []
        for index, result in sorted(outlier_dict.items(), key=lambda x: x[0]):
            outlier_indexes.append(index)
            outlier_results.append(result)
        self.logger.debug(
            f"Outlier results with abnormal {key} values of their components: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def check_component_fractions(self):
        outlier_results = []
        outlier_indexes = []
        for i, result in enumerate(self.__fitting_results):
            for component in result.components:
                if component.fraction < 1e-3:
                    outlier_results.append(result)
                    outlier_indexes.append(i)
                    break
        self.logger.debug(
            f"Outlier results with the component that its fraction is near zero: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def try_align_components(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return
        import matplotlib.pyplot as plt
        n_components_list = [
            result.n_components for result in self.__fitting_results
        ]
        count_dict = Counter(n_components_list)
        max_n_components = max(count_dict.keys())
        self.logger.debug(
            f"N_components: {count_dict}, Max N_components: {max_n_components}"
        )
        n_components_desc = "\n".join([
            self.tr("{0} Component(s): {1}").format(n_components, count)
            for n_components, count in count_dict.items()
        ])
        self.show_info(
            self.tr("N_components distribution of Results:\n{0}").format(
                n_components_desc))

        x = self.__fitting_results[0].classes_μm
        stacked_components = []
        for result in self.__fitting_results:
            for component in result.components:
                stacked_components.append(component.distribution)
        stacked_components = np.array(stacked_components)

        cluser = KMeans(n_clusters=max_n_components)
        flags = cluser.fit_predict(stacked_components)

        figure = plt.figure(figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        axes = figure.add_subplot(1, 1, 1)
        for flag, distribution in zip(flags, stacked_components):
            plt.plot(x, distribution, c=cmap(flag), zorder=flag)
        axes.set_xscale("log")
        axes.set_xlabel(self.tr("Grain-size [μm]"))
        axes.set_ylabel(self.tr("Frequency"))
        figure.tight_layout()
        figure.show()

        outlier_results = []
        outlier_indexes = []
        flag_index = 0
        for i, result in enumerate(self.__fitting_results):
            result_flags = set()
            for component in result.components:
                if flags[flag_index] in result_flags:
                    outlier_results.append(result)
                    outlier_indexes.append(i)
                    break
                else:
                    result_flags.add(flags[flag_index])
                flag_index += 1
        self.logger.debug(
            f"Outlier results that have two components in the same cluster: {[result.sample.name for result in outlier_results]}"
        )
        self.ask_deal_outliers(outlier_results, outlier_indexes)

    def analyse_typical_components(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        elif self.n_results < 10:
            self.show_warning(self.tr("The results in list are too less."))
            return

        self.typical_chart.show_typical(self.__fitting_results)
        self.typical_chart.show()
Пример #10
0
class ReferenceResultViewer(QDialog):
    PAGE_ROWS = 20
    logger = logging.getLogger("root.QGrain.ui.ReferenceResultViewer")

    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("SSU Reference Result Viewer"))
        self.__fitting_results = []
        self.__reference_map = {}
        self.retry_tasks = {}
        self.init_ui()
        self.distance_chart = DistanceCurveChart(parent=self, toolbar=True)
        self.mixed_distribution_chart = MixedDistributionChart(
            parent=self, toolbar=True, use_animation=True)
        self.file_dialog = QFileDialog(parent=self)
        self.update_page_list()
        self.update_page(self.page_index)

        self.remove_warning_msg = QMessageBox(self)
        self.remove_warning_msg.setStandardButtons(QMessageBox.No
                                                   | QMessageBox.Yes)
        self.remove_warning_msg.setDefaultButton(QMessageBox.No)
        self.remove_warning_msg.setWindowTitle(self.tr("Warning"))
        self.remove_warning_msg.setText(
            self.tr("Are you sure to remove all SSU results?"))

        self.normal_msg = QMessageBox(self)

    def init_ui(self):
        self.data_table = QTableWidget(100, 100)
        self.data_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
        self.data_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.data_table.setAlternatingRowColors(True)
        self.data_table.setContextMenuPolicy(Qt.CustomContextMenu)
        self.main_layout = QGridLayout(self)
        self.main_layout.addWidget(self.data_table, 0, 0, 1, 3)

        self.previous_button = QPushButton(
            qta.icon("mdi.skip-previous-circle"), self.tr("Previous"))
        self.previous_button.setToolTip(
            self.tr("Click to back to the previous page."))
        self.previous_button.clicked.connect(self.on_previous_button_clicked)
        self.current_page_combo_box = QComboBox()
        self.current_page_combo_box.addItem(self.tr("Page {0}").format(1))
        self.current_page_combo_box.currentIndexChanged.connect(
            self.update_page)
        self.next_button = QPushButton(qta.icon("mdi.skip-next-circle"),
                                       self.tr("Next"))
        self.next_button.setToolTip(self.tr("Click to jump to the next page."))
        self.next_button.clicked.connect(self.on_next_button_clicked)
        self.main_layout.addWidget(self.previous_button, 1, 0)
        self.main_layout.addWidget(self.current_page_combo_box, 1, 1)
        self.main_layout.addWidget(self.next_button, 1, 2)

        self.distance_label = QLabel(self.tr("Distance"))
        self.distance_label.setToolTip(
            self.
            tr("It's the function to calculate the difference (on the contrary, similarity) between two samples."
               ))
        self.distance_combo_box = QComboBox()
        self.distance_combo_box.addItems(built_in_distances)
        self.distance_combo_box.setCurrentText("log10MSE")
        self.distance_combo_box.currentTextChanged.connect(
            lambda: self.update_page(self.page_index))
        self.main_layout.addWidget(self.distance_label, 2, 0)
        self.main_layout.addWidget(self.distance_combo_box, 2, 1, 1, 2)
        self.menu = QMenu(self.data_table)
        self.mark_action = self.menu.addAction(
            qta.icon("mdi.marker-check"),
            self.tr("Mark Selection(s) as Reference"))
        self.mark_action.triggered.connect(self.mark_selections)
        self.unmark_action = self.menu.addAction(
            qta.icon("mdi.do-not-disturb"), self.tr("Unmark Selection(s)"))
        self.unmark_action.triggered.connect(self.unmark_selections)
        self.remove_action = self.menu.addAction(
            qta.icon("fa.remove"), self.tr("Remove Selection(s)"))
        self.remove_action.triggered.connect(self.remove_selections)
        self.remove_all_action = self.menu.addAction(qta.icon("fa.remove"),
                                                     self.tr("Remove All"))
        self.remove_all_action.triggered.connect(self.remove_all_results)
        self.plot_loss_chart_action = self.menu.addAction(
            qta.icon("mdi.chart-timeline-variant"), self.tr("Plot Loss Chart"))
        self.plot_loss_chart_action.triggered.connect(self.show_distance)
        self.plot_distribution_chart_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"), self.tr("Plot Distribution Chart"))
        self.plot_distribution_chart_action.triggered.connect(
            self.show_distribution)
        self.plot_distribution_animation_action = self.menu.addAction(
            qta.icon("fa5s.chart-area"),
            self.tr("Plot Distribution Chart (Animation)"))
        self.plot_distribution_animation_action.triggered.connect(
            self.show_history_distribution)

        self.load_dump_action = self.menu.addAction(
            qta.icon("fa.database"), self.tr("Load Binary Dump"))
        self.load_dump_action.triggered.connect(
            lambda: self.load_dump(mark_ref=True))
        self.save_dump_action = self.menu.addAction(
            qta.icon("fa.save"), self.tr("Save Binary Dump"))
        self.save_dump_action.triggered.connect(self.save_dump)
        self.data_table.customContextMenuRequested.connect(self.show_menu)

    def show_menu(self, pos):
        self.menu.popup(QCursor.pos())

    def show_message(self, title: str, message: str):
        self.normal_msg.setWindowTitle(title)
        self.normal_msg.setText(message)
        self.normal_msg.exec_()

    def show_info(self, message: str):
        self.show_message(self.tr("Info"), message)

    def show_warning(self, message: str):
        self.show_message(self.tr("Warning"), message)

    def show_error(self, message: str):
        self.show_message(self.tr("Error"), message)

    @property
    def distance_name(self) -> str:
        return self.distance_combo_box.currentText()

    @property
    def distance_func(self) -> typing.Callable:
        return get_distance_func_by_name(self.distance_combo_box.currentText())

    @property
    def page_index(self) -> int:
        return self.current_page_combo_box.currentIndex()

    @property
    def n_pages(self) -> int:
        return self.current_page_combo_box.count()

    @property
    def n_results(self) -> int:
        return len(self.__fitting_results)

    @property
    def selections(self):
        start = self.page_index * self.PAGE_ROWS
        temp = set()
        for item in self.data_table.selectedRanges():
            for i in range(item.topRow(),
                           min(self.PAGE_ROWS + 1,
                               item.bottomRow() + 1)):
                temp.add(i + start)
        indexes = list(temp)
        indexes.sort()
        return indexes

    def update_page_list(self):
        last_page_index = self.page_index
        if self.n_results == 0:
            n_pages = 1
        else:
            n_pages, left = divmod(self.n_results, self.PAGE_ROWS)
            if left != 0:
                n_pages += 1
        self.current_page_combo_box.blockSignals(True)
        self.current_page_combo_box.clear()
        self.current_page_combo_box.addItems(
            [self.tr("Page {0}").format(i + 1) for i in range(n_pages)])
        if last_page_index >= n_pages:
            self.current_page_combo_box.setCurrentIndex(n_pages - 1)
        else:
            self.current_page_combo_box.setCurrentIndex(last_page_index)
        self.current_page_combo_box.blockSignals(False)

    def update_page(self, page_index: int):
        def write(row: int, col: int, value: str):
            if isinstance(value, str):
                pass
            elif isinstance(value, int):
                value = str(value)
            elif isinstance(value, float):
                value = f"{value: 0.4f}"
            else:
                value = value.__str__()
            item = QTableWidgetItem(value)
            item.setTextAlignment(Qt.AlignCenter)
            self.data_table.setItem(row, col, item)

        # necessary to clear
        self.data_table.clear()
        if page_index == self.n_pages - 1:
            start = page_index * self.PAGE_ROWS
            end = self.n_results
        else:
            start, end = page_index * self.PAGE_ROWS, (page_index +
                                                       1) * self.PAGE_ROWS
        self.data_table.setRowCount(end - start)
        self.data_table.setColumnCount(8)
        self.data_table.setHorizontalHeaderLabels([
            self.tr("Resolver"),
            self.tr("Distribution Type"),
            self.tr("N_components"),
            self.tr("N_iterations"),
            self.tr("Spent Time [s]"),
            self.tr("Final Distance"),
            self.tr("Has Reference"),
            self.tr("Is Reference")
        ])
        sample_names = [
            result.sample.name for result in self.__fitting_results[start:end]
        ]
        self.data_table.setVerticalHeaderLabels(sample_names)
        for row, result in enumerate(self.__fitting_results[start:end]):
            write(row, 0, result.task.resolver)
            write(row, 1,
                  self.get_distribution_name(result.task.distribution_type))
            write(row, 2, result.task.n_components)
            write(row, 3, result.n_iterations)
            write(row, 4, result.time_spent)
            write(
                row, 5,
                self.distance_func(result.sample.distribution,
                                   result.distribution))
            has_ref = result.task.initial_guess is not None or result.task.reference is not None
            write(row, 6, self.tr("Yes") if has_ref else self.tr("No"))
            is_ref = result.uuid in self.__reference_map
            write(row, 7, self.tr("Yes") if is_ref else self.tr("No"))

        self.data_table.resizeColumnsToContents()

    def on_previous_button_clicked(self):
        if self.page_index > 0:
            self.current_page_combo_box.setCurrentIndex(self.page_index - 1)

    def on_next_button_clicked(self):
        if self.page_index < self.n_pages - 1:
            self.current_page_combo_box.setCurrentIndex(self.page_index + 1)

    def get_distribution_name(self, distribution_type: DistributionType):
        if distribution_type == DistributionType.Normal:
            return self.tr("Normal")
        elif distribution_type == DistributionType.Weibull:
            return self.tr("Weibull")
        elif distribution_type == DistributionType.SkewNormal:
            return self.tr("Skew Normal")
        else:
            raise NotImplementedError(distribution_type)

    def add_result(self, result: SSUResult):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.append(result)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def add_results(self, results: typing.List[SSUResult]):
        if self.n_results == 0 or \
            (self.page_index == self.n_pages - 1 and \
            divmod(self.n_results, self.PAGE_ROWS)[-1] != 0):
            need_update = True
        else:
            need_update = False
        self.__fitting_results.extend(results)
        self.update_page_list()
        if need_update:
            self.update_page(self.page_index)

    def mark_results(self, results: typing.List[SSUResult]):
        for result in results:
            self.__reference_map[result.uuid] = result

        self.update_page(self.page_index)

    def unmark_results(self, results: typing.List[SSUResult]):
        for result in results:
            if result.uuid in self.__reference_map:
                self.__reference_map.pop(result.uuid)

        self.update_page(self.page_index)

    def add_references(self, results: typing.List[SSUResult]):
        self.add_results(results)
        self.mark_results(results)

    def mark_selections(self):
        results = [
            self.__fitting_results[selection] for selection in self.selections
        ]
        self.mark_results(results)

    def unmark_selections(self):
        results = [
            self.__fitting_results[selection] for selection in self.selections
        ]
        self.unmark_results(results)

    def remove_results(self, indexes):
        results = []
        for i in reversed(indexes):
            res = self.__fitting_results.pop(i)
            results.append(res)
        self.unmark_results(results)
        self.update_page_list()
        self.update_page(self.page_index)

    def remove_selections(self):
        indexes = self.selections
        self.remove_results(indexes)

    def remove_all_results(self):
        res = self.remove_warning_msg.exec_()
        if res == QMessageBox.Yes:
            self.__fitting_results.clear()
            self.update_page_list()
            self.update_page(0)

    def show_distance(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.distance_chart.show_distance_series(result.get_distance_series(
            self.distance_name),
                                                 title=result.sample.name)
        self.distance_chart.show()

    def show_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_model(result.view_model)
        self.mixed_distribution_chart.show()

    def show_history_distribution(self):
        results = [self.__fitting_results[i] for i in self.selections]
        if results is None or len(results) == 0:
            return
        result = results[0]
        self.mixed_distribution_chart.show_result(result)
        self.mixed_distribution_chart.show()

    def load_dump(self, mark_ref=False):
        filename, _ = self.file_dialog.getOpenFileName(
            self, self.tr("Select a binary dump file of SSU results"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "rb") as f:
            results = pickle.load(f)
            valid = True
            if isinstance(results, list):
                for result in results:
                    if not isinstance(result, SSUResult):
                        valid = False
                        break
            else:
                valid = False

            if valid:
                self.add_results(results)
                if mark_ref:
                    self.mark_results(results)
            else:
                self.show_error(self.tr("The binary dump file is invalid."))

    def save_dump(self):
        if self.n_results == 0:
            self.show_warning(self.tr("There is not any result in the list."))
            return
        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Save the SSU results to binary dump file"), None,
            self.tr("Binary dump (*.dump)"))
        if filename is None or filename == "":
            return
        with open(filename, "wb") as f:
            pickle.dump(self.__fitting_results, f)

    def find_similar(self, target: GrainSizeSample,
                     ref_results: typing.List[SSUResult]):
        assert len(ref_results) != 0
        # sample_moments = logarithmic(sample.classes_φ, sample.distribution)
        # keys_to_check = ["mean", "std", "skewness", "kurtosis"]

        start_time = time.time()
        from scipy.interpolate import interp1d
        min_distance = 1e100
        min_result = None
        trans_func = interp1d(target.classes_φ,
                              target.distribution,
                              bounds_error=False,
                              fill_value=0.0)
        for result in ref_results:
            # TODO: To scale the classes of result to that of sample
            # use moments to calculate? MOMENTS MAY NOT BE PERFECT, MAY IGNORE THE MINOR DIFFERENCE
            # result_moments = logarithmic(result.classes_φ, result.distribution)
            # distance = sum([(sample_moments[key]-result_moments[key])**2 for key in keys_to_check])
            trans_dist = trans_func(result.classes_φ)
            distance = self.distance_func(result.distribution, trans_dist)

            if distance < min_distance:
                min_distance = distance
                min_result = result

        self.logger.debug(
            f"It took {time.time()-start_time:0.4f} s to query the reference from {len(ref_results)} results."
        )
        return min_result

    def query_reference(self, sample: GrainSizeSample):
        if len(self.__reference_map) == 0:
            self.logger.debug("No result is marked as reference.")
            return None
        return self.find_similar(sample, self.__reference_map.values())
Пример #11
0
    def init_ui(self):
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.main_layout = QGridLayout(self)
        # self.main_layout.setContentsMargins(0, 0, 0, 0)

        self.sampling_group = QGroupBox(self.tr("Sampling"))
        # self.control_group.setFixedSize(400, 160)
        self.control_layout = QGridLayout(self.sampling_group)
        self.minimum_size_label = QLabel(self.tr("Minimum Size [μm]"))
        self.minimum_size_input = QDoubleSpinBox()
        self.minimum_size_input.setDecimals(2)
        self.minimum_size_input.setRange(1e-4, 1e6)
        self.minimum_size_input.setValue(0.0200)
        self.maximum_size_label = QLabel(self.tr("Maximum Size [μm]"))
        self.maximum_size_input = QDoubleSpinBox()
        self.maximum_size_input.setDecimals(2)
        self.maximum_size_input.setRange(1e-4, 1e6)
        self.maximum_size_input.setValue(2000.0000)
        self.control_layout.addWidget(self.minimum_size_label, 0, 0)
        self.control_layout.addWidget(self.minimum_size_input, 0, 1)
        self.control_layout.addWidget(self.maximum_size_label, 0, 2)
        self.control_layout.addWidget(self.maximum_size_input, 0, 3)
        self.n_classes_label = QLabel(self.tr("N<sub>classes</sub>"))
        self.n_classes_input = QSpinBox()
        self.n_classes_input.setRange(10, 1e4)
        self.n_classes_input.setValue(101)
        self.precision_label = QLabel(self.tr("Data Precision"))
        self.precision_input = QSpinBox()
        self.precision_input.setRange(2, 8)
        self.precision_input.setValue(4)
        self.control_layout.addWidget(self.n_classes_label, 1, 0)
        self.control_layout.addWidget(self.n_classes_input, 1, 1)
        self.control_layout.addWidget(self.precision_label, 1, 2)
        self.control_layout.addWidget(self.precision_input, 1, 3)
        self.component_number_label = QLabel(self.tr("N<sub>components</sub>"))
        self.component_number_input = QSpinBox()
        self.component_number_input.setRange(1, 10)
        self.component_number_input.valueChanged.connect(
            self.on_n_components_changed)
        self.preview_button = QPushButton(qta.icon("mdi.animation-play"),
                                          self.tr("Preview"))
        self.preview_button.clicked.connect(self.on_preview_clicked)
        self.control_layout.addWidget(self.component_number_label, 2, 0)
        self.control_layout.addWidget(self.component_number_input, 2, 1)
        self.control_layout.addWidget(self.preview_button, 2, 2, 1, 2)
        self.main_layout.addWidget(self.sampling_group, 0, 0)

        self.save_group = QGroupBox(self.tr("Save"))
        # self.save_group.setFixedHeight(160)
        self.save_layout = QGridLayout(self.save_group)
        self.n_samples_label = QLabel(self.tr("N<sub>samples</sub>"))
        self.n_samples_input = QSpinBox()
        self.n_samples_input.setRange(100, 100000)
        self.save_layout.addWidget(self.n_samples_label, 0, 0)
        self.save_layout.addWidget(self.n_samples_input, 0, 1)

        self.cancel_button = QPushButton(qta.icon("mdi.cancel"),
                                         self.tr("Cancel"))
        self.cancel_button.setEnabled(False)
        self.cancel_button.clicked.connect(self.on_cancel_clicked)
        self.generate_button = QPushButton(qta.icon("mdi.cube-send"),
                                           self.tr("Generate"))
        self.generate_button.clicked.connect(self.on_generate_clicked)
        self.progress_bar = QProgressBar()
        self.progress_bar.setRange(0, 100)
        self.progress_bar.setOrientation(Qt.Horizontal)
        self.progress_bar.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
        self.save_layout.addWidget(self.cancel_button, 1, 0)
        self.save_layout.addWidget(self.generate_button, 1, 1)
        self.save_layout.addWidget(self.progress_bar, 2, 0, 1, 2)
        self.main_layout.addWidget(self.save_group, 0, 1)

        self.param_group = QGroupBox("Random Parameter")
        # self.param_group.setFixedWidth(400)
        self.param_layout = QGridLayout(self.param_group)
        self.main_layout.addWidget(self.param_group, 1, 0)

        self.preview_group = QGroupBox(self.tr("Preview"))
        self.chart_layout = QGridLayout(self.preview_group)

        self.chart = MixedDistributionChart(parent=self, toolbar=False)
        self.chart_layout.addWidget(self.chart, 0, 0)
        self.chart.setSizePolicy(QSizePolicy.MinimumExpanding,
                                 QSizePolicy.MinimumExpanding)
        self.main_layout.addWidget(self.preview_group, 1, 1)
Пример #12
0
class RandomDatasetGenerator(QDialog):
    logger = logging.getLogger("root.ui.RandomGeneratorWidget")
    gui_logger = logging.getLogger("GUI")

    def __init__(self, parent=None):
        super().__init__(parent=parent, f=Qt.Window)
        self.setWindowTitle(self.tr("Dataset Generator"))
        self.last_n_components = 0
        self.components = []  # typing.List[RandomGeneratorComponentWidget]
        self.component_series = []
        self.init_ui()
        self.target = LOESS
        self.minimum_size_input.setValue(0.02)
        self.maximum_size_input.setValue(2000.0)
        self.n_classes_input.setValue(101)
        self.precision_input.setValue(4)

        self.file_dialog = QFileDialog(parent=self)
        self.update_timer = QTimer()
        self.update_timer.timeout.connect(lambda: self.update_chart(True))
        self.cancel_flag = False

    def init_ui(self):
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.main_layout = QGridLayout(self)
        # self.main_layout.setContentsMargins(0, 0, 0, 0)

        self.sampling_group = QGroupBox(self.tr("Sampling"))
        # self.control_group.setFixedSize(400, 160)
        self.control_layout = QGridLayout(self.sampling_group)
        self.minimum_size_label = QLabel(self.tr("Minimum Size [μm]"))
        self.minimum_size_input = QDoubleSpinBox()
        self.minimum_size_input.setDecimals(2)
        self.minimum_size_input.setRange(1e-4, 1e6)
        self.minimum_size_input.setValue(0.0200)
        self.maximum_size_label = QLabel(self.tr("Maximum Size [μm]"))
        self.maximum_size_input = QDoubleSpinBox()
        self.maximum_size_input.setDecimals(2)
        self.maximum_size_input.setRange(1e-4, 1e6)
        self.maximum_size_input.setValue(2000.0000)
        self.control_layout.addWidget(self.minimum_size_label, 0, 0)
        self.control_layout.addWidget(self.minimum_size_input, 0, 1)
        self.control_layout.addWidget(self.maximum_size_label, 0, 2)
        self.control_layout.addWidget(self.maximum_size_input, 0, 3)
        self.n_classes_label = QLabel(self.tr("N<sub>classes</sub>"))
        self.n_classes_input = QSpinBox()
        self.n_classes_input.setRange(10, 1e4)
        self.n_classes_input.setValue(101)
        self.precision_label = QLabel(self.tr("Data Precision"))
        self.precision_input = QSpinBox()
        self.precision_input.setRange(2, 8)
        self.precision_input.setValue(4)
        self.control_layout.addWidget(self.n_classes_label, 1, 0)
        self.control_layout.addWidget(self.n_classes_input, 1, 1)
        self.control_layout.addWidget(self.precision_label, 1, 2)
        self.control_layout.addWidget(self.precision_input, 1, 3)
        self.component_number_label = QLabel(self.tr("N<sub>components</sub>"))
        self.component_number_input = QSpinBox()
        self.component_number_input.setRange(1, 10)
        self.component_number_input.valueChanged.connect(
            self.on_n_components_changed)
        self.preview_button = QPushButton(qta.icon("mdi.animation-play"),
                                          self.tr("Preview"))
        self.preview_button.clicked.connect(self.on_preview_clicked)
        self.control_layout.addWidget(self.component_number_label, 2, 0)
        self.control_layout.addWidget(self.component_number_input, 2, 1)
        self.control_layout.addWidget(self.preview_button, 2, 2, 1, 2)
        self.main_layout.addWidget(self.sampling_group, 0, 0)

        self.save_group = QGroupBox(self.tr("Save"))
        # self.save_group.setFixedHeight(160)
        self.save_layout = QGridLayout(self.save_group)
        self.n_samples_label = QLabel(self.tr("N<sub>samples</sub>"))
        self.n_samples_input = QSpinBox()
        self.n_samples_input.setRange(100, 100000)
        self.save_layout.addWidget(self.n_samples_label, 0, 0)
        self.save_layout.addWidget(self.n_samples_input, 0, 1)

        self.cancel_button = QPushButton(qta.icon("mdi.cancel"),
                                         self.tr("Cancel"))
        self.cancel_button.setEnabled(False)
        self.cancel_button.clicked.connect(self.on_cancel_clicked)
        self.generate_button = QPushButton(qta.icon("mdi.cube-send"),
                                           self.tr("Generate"))
        self.generate_button.clicked.connect(self.on_generate_clicked)
        self.progress_bar = QProgressBar()
        self.progress_bar.setRange(0, 100)
        self.progress_bar.setOrientation(Qt.Horizontal)
        self.progress_bar.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
        self.save_layout.addWidget(self.cancel_button, 1, 0)
        self.save_layout.addWidget(self.generate_button, 1, 1)
        self.save_layout.addWidget(self.progress_bar, 2, 0, 1, 2)
        self.main_layout.addWidget(self.save_group, 0, 1)

        self.param_group = QGroupBox("Random Parameter")
        # self.param_group.setFixedWidth(400)
        self.param_layout = QGridLayout(self.param_group)
        self.main_layout.addWidget(self.param_group, 1, 0)

        self.preview_group = QGroupBox(self.tr("Preview"))
        self.chart_layout = QGridLayout(self.preview_group)

        self.chart = MixedDistributionChart(parent=self, toolbar=False)
        self.chart_layout.addWidget(self.chart, 0, 0)
        self.chart.setSizePolicy(QSizePolicy.MinimumExpanding,
                                 QSizePolicy.MinimumExpanding)
        self.main_layout.addWidget(self.preview_group, 1, 1)

    @staticmethod
    def to_points(x, y):
        return [QPointF(x_value, y_value) for x_value, y_value in zip(x, y)]

    def on_n_components_changed(self, n_components: int):
        if self.last_n_components < n_components:
            for component_index in range(self.last_n_components, n_components):
                component = RandomGeneratorComponentWidget(
                    name=f"AC{component_index+1}")
                component.value_changed.connect(self.on_value_changed)
                self.param_layout.addWidget(component, component_index + 1, 0)
                self.components.append(component)

        if self.last_n_components > n_components:
            for i in range(n_components, self.last_n_components):
                before_component = self.components[i]
                before_component.value_changed.disconnect(
                    self.on_value_changed)
                self.param_layout.removeWidget(before_component)
                # if not hide, the widget will still display on screen
                before_component.hide()
                self.components.pop(n_components)

        self.last_n_components = n_components

    def on_preview_clicked(self):
        if self.update_timer.isActive():
            self.preview_button.setText(self.tr("Preview"))
            self.update_timer.stop()
            self.update_chart()
        else:
            self.preview_button.setText(self.tr("Stop"))
            self.update_timer.start(200)

    def on_cancel_clicked(self):
        self.cancel_flag = True

    def on_generate_clicked(self):
        if self.update_timer.isActive():
            self.preview_button.setText(self.tr("Preview"))
            self.update_timer.stop()
            self.update_chart()

        filename, _ = self.file_dialog.getSaveFileName(
            self, self.tr("Choose a filename to save the generated dataset"),
            None, "Microsoft Excel (*.xlsx)")
        if filename is None or filename == "":
            return
        n_samples = self.n_samples_input.value()
        dataset = self.get_random_dataset(n_samples)
        # generate samples
        self.cancel_button.setEnabled(True)
        self.generate_button.setEnabled(False)
        format_str = self.tr("Generating {0} samples: %p%").format(n_samples)
        self.progress_bar.setFormat(format_str)
        self.progress_bar.setValue(0)

        def cancel():
            self.progress_bar.setFormat(self.tr("Task canceled"))
            self.progress_bar.setValue(0)
            self.cancel_button.setEnabled(False)
            self.generate_button.setEnabled(True)
            self.cancel_flag = False

        samples = []
        for i in range(n_samples):
            if self.cancel_flag:
                cancel()
                return
            sample = dataset.get_sample(i)
            samples.append(sample)
            progress = (i + 1) / n_samples * 50
            self.progress_bar.setValue(progress)
            QCoreApplication.processEvents()

        # save file to excel file
        format_str = self.tr("Writing {0} samples to excel file: %p%").format(
            n_samples)
        self.progress_bar.setFormat(format_str)
        self.progress_bar.setValue(50)

        wb = openpyxl.Workbook()
        prepare_styles(wb)

        ws = wb.active
        ws.title = self.tr("README")
        description = \
            """
            This Excel file was generated by QGrain ({0}).

            Please cite:
            Liu, Y., Liu, X., Sun, Y., 2021. QGrain: An open-source and easy-to-use software for the comprehensive analysis of grain size distributions. Sedimentary Geology 423, 105980. https://doi.org/10.1016/j.sedgeo.2021.105980

            It contanins n_components + 3 sheets:
            1. The first sheet is the random settings which were used to generate random parameters.
            2. The second sheet is the generated dataset.
            3. The third sheet is random parameters which were used to calulate the component distributions and their mixture.
            4. The left sheets are the component distributions of all samples.

            Artificial dataset
                Using skew normal distribution as the base distribution of each component (i.e. end-member).
                Skew normal distribution has three parameters, shape, location and scale.
                Where shape controls the skewness, location and scale are simliar to that of the Normal distribution.
                When shape = 0, it becomes Normal distribution.
                The weight parameter controls the fraction of the component, where fraction_i = weight_i / sum(weight_i).
                By assigning the mean and std of each parameter, random parameters was generate by the `scipy.stats.truncnorm.rvs` function of Scipy.

            Sampling settings
                Minimum size [μm]: {1},
                Maximum size [μm]: {2},
                N_classes: {3},
                Precision: {4},
                Noise: {5},
                N_samples: {6}

            """.format(QGRAIN_VERSION,
                       self.minimum_size_input.value(),
                       self.maximum_size_input.value(),
                       self.n_classes_input.value(),
                       self.precision_input.value(),
                       self.precision_input.value()+1,
                       n_samples)

        def write(row, col, value, style="normal_light"):
            cell = ws.cell(row + 1, col + 1, value=value)
            cell.style = style

        lines_of_desc = description.split("\n")
        for row, line in enumerate(lines_of_desc):
            write(row, 0, line, style="description")
        ws.column_dimensions[column_to_char(0)].width = 200

        ws = wb.create_sheet(self.tr("Random Settings"))
        write(0, 0, self.tr("Parameter"), style="header")
        ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
        write(0, 1, self.tr("Shape"), style="header")
        ws.merge_cells(start_row=1, start_column=2, end_row=1, end_column=3)
        write(0, 3, self.tr("Location"), style="header")
        ws.merge_cells(start_row=1, start_column=4, end_row=1, end_column=5)
        write(0, 5, self.tr("Scale"), style="header")
        ws.merge_cells(start_row=1, start_column=6, end_row=1, end_column=7)
        write(0, 7, self.tr("Weight"), style="header")
        ws.merge_cells(start_row=1, start_column=8, end_row=1, end_column=9)
        ws.column_dimensions[column_to_char(0)].width = 16
        for col in range(1, 9):
            ws.column_dimensions[column_to_char(col)].width = 16
            if col % 2 == 0:
                write(1, col, self.tr("Mean"), style="header")
            else:
                write(1, col, self.tr("STD"), style="header")
        for row, comp_params in enumerate(self.target, 2):
            if row % 2 == 1:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, self.tr("Component{0}").format(row - 1), style=style)
            for i, key in enumerate(["shape", "loc", "scale", "weight"]):
                mean, std = comp_params[key]
                write(row, i * 2 + 1, mean, style=style)
                write(row, i * 2 + 2, std, style=style)

        ws = wb.create_sheet(self.tr("Dataset"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.column_dimensions[column_to_char(0)].width = 24
        for col, value in enumerate(dataset.classes_μm, 1):
            write(0, col, value, style="header")
            ws.column_dimensions[column_to_char(col)].width = 10
        for row, sample in enumerate(samples, 1):
            if row % 2 == 0:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, sample.name, style=style)
            for col, value in enumerate(sample.distribution, 1):
                write(row, col, value, style=style)

            if self.cancel_flag:
                cancel()
                return
            progress = 50 + (row / n_samples) * 10
            self.progress_bar.setValue(progress)
            QCoreApplication.processEvents()

        ws = wb.create_sheet(self.tr("Parameters"))
        write(0, 0, self.tr("Sample Name"), style="header")
        ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
        ws.column_dimensions[column_to_char(0)].width = 24
        for i in range(dataset.n_components):
            write(0,
                  4 * i + 1,
                  self.tr("Component{0}").format(i + 1),
                  style="header")
            ws.merge_cells(start_row=1,
                           start_column=4 * i + 2,
                           end_row=1,
                           end_column=4 * i + 5)
            for j, header_name in enumerate([
                    self.tr("Shape"),
                    self.tr("Location"),
                    self.tr("Scale"),
                    self.tr("Weight")
            ]):
                write(1, 4 * i + 1 + j, header_name, style="header")
                ws.column_dimensions[column_to_char(4 * i + 1 + j)].width = 16
        for row, sample in enumerate(samples, 2):
            if row % 2 == 1:
                style = "normal_dark"
            else:
                style = "normal_light"
            write(row, 0, sample.name, style=style)
            for i, comp_param in enumerate(sample.parameter.components):
                write(row, 4 * i + 1, comp_param.shape, style=style)
                write(row, 4 * i + 2, comp_param.loc, style=style)
                write(row, 4 * i + 3, comp_param.scale, style=style)
                write(row, 4 * i + 4, comp_param.weight, style=style)
            if self.cancel_flag:
                cancel()
                return
            progress = 60 + (row / n_samples) * 10
            self.progress_bar.setValue(progress)
            QCoreApplication.processEvents()

        for i in range(dataset.n_components):
            ws = wb.create_sheet(self.tr("Component{0}").format(i + 1))
            write(0, 0, self.tr("Sample Name"), style="header")
            ws.column_dimensions[column_to_char(0)].width = 24
            for col, value in enumerate(dataset.classes_μm, 1):
                write(0, col, value, style="header")
                ws.column_dimensions[column_to_char(col)].width = 10
            for row, sample in enumerate(samples, 1):
                if row % 2 == 0:
                    style = "normal_dark"
                else:
                    style = "normal_light"
                write(row, 0, sample.name, style=style)
                for col, value in enumerate(sample.components[i].distribution,
                                            1):
                    write(row, col, value, style=style)
            if self.cancel_flag:
                cancel()
                return
            progress = 70 + (
                (i * n_samples + row) / n_samples * dataset.n_components) * 30
            self.progress_bar.setValue(progress)
            QCoreApplication.processEvents()
        wb.save(filename)
        wb.close()

        self.progress_bar.setValue(100)
        self.progress_bar.setFormat(self.tr("Task finished"))
        self.cancel_button.setEnabled(False)
        self.generate_button.setEnabled(True)

    @property
    def target(self):
        return [comp.target for comp in self.components]

    @target.setter
    def target(self, values):
        if len(values) != len(self.components):
            self.component_number_input.setValue(len(values))
        for comp, comp_target in zip(self.components, values):
            comp.blockSignals(True)
            comp.target = comp_target
            comp.blockSignals(False)
        self.update_chart()

    def get_random_sample(self):
        dataset = self.get_random_dataset(n_samples=1)
        sample = dataset.get_sample(0)
        sample.name = self.tr("Artificial Sample")
        return sample

    def get_random_mean(self):
        dataset = self.get_random_dataset(n_samples=1)
        random_setting = RandomSetting(self.target)
        sample = dataset.get_sample_by_params(self.tr("Artificial Sample"),
                                              random_setting.mean_param)
        return sample

    def get_random_dataset(self, n_samples):
        min_μm = self.minimum_size_input.value()
        max_μm = self.maximum_size_input.value()
        n_classes = self.n_classes_input.value()
        if min_μm == max_μm:
            return
        if min_μm > max_μm:
            min_μm, max_μm = max_μm, min_μm
        precision = self.precision_input.value()
        noise = precision + 1

        dataset = get_random_dataset(target=self.target,
                                     n_samples=n_samples,
                                     min_μm=min_μm,
                                     max_μm=max_μm,
                                     n_classes=n_classes,
                                     precision=precision,
                                     noise=noise)
        return dataset

    def on_value_changed(self):
        self.update_chart()

    def update_chart(self, random=False):
        if not random:
            sample = self.get_random_mean()
        else:
            sample = self.get_random_sample()
        self.chart.show_model(sample.view_model)

    def closeEvent(self, event):
        if self.cancel_button.isEnabled():
            self.on_cancel_clicked()
        event.accept()