class LMFitTab(LMFitBaseTab, LMFitTab_UI): fit_finished = pyqtSignal(list, list, dict) def __init__(self, sliced_data, orbitals): self.model = LMFitModel(sliced_data, orbitals) # Setup GUI super(LMFitTab, self).__init__() self.setupUi(self) self._setup() self._connect() self.refresh_sliced_plot() self.refresh_selected_plot() kmap = self.refresh_sum_plot() self.refresh_residual_plot(weight_sum_data=kmap) def get_title(self): return 'LM-Fit' def trigger_fit(self): results = self.model.fit() settings = self.model.get_settings() data = [self.model.sliced_data, self.model.orbitals] self.fit_finished.emit(data, results, settings) def change_slice(self): axis_index = self.slider.get_axis() slice_policy = self.lmfit_options.get_slice_policy() combined = True if slice_policy == 'all combined' else False slice_indices = (self.slider.get_index() if slice_policy == 'only one' else 'all') self.model.set_slices(slice_indices, axis_index=axis_index, combined=combined) self.refresh_sliced_plot() self.refresh_residual_plot() def change_axis(self): axis = self.interpolation.get_axis() self.model.set_axis(axis) self.refresh_sliced_plot() self.refresh_selected_plot() kmap = self.refresh_sum_plot() self.refresh_residual_plot(weight_sum_data=kmap) def _change_slice_policy(self, slice_policy): axis = self.slider.get_axis() if slice_policy == 'all': self.model.set_slices('all', axis_index=axis, combined=False) elif slice_policy == 'only one': index = self.slider.get_index() self.model.set_slices([index], axis_index=axis, combined=False) else: self.model.set_slices('all', axis_index=axis, combined=True) def _change_region(self, *args): self.model.set_region(*args) self.refresh_residual_plot() def _change_background(self, *args): new_variables = self.model.set_background_equation(*args) for variable in new_variables: self.tree.add_equation_parameter(variable) self.refresh_sum_plot() self.refresh_residual_plot() def _setup(self): LMFitBaseTab._setup(self) self.orbital_options = LMFitOrbitalOptions() self.tree = LMFitTree(self.model.orbitals, self.model.parameters) self.interpolation = LMFitInterpolation() self.lmfit_options = LMFitOptions(self) self.change_axis() self.model.set_crosshair(self.crosshair.model) self._change_background(self.lmfit_options.get_background()) layout = QVBoxLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setSpacing(3) self.scroll_area.widget().setLayout(layout) layout.insertWidget(0, self.slider) layout.insertWidget(1, self.orbital_options) layout.insertWidget(2, self.interpolation) layout.insertWidget(3, self.lmfit_options) layout.insertWidget(4, self.colormap) layout.insertWidget(5, self.crosshair) self.layout.insertWidget(1, self.tree) def _connect(self): LMFitBaseTab._connect(self) self.interpolation.interpolation_changed.connect(self.change_axis) self.tree.value_changed.connect(self._refresh_orbital_plots) self.tree.vary_changed.connect(self.update_chi2_label) self.lmfit_options.background_changed.connect(self._change_background) self.lmfit_options.fit_triggered.connect(self.trigger_fit) self.lmfit_options.method_changed.connect(self.model.set_fit_method) self.lmfit_options.slice_policy_changed.connect( self._change_slice_policy) self.orbital_options.symmetrization_changed.connect( self.model.set_symmetrization) self.orbital_options.symmetrization_changed.connect( self._refresh_orbital_plots) self.orbital_options.polarization_changed.connect( self.model.set_polarization) self.orbital_options.polarization_changed.connect( self._refresh_orbital_plots)
class LMFitTab(LMFitBaseTab, LMFitTab_UI): fit_finished = pyqtSignal(list, dict) def __init__(self, sliced_tab, orbital_tab, max_orbitals=-1): self.sliced_tab = sliced_tab self.orbital_tab = orbital_tab if max_orbitals != -1: self.model = LMFitModel(sliced_tab.get_data(), orbital_tab.get_orbitals()[:max_orbitals]) else: self.model = LMFitModel(sliced_tab.get_data(), orbital_tab.get_orbitals()) # Setup GUI super(LMFitTab, self).__init__() self.setupUi(self) self._setup() self._connect() self.refresh_all() @classmethod def init_from_save(cls, save, dependencies, tab_widget): sliced_tab = tab_widget.get_tab_by_ID(dependencies['sliced_tab']) orbital_tab = tab_widget.get_tab_by_ID(dependencies['orbital_tab']) # max orbitals: If orbitals were loaded after lmfit was created they # will not be reflected in the results -> not load them self = cls(sliced_tab, orbital_tab, max_orbitals=len(save['tree'])-2) self.locked_tabs = [sliced_tab, orbital_tab] self.title = save['title'] self.tree.restore_state(save['tree']) self.interpolation.restore_state(save['interpolation']), self.lmfit_options.restore_state(save['lmfit_options']) self.orbital_options.restore_state(save['orbital_options']), self.slider.restore_state(save['slider']) self.colormap.restore_state(save['colormap']) self.crosshair.restore_state(save['crosshair']) self.sliced_plot.set_colormap(save['colorscales']['sliced']) self.sliced_plot.set_levels(save['levels']['sliced']) self.sum_plot.set_colormap(save['colorscales']['sum']) self.sum_plot.set_levels(save['levels']['sum']) self.residual_plot.set_levels(save['levels']['residual']) self.residual_plot.set_colormap(save['colorscales']['residual']) self.selected_plot.set_levels(save['levels']['selected']) self.selected_plot.set_colormap(save['colorscales']['selected']) return self def get_title(self): return self.title def get_data(self): return [self.model.sliced_data, self.model.orbitals] def trigger_fit(self): try: new_results = self.model.fit() except ValueError as e: logging.getLogger('kmap').warning(str(e)) self.lmfit_options.update_fit_button() return if not hasattr(self, 'results'): self.results = new_results else: results_dict = dict(self.results) new_results_dict = dict(new_results) results_dict.update(new_results_dict) self.results = sorted(list(results_dict.items()), key=lambda x: x[0]) settings = self.model.get_settings() self.fit_finished.emit(self.results, settings) def change_slice(self): axis_index = self.slider.get_axis() slice_policy = self.lmfit_options.get_slice_policy() combined = True if slice_policy == 'all combined' else False slice_indices = (self.slider.get_index() if slice_policy == 'only one' else 'all') self.model.set_slices( slice_indices, axis_index=axis_index, combined=combined) self.refresh_sliced_plot() self.refresh_residual_plot() def change_axis(self): axis = self.interpolation.get_axis() self.model.set_axis(axis) self.refresh_all() def save_state(self): save, dependencies = super().save_state() save.update({'orbital_options': self.orbital_options.save_state(), 'interpolation': self.interpolation.save_state(), 'lmfit_options': self.lmfit_options.save_state(), 'tree': self.tree.save_state()}) dependencies.update({'sliced_tab': self.sliced_tab.ID, 'orbital_tab': self.orbital_tab.ID}) return save, dependencies def _change_slice_policy(self, slice_policy): axis = self.slider.get_axis() if slice_policy == 'all': self.model.set_slices('all', axis_index=axis, combined=False) elif slice_policy == 'only one': index = self.slider.get_index() self.model.set_slices([index], axis_index=axis, combined=False) elif slice_policy == 'all combined': self.model.set_slices('all', axis_index=axis, combined=True) else: indices = [int(e) for e in slice_policy.split(' ')] self.model.set_slices(indices, axis_index=axis, combined=False) def _change_method(self, method): self._change_to_matrix_state(method == 'matrix_inversion') self.model.set_fit_method(method) def _change_to_matrix_state(self, state): if state: variables = self.model.background_equation[1] if 'c' not in variables: self.lmfit_options._pre_factor_background() self.tree._change_to_matrix_state(state) def _change_region(self, *args): self.model.set_region(*args) self.refresh_all() def _change_background(self, *args): new_variables = self.model.set_background_equation(*args) for variable in new_variables: self.tree.add_equation_parameter(variable) self.refresh_sum_plot() self.refresh_residual_plot() def _setup(self): LMFitBaseTab._setup(self) self.orbital_options = LMFitOrbitalOptions() self.tree = LMFitTree(self.model.orbitals, self.model.parameters) self.interpolation = LMFitInterpolation() self.lmfit_options = LMFitOptions(self) self.change_axis() self.model.set_crosshair(self.crosshair.model) self._change_background(self.lmfit_options.get_background()) layout = QVBoxLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setSpacing(3) self.scroll_area.widget().setLayout(layout) layout.insertWidget(0, self.slider) layout.insertWidget(1, self.orbital_options) layout.insertWidget(2, self.interpolation) layout.insertWidget(3, self.lmfit_options) layout.insertWidget(4, self.colormap) layout.insertWidget(5, self.crosshair) self.layout.insertWidget(1, self.tree) def _connect(self): LMFitBaseTab._connect(self) self.interpolation.interpolation_changed.connect(self.change_axis) self.tree.value_changed.connect(self._refresh_orbital_plots) self.tree.vary_changed.connect(self.update_chi2_label) self.lmfit_options.background_changed.connect(self._change_background) self.lmfit_options.fit_triggered.connect(self.trigger_fit) self.lmfit_options.method_changed.connect(self._change_method) self.lmfit_options.slice_policy_changed.connect( self._change_slice_policy) self.lmfit_options.region_changed.connect(self._change_region) self.orbital_options.symmetrization_changed.connect( self.model.set_symmetrization) self.orbital_options.symmetrization_changed.connect( self._refresh_orbital_plots) self.orbital_options.polarization_changed.connect( self.model.set_polarization) self.orbital_options.polarization_changed.connect( self._refresh_orbital_plots)