Exemplo n.º 1
0
    def test_settings(self):

        lmfit_new = LMFitModel(TestLMFitModel.sliced_data,
                               TestLMFitModel.orbitals)

        lmfit_new.set_settings(self.lmfit.get_settings())
Exemplo n.º 2
0
class LMFitResultTab(LMFitBaseTab, LMFitTab_UI):

    open_plot_tab = pyqtSignal(list, list, Axis)

    def __init__(self, data, results, settings):

        self.results = results
        self.current_result = self.results[0]

        self.model = LMFitModel(*data)
        self.model.set_settings(settings)

        # Setup GUI
        super(LMFitResultTab, self).__init__()
        self.setupUi(self)

        self._setup()
        self._connect()

        self.refresh_sliced_plot()
        self.refresh_selected_plot()
        kmap = self.refresh_sum_plot()
        self.refresh_residual_plot(weight_sum_data=kmap)

    def get_title(self):

        return 'Results'

    def change_slice(self):

        index = self.slider.get_index()

        self.current_result = self.results[0] if len(
            self.results) == 1 else self.results[index]

        self.update_tree()
        self.result.result = self.current_result[1]
        self.refresh_sliced_plot()
        self._refresh_orbital_plots()

    def refresh_selected_plot(self):

        params = self.current_result[1].params
        super().refresh_selected_plot(params)

    def refresh_sum_plot(self):

        params = self.current_result[1].params
        return super().refresh_sum_plot(params)

    def refresh_residual_plot(self, weight_sum_data=None):

        params = self.current_result[1].params
        super().refresh_residual_plot(params, weight_sum_data)

    def update_tree(self):

        params = self.current_result[1].params
        self.tree.update_result(params)

    def print_result(self):

        report = self.result.get_fit_report()

        print(report)

    def print_covariance_matrix(self):

        cov_matrix = self.result.get_covariance_matrix()

        print(cov_matrix)

    def plot(self):

        results = [result[1] for result in self.results]
        orbitals = self.model.orbitals
        axis = self.model.sliced_data.axes[self.model.slice_policy[0]]
        self.open_plot_tab.emit(results, orbitals, axis)

    def _setup(self):

        LMFitBaseTab._setup(self)

        self.result = LMFitResult(self.current_result[1], self.model)
        self.tree = LMFitResultTree(self.model.orbitals,
                                    self.current_result[1].params,
                                    self.model.background_equation[1])
        self.crosshair._set_model(self.model.crosshair)

        layout = QVBoxLayout()
        self.scroll_area.widget().setLayout(layout)
        layout.insertWidget(0, self.slider)
        layout.insertWidget(1, self.result)
        layout.insertWidget(3, self.colormap)
        layout.insertWidget(4, self.crosshair)

        self.layout.insertWidget(1, self.tree)

    def _connect(self):

        self.result.print_triggered.connect(self.print_result)
        self.result.cov_matrix_requested.connect(self.print_covariance_matrix)
        self.result.plot_requested.connect(self.plot)

        LMFitBaseTab._connect(self)
Exemplo n.º 3
0
class LMFitResultTab(LMFitBaseTab, LMFitTab_UI):
    open_plot_tab = pyqtSignal(list, list, Axis, list, list)

    def __init__(self, lmfit_tab, results, settings):
        
        self.results = results
        self.lmfit_tab = lmfit_tab
        self.current_result = self.results[0]
        self.settings = settings
        self.model = LMFitModel(*self.lmfit_tab.get_data())
        self.model.set_settings(settings)

        # Setup GUI
        super(LMFitResultTab, self).__init__()
        self.setupUi(self)

        self._setup()
        self._connect()

        self.refresh_all()

    @classmethod
    def init_from_save(cls, save, dependencies, tab_widget):
        results = save['results']
        settings = save['settings']
        tab = tab_widget.get_tab_by_ID(dependencies['lmfittab'])
        
        self = LMFitResultTab(tab, results, settings)
        self.slider.restore_state(save['slider'])
        self.crosshair.restore_state(save['crosshair'])
        self.colormap.restore_state(save['colormap'])
        self.sliced_plot.set_colormap(save['colorscales']['sliced'])
        self.sliced_plot.set_levels(save['levels']['sliced'])
        self.sum_plot.set_colormap(save['colorscales']['sum'])
        self.sum_plot.set_levels(save['levels']['sum'])
        self.residual_plot.set_levels(save['levels']['residual'])
        self.residual_plot.set_colormap(save['colorscales']['residual'])
        self.selected_plot.set_levels(save['levels']['selected'])
        self.selected_plot.set_colormap(save['colorscales']['selected'])

        self.locked_tabs = [tab]

        return self

    def save_state(self):
        save, dependencies = super().save_state()

        save.update({'title': self.title,
                'crosshair': self.crosshair.save_state(),
                'slider': self.slider.save_state(),
                'colormap': self.colormap.save_state(),
                'results': self.results,
                'settings': self.settings})
        
        dependencies.update({'lmfittab': self.lmfit_tab.ID})

        return save, dependencies

    def get_title(self):
        return 'Results'

    def change_slice(self):
        index = self.slider.get_index()
        slices = [result[0] for result in self.results]
        
        if len(self.results) == 1:
            self.current_result = self.results[-1]

        elif index <= slices[0]:
            self.current_result = self.results[0]

        elif index >= slices[-1]:
            self.current_result = self.results[-1]

        else:
            self.current_result = self.results[index-slices[0]]

        self.update_tree()
        self.result.result = self.current_result[1]
        self.refresh_sliced_plot()
        self._refresh_orbital_plots()

    def refresh_selected_plot(self):
        params = self.current_result[1].params
        super().refresh_selected_plot(params)

    def refresh_sum_plot(self):
        params = self.current_result[1].params
        return super().refresh_sum_plot(params)

    def refresh_residual_plot(self, weight_sum_data=None):
        params = self.current_result[1].params
        super().refresh_residual_plot(params, weight_sum_data)

    def update_tree(self):
        params = self.current_result[1].params
        self.tree.update_result(params)

    def print_result(self):
        report = self.result.get_fit_report()

        print(report)

    def print_covariance_matrix(self):
        cov_matrix = self.result.get_covariance_matrix()

        print(cov_matrix)

    def get_orbitals(self):
        return self.model.orbitals

    def plot(self):
        indices = [result[0] for result in self.results]
        results = [result[1] for result in self.results]
        orbitals = self.model.orbitals
        axis = self.model.sliced_data.axes[self.model.slice_policy[0]]
        axis = axis.sublist(indices)
        kmaps = abs(self.get_residual_kmaps())
        residuals = list(np.nansum(np.nansum(kmaps, axis=1), axis=1))
        
        self.open_plot_tab.emit(results, orbitals, axis, residuals,
                                self.model.background_equation[1])

    def export_to_hdf5(self):
        path = config.get_key('paths', 'hdf5_export_start')
        if path == 'None':
            file_name, _ = QFileDialog.getSaveFileName(
                None, 'Save .hdf5 File (*.hdf5)')
        else:
            start_path = str(__directory__ / path)
            file_name, _ = QFileDialog.getSaveFileName(
                None, 'Save .hdf5 File (*.hdf5)', str(start_path))

        if not file_name:
            return
        else:
            h5file = h5py.File(file_name, 'w')

        kmaps = self.get_residual_kmaps()

        slice_axis = self.slider.data.axes[0]
        x_axis = self.slider.data.axes[1]
        y_axis = self.slider.data.axes[2]

        h5file.create_dataset('name', data='Residual')
        h5file.create_dataset('axis_1_label', data=slice_axis.label)
        h5file.create_dataset('axis_2_label', data=x_axis.label)
        h5file.create_dataset('axis_3_label', data=y_axis.label)
        h5file.create_dataset('axis_1_units', data=slice_axis.units)
        h5file.create_dataset('axis_2_units', data=x_axis.units)
        h5file.create_dataset('axis_3_units', data=y_axis.units)
        h5file.create_dataset('axis_1_range', data=slice_axis.range)
        h5file.create_dataset('axis_2_range', data=x_axis.range)
        h5file.create_dataset('axis_3_range', data=y_axis.range)
        h5file.create_dataset('data', data=kmaps, dtype='f8',
                              compression='gzip', compression_opts=9)
        h5file.close()

    def get_residual_kmaps(self):
        kmaps = []
        for i, result in enumerate(self.results):
            residual = self.model.get_residual(i, result[1].params)
            kmaps.append(residual.data)
        kmaps = np.array(kmaps)

        return kmaps

    def _setup(self):
        LMFitBaseTab._setup(self)

        self.result = LMFitResult(self.current_result[1], self.model)
        self.tree = LMFitResultTree(
            self.model.orbitals, self.current_result[1].params,
            self.model.background_equation[1])
        self.crosshair._set_model(self.model.crosshair)

        layout = QVBoxLayout()
        self.scroll_area.widget().setLayout(layout)
        layout.insertWidget(0, self.slider)
        layout.insertWidget(1, self.result)
        layout.insertWidget(3, self.colormap)
        layout.insertWidget(4, self.crosshair)

        self.layout.insertWidget(1, self.tree)

    def _connect(self):
        self.result.print_triggered.connect(self.print_result)
        self.result.cov_matrix_requested.connect(self.print_covariance_matrix)
        self.result.plot_requested.connect(self.plot)

        LMFitBaseTab._connect(self)