コード例 #1
0
class TestLMFitModel(unittest.TestCase):
    @classmethod
    def setUpClass(cls):

        file_path = os.path.dirname(os.path.realpath(__file__))
        input_path = file_path + '/../../example/data/'

        sliced_path = input_path + 'example5_6584.hdf5'
        cls.sliced_data = SlicedData.init_from_hdf5(sliced_path)

        orbital_paths = [
            'PTCDA_C.cube', 'PTCDA_D.cube', 'PTCDA_E.cube', 'PTCDA_F.cube'
        ]
        cls.orbitals = [
            OrbitalData.init_from_file(input_path + path, ID)
            for ID, path in enumerate(orbital_paths)
        ]

        cls.expected = np.loadtxt(file_path + '/output/weights_PTCDA')
        cls.background_expected = np.loadtxt(file_path +
                                             '/output/background_expected')

    def setUp(self):

        self.lmfit = LMFitModel(TestLMFitModel.sliced_data,
                                TestLMFitModel.orbitals)
        self.crosshair = CrosshairAnnulusModel()

    def test_set_crosshair(self):

        self.lmfit.set_crosshair(self.crosshair)

        self.assertEqual(self.lmfit.crosshair, self.crosshair)

    def test_set_axis(self):

        step_size = 0.24
        range_ = [-3, 3]
        axis = np.linspace(*range_,
                           num=step_size_to_num(range_, step_size),
                           endpoint=True)

        self.lmfit.set_axis_by_step_size(range_, step_size)
        npt.assert_almost_equal(self.lmfit.get_sliced_kmap(0).x_axis, axis)

    def test_set_slices(self):

        self.lmfit.set_slices([1, 2, 3])
        self.lmfit.set_slices([0, 1, 2, 3], combined=True)

    def test_set_background_equation(self):

        self.lmfit.set_background_equation('np.exp(a)')
        npt.assert_equal(self.lmfit.background_equation, ['np.exp(a)', ['a']])

        self.assertRaises(ValueError, self.lmfit.set_background_equation,
                          'np.exp(a')

    def test_parameters(self):

        self.lmfit.set_background_equation('np.exp(a)')

        self.assertEqual(self.lmfit.parameters['w_1'].min, 0)
        self.assertEqual(self.lmfit.parameters['a'].value, 0)
        self.assertEqual(self.lmfit.parameters['E_kin'].max, 150)

    def test_edit_parameter(self):

        self.lmfit.edit_parameter('w_1', value=1.5)
        self.assertEqual(self.lmfit.parameters['w_1'].value, 1.5)

    def test_background(self):

        range_, dk = [-3.0, 3.0], 0.025
        self.lmfit.set_axis_by_step_size(range_, dk)

        self.lmfit.set_background_equation(
            '(np.exp(-x**2-y**2)-np.exp(-(x-1)**2-(y-1)**2))/2')

        background = self.lmfit._get_background(variables={})

        npt.assert_almost_equal(background, TestLMFitModel.background_expected)

    def test_settings(self):

        lmfit_new = LMFitModel(TestLMFitModel.sliced_data,
                               TestLMFitModel.orbitals)

        lmfit_new.set_settings(self.lmfit.get_settings())

    def test_PTCDA(self):

        if float(config.get_key('orbital', 'dk3D')) != 0.12:
            print('WARNING: Test \'test_PTCDA\' from the ' +
                  '\'test_lmfit\' module has not been run. It requires ' +
                  '\'dk3D\' setting from the \'cube\' category to be ' +
                  'to 0.12.')
            return

        # Set certain parameters not being fitted but desired to be changed
        range_, dk = [-3.0, 3.0], 0.04
        self.lmfit.set_axis_by_step_size(range_, dk)
        self.lmfit.set_polarization('toroid', 'p')
        self.lmfit.set_background_equation('c')

        # Set certain fit parameter to desired value
        self.lmfit.edit_parameter('E_kin', value=27.2)
        self.lmfit.edit_parameter('alpha', value=40)
        self.lmfit.edit_parameter('c', value=1, vary=True)

        # Activate fitting for all weights (i is a dummy ID used in setUpClass)
        for i in [0, 1, 2, 3]:
            self.lmfit.edit_parameter('w_' + str(i), vary=True)

        # Set slices to be used
        self.lmfit.set_slices('all', combined=False)

        results = self.lmfit.fit()

        # Test results
        weights = np.array([[
            result[1].params['w_0'].value, result[1].params['w_1'].value,
            result[1].params['w_2'].value, result[1].params['w_3'].value
        ] for result in results]).T

        npt.assert_almost_equal(weights, TestLMFitModel.expected, decimal=5)
コード例 #2
0
ファイル: lmfittab.py プロジェクト: gsm-matthijs/kMap
class LMFitTab(LMFitBaseTab, LMFitTab_UI):

    fit_finished = pyqtSignal(list, list, dict)

    def __init__(self, sliced_data, orbitals):

        self.model = LMFitModel(sliced_data, orbitals)

        # Setup GUI
        super(LMFitTab, self).__init__()
        self.setupUi(self)

        self._setup()
        self._connect()

        self.refresh_sliced_plot()
        self.refresh_selected_plot()
        kmap = self.refresh_sum_plot()
        self.refresh_residual_plot(weight_sum_data=kmap)

    def get_title(self):

        return 'LM-Fit'

    def trigger_fit(self):

        results = self.model.fit()
        settings = self.model.get_settings()
        data = [self.model.sliced_data, self.model.orbitals]

        self.fit_finished.emit(data, results, settings)

    def change_slice(self):

        axis_index = self.slider.get_axis()
        slice_policy = self.lmfit_options.get_slice_policy()
        combined = True if slice_policy == 'all combined' else False
        slice_indices = (self.slider.get_index()
                         if slice_policy == 'only one' else 'all')

        self.model.set_slices(slice_indices,
                              axis_index=axis_index,
                              combined=combined)

        self.refresh_sliced_plot()
        self.refresh_residual_plot()

    def change_axis(self):

        axis = self.interpolation.get_axis()
        self.model.set_axis(axis)

        self.refresh_sliced_plot()
        self.refresh_selected_plot()
        kmap = self.refresh_sum_plot()
        self.refresh_residual_plot(weight_sum_data=kmap)

    def _change_slice_policy(self, slice_policy):

        axis = self.slider.get_axis()

        if slice_policy == 'all':
            self.model.set_slices('all', axis_index=axis, combined=False)

        elif slice_policy == 'only one':
            index = self.slider.get_index()
            self.model.set_slices([index], axis_index=axis, combined=False)

        else:
            self.model.set_slices('all', axis_index=axis, combined=True)

    def _change_region(self, *args):

        self.model.set_region(*args)
        self.refresh_residual_plot()

    def _change_background(self, *args):

        new_variables = self.model.set_background_equation(*args)
        for variable in new_variables:
            self.tree.add_equation_parameter(variable)

        self.refresh_sum_plot()
        self.refresh_residual_plot()

    def _setup(self):

        LMFitBaseTab._setup(self)

        self.orbital_options = LMFitOrbitalOptions()
        self.tree = LMFitTree(self.model.orbitals, self.model.parameters)
        self.interpolation = LMFitInterpolation()
        self.lmfit_options = LMFitOptions(self)

        self.change_axis()
        self.model.set_crosshair(self.crosshair.model)
        self._change_background(self.lmfit_options.get_background())

        layout = QVBoxLayout()
        layout.setContentsMargins(3, 3, 3, 3)
        layout.setSpacing(3)
        self.scroll_area.widget().setLayout(layout)
        layout.insertWidget(0, self.slider)
        layout.insertWidget(1, self.orbital_options)
        layout.insertWidget(2, self.interpolation)
        layout.insertWidget(3, self.lmfit_options)
        layout.insertWidget(4, self.colormap)
        layout.insertWidget(5, self.crosshair)

        self.layout.insertWidget(1, self.tree)

    def _connect(self):

        LMFitBaseTab._connect(self)

        self.interpolation.interpolation_changed.connect(self.change_axis)

        self.tree.value_changed.connect(self._refresh_orbital_plots)
        self.tree.vary_changed.connect(self.update_chi2_label)
        self.lmfit_options.background_changed.connect(self._change_background)
        self.lmfit_options.fit_triggered.connect(self.trigger_fit)
        self.lmfit_options.method_changed.connect(self.model.set_fit_method)
        self.lmfit_options.slice_policy_changed.connect(
            self._change_slice_policy)

        self.orbital_options.symmetrization_changed.connect(
            self.model.set_symmetrization)
        self.orbital_options.symmetrization_changed.connect(
            self._refresh_orbital_plots)
        self.orbital_options.polarization_changed.connect(
            self.model.set_polarization)
        self.orbital_options.polarization_changed.connect(
            self._refresh_orbital_plots)
コード例 #3
0
class LMFitTab(LMFitBaseTab, LMFitTab_UI):
    fit_finished = pyqtSignal(list, dict)

    def __init__(self, sliced_tab, orbital_tab, max_orbitals=-1):
        self.sliced_tab = sliced_tab
        self.orbital_tab = orbital_tab
        if max_orbitals != -1:
            self.model = LMFitModel(sliced_tab.get_data(),
                                    orbital_tab.get_orbitals()[:max_orbitals])
        else:
            self.model = LMFitModel(sliced_tab.get_data(),
                                    orbital_tab.get_orbitals())

        # Setup GUI
        super(LMFitTab, self).__init__()
        self.setupUi(self)

        self._setup()
        self._connect()

        self.refresh_all()

    @classmethod
    def init_from_save(cls, save, dependencies, tab_widget):
        sliced_tab = tab_widget.get_tab_by_ID(dependencies['sliced_tab'])
        orbital_tab = tab_widget.get_tab_by_ID(dependencies['orbital_tab'])
        # max orbitals: If orbitals were loaded after lmfit was created they
        # will not be reflected in the results -> not load them
        self = cls(sliced_tab, orbital_tab, max_orbitals=len(save['tree'])-2)

        self.locked_tabs = [sliced_tab, orbital_tab]
        self.title = save['title']
        self.tree.restore_state(save['tree'])
        self.interpolation.restore_state(save['interpolation']),
        self.lmfit_options.restore_state(save['lmfit_options'])
        self.orbital_options.restore_state(save['orbital_options']),
        self.slider.restore_state(save['slider'])
        self.colormap.restore_state(save['colormap'])
        self.crosshair.restore_state(save['crosshair'])
        self.sliced_plot.set_colormap(save['colorscales']['sliced'])
        self.sliced_plot.set_levels(save['levels']['sliced'])
        self.sum_plot.set_colormap(save['colorscales']['sum'])
        self.sum_plot.set_levels(save['levels']['sum'])
        self.residual_plot.set_levels(save['levels']['residual'])
        self.residual_plot.set_colormap(save['colorscales']['residual'])
        self.selected_plot.set_levels(save['levels']['selected'])
        self.selected_plot.set_colormap(save['colorscales']['selected'])

        return self

    def get_title(self):
        return self.title

    def get_data(self):
        return [self.model.sliced_data, self.model.orbitals]

    def trigger_fit(self):
        try:
            new_results = self.model.fit()
        except ValueError as e:
            logging.getLogger('kmap').warning(str(e))
            self.lmfit_options.update_fit_button()

            return
        
        if not hasattr(self, 'results'):
            self.results = new_results

        else:
            results_dict = dict(self.results)
            new_results_dict = dict(new_results)
            results_dict.update(new_results_dict)
            self.results = sorted(list(results_dict.items()), key=lambda x: x[0])
        
        settings = self.model.get_settings()

        self.fit_finished.emit(self.results, settings)

    def change_slice(self):
        axis_index = self.slider.get_axis()
        slice_policy = self.lmfit_options.get_slice_policy()
        combined = True if slice_policy == 'all combined' else False
        slice_indices = (self.slider.get_index()
                         if slice_policy == 'only one' else 'all')

        self.model.set_slices(
            slice_indices, axis_index=axis_index, combined=combined)

        self.refresh_sliced_plot()
        self.refresh_residual_plot()

    def change_axis(self):
        axis = self.interpolation.get_axis()
        self.model.set_axis(axis)

        self.refresh_all()

    def save_state(self):

        save, dependencies = super().save_state()

        save.update({'orbital_options': self.orbital_options.save_state(),
                     'interpolation': self.interpolation.save_state(),
                     'lmfit_options': self.lmfit_options.save_state(),
                     'tree': self.tree.save_state()})
        
        dependencies.update({'sliced_tab': self.sliced_tab.ID,
                             'orbital_tab': self.orbital_tab.ID})

        return save, dependencies

    def _change_slice_policy(self, slice_policy):
        axis = self.slider.get_axis()

        if slice_policy == 'all':
            self.model.set_slices('all', axis_index=axis, combined=False)

        elif slice_policy == 'only one':
            index = self.slider.get_index()
            self.model.set_slices([index], axis_index=axis, combined=False)

        elif slice_policy == 'all combined':
            self.model.set_slices('all', axis_index=axis, combined=True)
        
        else:
            indices = [int(e) for e in slice_policy.split(' ')]
            self.model.set_slices(indices, axis_index=axis, combined=False)

    def _change_method(self, method):

        self._change_to_matrix_state(method == 'matrix_inversion')

        self.model.set_fit_method(method)

    def _change_to_matrix_state(self, state):

        if state:
            variables = self.model.background_equation[1]
            if 'c' not in variables:
                self.lmfit_options._pre_factor_background()

        self.tree._change_to_matrix_state(state)

    def _change_region(self, *args):
        self.model.set_region(*args)

        self.refresh_all()

    def _change_background(self, *args):

        new_variables = self.model.set_background_equation(*args)
        for variable in new_variables:
            self.tree.add_equation_parameter(variable)

        self.refresh_sum_plot()
        self.refresh_residual_plot()

    def _setup(self):
        LMFitBaseTab._setup(self)

        self.orbital_options = LMFitOrbitalOptions()
        self.tree = LMFitTree(self.model.orbitals, self.model.parameters)
        self.interpolation = LMFitInterpolation()
        self.lmfit_options = LMFitOptions(self)

        self.change_axis()
        self.model.set_crosshair(self.crosshair.model)
        self._change_background(self.lmfit_options.get_background())

        layout = QVBoxLayout()
        layout.setContentsMargins(3, 3, 3, 3)
        layout.setSpacing(3)
        self.scroll_area.widget().setLayout(layout)
        layout.insertWidget(0, self.slider)
        layout.insertWidget(1, self.orbital_options)
        layout.insertWidget(2, self.interpolation)
        layout.insertWidget(3, self.lmfit_options)
        layout.insertWidget(4, self.colormap)
        layout.insertWidget(5, self.crosshair)

        self.layout.insertWidget(1, self.tree)

    def _connect(self):
        LMFitBaseTab._connect(self)

        self.interpolation.interpolation_changed.connect(self.change_axis)

        self.tree.value_changed.connect(self._refresh_orbital_plots)
        self.tree.vary_changed.connect(self.update_chi2_label)
        self.lmfit_options.background_changed.connect(self._change_background)
        self.lmfit_options.fit_triggered.connect(self.trigger_fit)
        self.lmfit_options.method_changed.connect(self._change_method)
        self.lmfit_options.slice_policy_changed.connect(
            self._change_slice_policy)
        self.lmfit_options.region_changed.connect(self._change_region)

        self.orbital_options.symmetrization_changed.connect(
            self.model.set_symmetrization)
        self.orbital_options.symmetrization_changed.connect(
            self._refresh_orbital_plots)
        self.orbital_options.polarization_changed.connect(
            self.model.set_polarization)
        self.orbital_options.polarization_changed.connect(
            self._refresh_orbital_plots)