示例#1
0
    def setUp(self):
        # creating Data objects
        self.img_data = ImgModel()
        self.img_data.load("Data/CbnCorrectionOptimization/Mg2SiO4_091.tif")
        self.calibration_data = CalibrationModel(self.img_data)
        self.calibration_data.load("Data/CbnCorrectionOptimization/LaB6_40keV side.poni")
        self.mask_data = MaskModel()
        self.mask_data.load_mask("Data/CbnCorrectionOptimization/Mg2SiO4_91_combined.mask")

        # creating the ObliqueAngleDetectorAbsorptionCorrection
        _, fit2d_parameter = self.calibration_data.get_calibration_parameter()
        detector_tilt = fit2d_parameter['tilt']
        detector_tilt_rotation = fit2d_parameter['tiltPlanRotation']

        self.tth_array = self.calibration_data.spectrum_geometry.twoThetaArray((2048, 2048))
        self.azi_array = self.calibration_data.spectrum_geometry.chiArray((2048, 2048))

        self.oiadac_correction = ObliqueAngleDetectorAbsorptionCorrection(
                self.tth_array, self.azi_array,
                detector_thickness=40,
                absorption_length=465.5,
                tilt=detector_tilt,
                rotation=detector_tilt_rotation,
        )
        self.img_data.add_img_correction(self.oiadac_correction, "oiadac")
 def setUp(self):
     self.calibration_model = CalibrationModel()
     dummy_x = np.linspace(0, 30)
     dummy_y = np.ones(dummy_x.shape)
     self.calibration_model.integrate_1d = MagicMock(return_value=(dummy_x,
                                                                   dummy_y))
     self.jcpds = jcpds()
    def setUp(self):
        self.img_model = ImgModel()
        self.mask_model = MaskModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.calibration_model._calibrants_working_dir = os.path.join(
            data_path, 'calibrants')
        self.calibration_model.integrate_1d = MagicMock()
        self.calibration_model.integrate_2d = MagicMock()

        self.calibration_widget = CalibrationWidget()
        self.working_dir = {}
        self.calibration_controller = CalibrationController(
            working_dir=self.working_dir,
            img_model=self.img_model,
            mask_model=self.mask_model,
            widget=self.calibration_widget,
            calibration_model=self.calibration_model)
class CalibrationModelTest(unittest.TestCase):
    def setUp(self):
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)

    def tearDown(self):
        del self.img_model
        del self.calibration_model.cake_geometry
        del self.calibration_model.spectrum_geometry
        del self.calibration_model
        gc.collect()

    def test_loading_calibration_gives_right_pixel_size(self):
        self.calibration_model.spectrum_geometry.load(os.path.join(data_path, 'CeO2_Pilatus1M.poni'))
        self.assertEqual(self.calibration_model.spectrum_geometry.pixel1, 0.000172)


        self.calibration_model.load(os.path.join(data_path,'LaB6_40keV_MarCCD.poni'))
        self.assertEqual(self.calibration_model.spectrum_geometry.pixel1, 0.000079)
 def setUp(self):
     self.app = QtGui.QApplication(sys.argv)
     self.img_model = ImgModel()
     self.mask_model = MaskModel()
     self.calibration_model = CalibrationModel(self.img_model)
     self.calibration_model._calibrants_working_dir = os.path.join(data_path, 'calibrants')
     self.calibration_widget = CalibrationWidget()
     self.working_dir = {}
     self.calibration_controller = CalibrationController(working_dir=self.working_dir,
                                                         img_model=self.img_model,
                                                         mask_model=self.mask_model,
                                                         widget=self.calibration_widget,
                                                         calibration_model=self.calibration_model)
示例#6
0
    def __init__(self, use_settings=True):
        self.use_settings = use_settings

        self.widget = MainWidget()
        # create data
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.mask_model = MaskModel()
        self.spectrum_model = PatternModel()
        self.phase_model = PhaseModel()

        self.settings_directory = os.path.join(os.path.expanduser("~"),
                                               '.Dioptas')
        self.working_directories = {
            'calibration': '',
            'mask': '',
            'image': '',
            'spectrum': '',
            'overlay': '',
            'phase': ''
        }

        if use_settings:
            self.load_settings()

        self.calibration_controller = CalibrationController(
            self.working_directories, self.widget.calibration_widget,
            self.img_model, self.mask_model, self.calibration_model)
        self.mask_controller = MaskController(self.working_directories,
                                              self.widget.mask_widget,
                                              self.img_model, self.mask_model)
        self.integration_controller = IntegrationController(
            self.working_directories, self.widget.integration_widget,
            self.img_model, self.mask_model, self.calibration_model,
            self.spectrum_model, self.phase_model)
        self.create_signals()
        self.update_title()

        self.current_tab_index = 0
    def setUp(self):
        self.app = QtGui.QApplication(sys.argv)
        self.image_model = ImgModel()
        self.calibration_model = CalibrationModel()
        self.calibration_model.load(os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))
        self.spectrum_model = SpectrumModel()
        self.phase_model = PhaseModel()
        self.widget = IntegrationWidget()
        self.widget.spectrum_view._auto_range = True
        self.phase_tw = self.widget.phase_tw

        self.spectrum_controller = SpectrumController({}, self.widget, self.image_model, None,
                                                                   self.calibration_model, self.spectrum_model)
        self.controller = PhaseController({}, self.widget, self.calibration_model, self.spectrum_model,
                                                       self.phase_model)
        self.spectrum_model.load_spectrum(os.path.join(data_path, 'spectrum_001.xy'))
示例#8
0
    def setUp(self):
        self.working_dir = {'image': ''}

        self.widget = IntegrationWidget()
        self.image_model = ImgModel()
        self.mask_model = MaskModel()
        self.spectrum_model = PatternModel()
        self.calibration_model = CalibrationModel(self.image_model)

        self.controller = ImageController(
            working_dir=self.working_dir,
            widget=self.widget,
            img_model=self.image_model,
            mask_model=self.mask_model,
            spectrum_model=self.spectrum_model,
            calibration_model=self.calibration_model)
示例#9
0
    def setUp(self):
        self.image_model = ImgModel()
        self.calibration_model = CalibrationModel()
        self.calibration_model.is_calibrated = True
        self.calibration_model.spectrum_geometry.wavelength = 0.31E-10
        self.calibration_model.integrate_1d = MagicMock(return_value=(self.calibration_model.tth,
                                                                      self.calibration_model.int))
        self.spectrum_model = PatternModel()
        self.phase_model = PhaseModel()
        self.widget = IntegrationWidget()
        self.widget.pattern_widget._auto_range = True
        self.phase_tw = self.widget.phase_tw

        self.spectrum_controller = PatternController({}, self.widget, self.image_model, None,
                                                     self.calibration_model, self.spectrum_model)
        self.controller = PhaseController({}, self.widget, self.calibration_model, self.spectrum_model,
                                          self.phase_model)
        self.spectrum_controller.load(os.path.join(data_path, 'spectrum_001.xy'))
示例#10
0
    def __init__(self, use_settings=True):
        self.use_settings = use_settings

        self.widget = MainWidget()
        # create data
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.mask_model = MaskModel()
        self.spectrum_model = PatternModel()
        self.phase_model = PhaseModel()

        self.settings_directory = os.path.join(os.path.expanduser("~"), '.Dioptas')
        self.working_directories = {'calibration': '', 'mask': '', 'image': '', 'spectrum': '', 'overlay': '',
                                    'phase': ''}

        if use_settings:
            self.load_settings()

        self.calibration_controller = CalibrationController(self.working_directories,
                                                            self.widget.calibration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model)
        self.mask_controller = MaskController(self.working_directories,
                                              self.widget.mask_widget,
                                              self.img_model,
                                              self.mask_model)
        self.integration_controller = IntegrationController(self.working_directories,
                                                            self.widget.integration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model,
                                                            self.spectrum_model,
                                                            self.phase_model)
        self.create_signals()
        self.update_title()

        self.current_tab_index = 0
示例#11
0
 def setUp(self):
     self.img_model = ImgModel()
     self.calibration_model = CalibrationModel(self.img_model)
示例#12
0
class CalibrationModelTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.app = QtGui.QApplication([])

    @classmethod
    def tearDownClass(cls):
        cls.app.quit()
        cls.app.deleteLater()

    def setUp(self):
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)

    def tearDown(self):
        del self.img_model
        if hasattr(self.calibration_model, 'cake_geometry'):
            del self.calibration_model.cake_geometry
        del self.calibration_model.spectrum_geometry
        del self.calibration_model
        gc.collect()

    def test_loading_calibration_gives_right_pixel_size(self):
        self.calibration_model.spectrum_geometry.load(os.path.join(data_path, 'CeO2_Pilatus1M.poni'))
        self.assertEqual(self.calibration_model.spectrum_geometry.pixel1, 0.000172)

        self.calibration_model.load(os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))
        self.assertEqual(self.calibration_model.spectrum_geometry.pixel1, 0.000079)

    def test_find_peaks_automatic(self):
        self.load_pilatus_1M_and_find_peaks()
        self.assertEqual(len(self.calibration_model.points), 6)
        for points in self.calibration_model.points:
            self.assertGreater(len(points), 0)

    def test_find_peak(self):
        """
        Tests the find_peak function for several maxima and pick points

        """
        points_and_pick_points = [
            [[30, 50], [31, 49]],
            [[30, 50], [34, 46]],
            [[5, 5],  [3, 3]],
            [[298, 298], [299, 299]]
        ]

        for data in points_and_pick_points:
            self.img_model._img_data = np.zeros((300, 300))

            point = data[0]
            pick_point = data[1]
            self.img_model._img_data[point[0], point[1]] = 100

            peak_point = self.calibration_model.find_peak(pick_point[0], pick_point[1], 10, 0)
            self.assertEqual(peak_point[0][0], point[0])
            self.assertEqual(peak_point[0][1], point[1])

    def test_search_peaks_on_ring(self):
        """
        Tests to search on the first ring of the calibrant after an inital calibration
        """
        pass

    def load_pilatus_1M_and_find_peaks(self):
        self.img_model.load(os.path.join(data_path, 'CeO2_Pilatus1M.tif'))
        self.calibration_model.find_peaks_automatic(517.664434674, 647.529865592, 0)
        self.calibration_model.find_peaks_automatic(667.380513299, 525.252854758, 1)
        self.calibration_model.find_peaks_automatic(671.110095329, 473.571503774, 2)
        self.calibration_model.find_peaks_automatic(592.788872703, 350.495296791, 3)
        self.calibration_model.find_peaks_automatic(387.395462348, 390.987901686, 4)
        self.calibration_model.find_peaks_automatic(367.94835605, 554.290314848, 5)

    def test_calibration_with_supersampling(self):
        self.load_pilatus_1M_and_find_peaks()
        self.calibration_model.set_calibrant(os.path.join(calibrant_path, 'LaB6.D'))
        self.calibration_model.calibrate()
        normal_poni1 = self.calibration_model.spectrum_geometry.poni1
        self.img_model.set_supersampling(2)
        self.calibration_model.set_supersampling(2)
        self.calibration_model.calibrate()
        self.assertAlmostEqual(normal_poni1, self.calibration_model.spectrum_geometry.poni1, places=5)

    def test_calibration1(self):
        self.img_model.load(os.path.join(data_path, 'LaB6_40keV_MarCCD.tif'))
        self.calibration_model.find_peaks_automatic(1179.6, 1129.4, 0)
        self.calibration_model.find_peaks_automatic(1268.5, 1119.8, 1)
        self.calibration_model.set_calibrant(os.path.join(calibrant_path, 'LaB6.D'))
        self.calibration_model.calibrate()

        self.assertGreater(self.calibration_model.spectrum_geometry.poni1, 0)
        self.assertAlmostEqual(self.calibration_model.spectrum_geometry.dist, 0.18, delta=0.01)
        self.assertGreater(self.calibration_model.cake_geometry.poni1, 0)

    def test_calibration2(self):
        self.img_model.load(os.path.join(data_path, 'LaB6_OffCenter_PE.tif'))
        self.calibration_model.find_peaks_automatic(1245.2, 1919.3, 0)
        self.calibration_model.find_peaks_automatic(1334.0, 1823.7, 1)
        self.calibration_model.start_values['dist'] = 500e-3
        self.calibration_model.start_values['pixel_height'] = 200e-6
        self.calibration_model.start_values['pixel_width'] = 200e-6
        self.calibration_model.set_calibrant(os.path.join(calibrant_path, 'LaB6.D'))
        self.calibration_model.calibrate()

        self.assertGreater(self.calibration_model.spectrum_geometry.poni1, 0)
        self.assertAlmostEqual(self.calibration_model.spectrum_geometry.dist, 0.500, delta=0.01)
        self.assertGreater(self.calibration_model.cake_geometry.poni1, 0)

    def test_calibration3(self):
        self.load_pilatus_1M_and_find_peaks()
        self.calibration_model.start_values['wavelength'] = 0.406626e-10
        self.calibration_model.start_values['pixel_height'] = 172e-6
        self.calibration_model.start_values['pixel_width'] = 172e-6
        self.calibration_model.set_calibrant(os.path.join(calibrant_path, 'LaB6.D'))
        self.calibration_model.calibrate()

        self.assertGreater(self.calibration_model.spectrum_geometry.poni1, 0)
        self.assertAlmostEqual(self.calibration_model.spectrum_geometry.dist, 0.100, delta=0.02)
        self.assertGreater(self.calibration_model.cake_geometry.poni1, 0)

    def test_get_pixel_ind(self):
        self.img_model.load(os.path.join(data_path, 'image_001.tif'))
        self.calibration_model.load(os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))

        self.calibration_model.integrate_1d(1000)

        tth_array = self.calibration_model.spectrum_geometry.ttha
        azi_array = self.calibration_model.spectrum_geometry.chia

        for i in range(100):
            ind1 = np.random.random_integers(0, 2023)
            ind2 = np.random.random_integers(0, 2023)

            tth = tth_array[ind1, ind2]
            azi = azi_array[ind1, ind2]

            result_ind1, result_ind2 = self.calibration_model.get_pixel_ind(tth, azi)

            self.assertAlmostEqual(ind1, result_ind1, places=3)
            self.assertAlmostEqual(ind2, result_ind2, places=3)
示例#13
0
class MainController(object):
    """
    Creates a the main controller for Dioptas. Creates all the data objects and connects them with the other controllers
    """

    def __init__(self, use_settings=True):
        self.use_settings = use_settings

        self.widget = MainWidget()
        # create data
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.mask_model = MaskModel()
        self.spectrum_model = PatternModel()
        self.phase_model = PhaseModel()

        self.settings_directory = os.path.join(os.path.expanduser("~"), '.Dioptas')
        self.working_directories = {'calibration': '', 'mask': '', 'image': '', 'spectrum': '', 'overlay': '',
                                    'phase': ''}

        if use_settings:
            self.load_settings()

        self.calibration_controller = CalibrationController(self.working_directories,
                                                            self.widget.calibration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model)
        self.mask_controller = MaskController(self.working_directories,
                                              self.widget.mask_widget,
                                              self.img_model,
                                              self.mask_model)
        self.integration_controller = IntegrationController(self.working_directories,
                                                            self.widget.integration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model,
                                                            self.spectrum_model,
                                                            self.phase_model)
        self.create_signals()
        self.update_title()

        self.current_tab_index = 0

    def show_window(self):
        """
        Displays the main window on the screen and makes it active.
        """
        self.widget.show()

        if _platform == "darwin":
            self.widget.setWindowState(self.widget.windowState() & ~QtCore.Qt.WindowMinimized | QtCore.Qt.WindowActive)
            self.widget.activateWindow()
            self.widget.raise_()

    def create_signals(self):
        """
        Creates subscriptions for changing tabs and also newly loaded files which will update the title of the main
                window.
        """
        self.widget.tabWidget.currentChanged.connect(self.tab_changed)
        self.widget.closeEvent = self.close_event
        self.img_model.img_changed.connect(self.update_title)
        self.spectrum_model.pattern_changed.connect(self.update_title)

    def tab_changed(self, ind):
        """
        Function which is called when a tab has been selected (calibration, mask, or integration). Performs
        needed initialization tasks.
        :param ind: index for the tab selected (2 - integration, 1 = mask, 0 - calibration)
        :return:
        """
        old_index = self.current_tab_index
        self.current_tab_index = ind

        # get the old view range
        old_view_range = None
        old_hist_levels = None
        if old_index == 0:  # calibration tab
            old_view_range = self.widget.calibration_widget.img_widget.img_view_box.targetRange()
            old_hist_levels = self.widget.calibration_widget.img_widget.img_histogram_LUT.getExpLevels()
        elif old_index == 1:  # mask tab
            old_view_range = self.widget.mask_widget.img_widget.img_view_box.targetRange()
            old_hist_levels = self.widget.mask_widget.img_widget.img_histogram_LUT.getExpLevels()
        elif old_index == 2:
            old_view_range = self.widget.integration_widget.img_widget.img_view_box.targetRange()
            old_hist_levels = self.widget.integration_widget.img_widget.img_histogram_LUT.getExpLevels()

        # update the GUI
        if ind == 2:  # integration tab
            self.mask_model.set_supersampling()
            self.integration_controller.image_controller.plot_mask()
            self.integration_controller.widget.calibration_lbl.setText(self.calibration_model.calibration_name)
            self.integration_controller.image_controller._auto_scale = False
            self.integration_controller.spectrum_controller.image_changed()
            self.integration_controller.image_controller.update_img()
            self.widget.integration_widget.img_widget.set_range(x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.integration_widget.img_widget.img_histogram_LUT.setLevels(*old_hist_levels)
        elif ind == 1:  # mask tab
            self.mask_controller.plot_mask()
            self.mask_controller.plot_image()
            self.widget.mask_widget.img_widget.set_range(x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.mask_widget.img_widget.img_histogram_LUT.setLevels(*old_hist_levels)
        elif ind == 0:  # calibration tab
            self.calibration_controller.plot_mask()
            try:
                self.calibration_controller.update_calibration_parameter_in_view()
            except (TypeError, AttributeError):
                pass
            self.widget.calibration_widget.img_widget.set_range(x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.calibration_widget.img_widget.img_histogram_LUT.setLevels(*old_hist_levels)

    def update_title(self):
        """
        Updates the title bar of the main window. The title bar will always show the current version of Dioptas, the
        image or spectrum filenames loaded and the current calibration name.
        """
        img_filename = os.path.basename(self.img_model.filename)
        spec_filename = os.path.basename(self.spectrum_model.pattern_filename)
        calibration_name = self.calibration_model.calibration_name
        str = 'Dioptas ' + __version__
        if img_filename is '' and spec_filename is '':
            self.widget.setWindowTitle(str + u' - © 2015 C. Prescher')
            self.widget.integration_widget.img_frame.setWindowTitle(str + u' - © 2015 C. Prescher')
            return

        if img_filename is not '' or spec_filename is not '':
            str += ' - ['
        if img_filename is not '':
            str += img_filename
        elif img_filename is '' and spec_filename is not '':
            str += spec_filename
        if not img_filename == spec_filename:
            str += ', ' + spec_filename
        if calibration_name is not None:
            str += ', calibration: ' + calibration_name
        str += ']'
        str += u' - © 2015 C. Prescher'
        self.widget.setWindowTitle(str)
        self.widget.integration_widget.img_frame.setWindowTitle(str)

    def load_settings(self):
        """
        Loads previously saved Dioptas settings.
        """
        if os.path.exists(self.settings_directory):
            self.load_directories()
            self.load_xml_settings()

    def load_directories(self):
        """
        Loads previously used Dioptas directory paths.
        """
        working_directories_path = os.path.join(self.settings_directory, 'working_directories.csv')
        if os.path.exists(working_directories_path):
            reader = csv.reader(open(working_directories_path, 'r'))
            self.working_directories = dict(x for x in reader)

    def load_xml_settings(self):
        """
        Loads previously used Dioptas settings. Currently this is only the calibration.
        :return:
        """
        xml_settings_path = os.path.join(self.settings_directory, "settings.xml")
        if os.path.exists(xml_settings_path):
            tree = ET.parse(xml_settings_path)
            root = tree.getroot()
            filenames = root.find("filenames")
            calibration_path = filenames.find("calibration").text
            if os.path.exists(str(calibration_path)):
                self.calibration_model.load(calibration_path)

    def save_settings(self):
        """
        Saves current settings of Dioptas in the users directory.
        """
        if not os.path.exists(self.settings_directory):
            os.mkdir(self.settings_directory)
        self.save_directories()
        self.save_xml_settings()

    def save_directories(self):
        """
        Currently used working directories for images, spectra, etc. are saved as csv file in the users directory for
        reuse when Dioptas is started again
        """

        working_directories_path = os.path.join(self.settings_directory, 'working_directories.csv')
        writer = csv.writer(open(working_directories_path, 'w'))
        for key, value in list(self.working_directories.items()):
            writer.writerow([key, value])
            writer.writerow([key, value])

    def save_xml_settings(self):
        """
        Currently used settings of Dioptas are saved in to an xml file in the users directory for reuse when Dioptas is
        started again. Right now this is only saving the calibration filename.
        """
        root = ET.Element("DioptasSettings")
        filenames = ET.SubElement(root, "filenames")
        calibration_filename = ET.SubElement(filenames, "calibration")
        calibration_filename.text = self.calibration_model.filename
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.settings_directory, "settings.xml"))

    def close_event(self, _):
        """
        Intervention of the Dioptas close event to save settings before closing the Program.
        """
        if self.use_settings:
            self.save_settings()
        QtGui.QApplication.closeAllWindows()
        QtGui.QApplication.quit()
示例#14
0
class MainController(object):
    """
    Creates a the main controller for Dioptas. Creates all the data objects and connects them with the other controllers
    """
    def __init__(self, use_settings=True):
        self.use_settings = use_settings

        self.widget = MainWidget()
        # create data
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.mask_model = MaskModel()
        self.spectrum_model = PatternModel()
        self.phase_model = PhaseModel()

        self.settings_directory = os.path.join(os.path.expanduser("~"),
                                               '.Dioptas')
        self.working_directories = {
            'calibration': '',
            'mask': '',
            'image': '',
            'spectrum': '',
            'overlay': '',
            'phase': ''
        }

        if use_settings:
            self.load_settings()

        self.calibration_controller = CalibrationController(
            self.working_directories, self.widget.calibration_widget,
            self.img_model, self.mask_model, self.calibration_model)
        self.mask_controller = MaskController(self.working_directories,
                                              self.widget.mask_widget,
                                              self.img_model, self.mask_model)
        self.integration_controller = IntegrationController(
            self.working_directories, self.widget.integration_widget,
            self.img_model, self.mask_model, self.calibration_model,
            self.spectrum_model, self.phase_model)
        self.create_signals()
        self.update_title()

        self.current_tab_index = 0

    def show_window(self):
        """
        Displays the main window on the screen and makes it active.
        """
        self.widget.show()

        if _platform == "darwin":
            self.widget.setWindowState(self.widget.windowState()
                                       & ~QtCore.Qt.WindowMinimized
                                       | QtCore.Qt.WindowActive)
            self.widget.activateWindow()
            self.widget.raise_()

    def create_signals(self):
        """
        Creates subscriptions for changing tabs and also newly loaded files which will update the title of the main
                window.
        """
        self.widget.tabWidget.currentChanged.connect(self.tab_changed)
        self.widget.closeEvent = self.close_event
        self.img_model.img_changed.connect(self.update_title)
        self.spectrum_model.pattern_changed.connect(self.update_title)

    def tab_changed(self, ind):
        """
        Function which is called when a tab has been selected (calibration, mask, or integration). Performs
        needed initialization tasks.
        :param ind: index for the tab selected (2 - integration, 1 = mask, 0 - calibration)
        :return:
        """
        old_index = self.current_tab_index
        self.current_tab_index = ind

        # get the old view range
        old_view_range = None
        old_hist_levels = None
        if old_index == 0:  # calibration tab
            old_view_range = self.widget.calibration_widget.img_widget.img_view_box.targetRange(
            )
            old_hist_levels = self.widget.calibration_widget.img_widget.img_histogram_LUT.getExpLevels(
            )
        elif old_index == 1:  # mask tab
            old_view_range = self.widget.mask_widget.img_widget.img_view_box.targetRange(
            )
            old_hist_levels = self.widget.mask_widget.img_widget.img_histogram_LUT.getExpLevels(
            )
        elif old_index == 2:
            old_view_range = self.widget.integration_widget.img_widget.img_view_box.targetRange(
            )
            old_hist_levels = self.widget.integration_widget.img_widget.img_histogram_LUT.getExpLevels(
            )

        # update the GUI
        if ind == 2:  # integration tab
            self.mask_model.set_supersampling()
            self.integration_controller.image_controller.plot_mask()
            self.integration_controller.widget.calibration_lbl.setText(
                self.calibration_model.calibration_name)
            self.integration_controller.image_controller._auto_scale = False
            self.integration_controller.spectrum_controller.image_changed()
            self.integration_controller.image_controller.update_img()
            self.widget.integration_widget.img_widget.set_range(
                x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.integration_widget.img_widget.img_histogram_LUT.setLevels(
                *old_hist_levels)
        elif ind == 1:  # mask tab
            self.mask_controller.plot_mask()
            self.mask_controller.plot_image()
            self.widget.mask_widget.img_widget.set_range(
                x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.mask_widget.img_widget.img_histogram_LUT.setLevels(
                *old_hist_levels)
        elif ind == 0:  # calibration tab
            self.calibration_controller.plot_mask()
            try:
                self.calibration_controller.update_calibration_parameter_in_view(
                )
            except (TypeError, AttributeError):
                pass
            self.widget.calibration_widget.img_widget.set_range(
                x_range=old_view_range[0], y_range=old_view_range[1])
            self.widget.calibration_widget.img_widget.img_histogram_LUT.setLevels(
                *old_hist_levels)

    def update_title(self):
        """
        Updates the title bar of the main window. The title bar will always show the current version of Dioptas, the
        image or spectrum filenames loaded and the current calibration name.
        """
        img_filename = os.path.basename(self.img_model.filename)
        spec_filename = os.path.basename(self.spectrum_model.pattern_filename)
        calibration_name = self.calibration_model.calibration_name
        str = 'Dioptas ' + __version__
        if img_filename is '' and spec_filename is '':
            self.widget.setWindowTitle(str + u' - © 2015 C. Prescher')
            self.widget.integration_widget.img_frame.setWindowTitle(
                str + u' - © 2015 C. Prescher')
            return

        if img_filename is not '' or spec_filename is not '':
            str += ' - ['
        if img_filename is not '':
            str += img_filename
        elif img_filename is '' and spec_filename is not '':
            str += spec_filename
        if not img_filename == spec_filename:
            str += ', ' + spec_filename
        if calibration_name is not None:
            str += ', calibration: ' + calibration_name
        str += ']'
        str += u' - © 2015 C. Prescher'
        self.widget.setWindowTitle(str)
        self.widget.integration_widget.img_frame.setWindowTitle(str)

    def load_settings(self):
        """
        Loads previously saved Dioptas settings.
        """
        if os.path.exists(self.settings_directory):
            self.load_directories()
            self.load_xml_settings()

    def load_directories(self):
        """
        Loads previously used Dioptas directory paths.
        """
        working_directories_path = os.path.join(self.settings_directory,
                                                'working_directories.csv')
        if os.path.exists(working_directories_path):
            reader = csv.reader(open(working_directories_path, 'r'))
            self.working_directories = dict(x for x in reader)

    def load_xml_settings(self):
        """
        Loads previously used Dioptas settings. Currently this is only the calibration.
        :return:
        """
        xml_settings_path = os.path.join(self.settings_directory,
                                         "settings.xml")
        if os.path.exists(xml_settings_path):
            tree = ET.parse(xml_settings_path)
            root = tree.getroot()
            filenames = root.find("filenames")
            calibration_path = filenames.find("calibration").text
            if os.path.exists(str(calibration_path)):
                self.calibration_model.load(calibration_path)

    def save_settings(self):
        """
        Saves current settings of Dioptas in the users directory.
        """
        if not os.path.exists(self.settings_directory):
            os.mkdir(self.settings_directory)
        self.save_directories()
        self.save_xml_settings()

    def save_directories(self):
        """
        Currently used working directories for images, spectra, etc. are saved as csv file in the users directory for
        reuse when Dioptas is started again
        """

        working_directories_path = os.path.join(self.settings_directory,
                                                'working_directories.csv')
        writer = csv.writer(open(working_directories_path, 'w'))
        for key, value in list(self.working_directories.items()):
            writer.writerow([key, value])
            writer.writerow([key, value])

    def save_xml_settings(self):
        """
        Currently used settings of Dioptas are saved in to an xml file in the users directory for reuse when Dioptas is
        started again. Right now this is only saving the calibration filename.
        """
        root = ET.Element("DioptasSettings")
        filenames = ET.SubElement(root, "filenames")
        calibration_filename = ET.SubElement(filenames, "calibration")
        calibration_filename.text = self.calibration_model.filename
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.settings_directory, "settings.xml"))

    def close_event(self, _):
        """
        Intervention of the Dioptas close event to save settings before closing the Program.
        """
        if self.use_settings:
            self.save_settings()
        QtGui.QApplication.closeAllWindows()
        QtGui.QApplication.quit()
class TestCalibrationController(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.app = QtGui.QApplication([])

    @classmethod
    def tearDownClass(cls):
        cls.app.quit()
        cls.app.deleteLater()

    def setUp(self):
        self.img_model = ImgModel()
        self.mask_model = MaskModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.calibration_model._calibrants_working_dir = os.path.join(
            data_path, 'calibrants')
        self.calibration_model.integrate_1d = MagicMock()
        self.calibration_model.integrate_2d = MagicMock()

        self.calibration_widget = CalibrationWidget()
        self.working_dir = {}
        self.calibration_controller = CalibrationController(
            working_dir=self.working_dir,
            img_model=self.img_model,
            mask_model=self.mask_model,
            widget=self.calibration_widget,
            calibration_model=self.calibration_model)

    def tearDown(self):
        del self.img_model
        del self.mask_model
        del self.calibration_model.cake_geometry
        del self.calibration_model.spectrum_geometry
        del self.calibration_model
        gc.collect()

    def test_automatic_calibration(self):
        self.calibration_controller.load_img(
            os.path.join(data_path, 'LaB6_40keV_MarCCD.tif'))
        self.calibration_controller.search_peaks(1179.6, 1129.4)
        self.calibration_controller.search_peaks(1268.5, 1119.8)
        self.calibration_controller.widget.sv_wavelength_txt.setText('0.31')
        self.calibration_controller.widget.sv_distance_txt.setText('200')
        self.calibration_controller.widget.sv_pixel_width_txt.setText('79')
        self.calibration_controller.widget.sv_pixel_height_txt.setText('79')
        calibrant_index = self.calibration_widget.calibrant_cb.findText('LaB6')
        self.calibration_controller.widget.calibrant_cb.setCurrentIndex(
            calibrant_index)

        QTest.mouseClick(self.calibration_widget.calibrate_btn,
                         QtCore.Qt.LeftButton)
        self.calibration_model.integrate_1d.assert_called_once_with()
        self.calibration_model.integrate_2d.assert_called_once_with()
        self.assertEqual(QtGui.QProgressDialog.setValue.call_count, 15)

        calibration_parameter = self.calibration_model.get_calibration_parameter(
        )[0]
        self.assertAlmostEqual(calibration_parameter['dist'], .1967, places=4)

    def test_loading_and_saving_of_calibration_files(self):
        self.calibration_controller.load_calibration(
            os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))
        self.calibration_controller.save_calibration(
            os.path.join(data_path, 'calibration.poni'))
        self.assertTrue(
            os.path.exists(os.path.join(data_path, 'calibration.poni')))
        os.remove(os.path.join(data_path, 'calibration.poni'))
 def setUp(self):
     self.img_model = ImgModel()
     self.calibration_model = CalibrationModel(self.img_model)
示例#17
0
class CbnAbsorptionCorrectionOptimizationTest(unittest.TestCase):
    def setUp(self):
        # creating Data objects
        self.img_data = ImgModel()
        self.img_data.load("Data/CbnCorrectionOptimization/Mg2SiO4_091.tif")
        self.calibration_data = CalibrationModel(self.img_data)
        self.calibration_data.load("Data/CbnCorrectionOptimization/LaB6_40keV side.poni")
        self.mask_data = MaskModel()
        self.mask_data.load_mask("Data/CbnCorrectionOptimization/Mg2SiO4_91_combined.mask")

        # creating the ObliqueAngleDetectorAbsorptionCorrection
        _, fit2d_parameter = self.calibration_data.get_calibration_parameter()
        detector_tilt = fit2d_parameter['tilt']
        detector_tilt_rotation = fit2d_parameter['tiltPlanRotation']

        self.tth_array = self.calibration_data.spectrum_geometry.twoThetaArray((2048, 2048))
        self.azi_array = self.calibration_data.spectrum_geometry.chiArray((2048, 2048))

        self.oiadac_correction = ObliqueAngleDetectorAbsorptionCorrection(
                self.tth_array, self.azi_array,
                detector_thickness=40,
                absorption_length=465.5,
                tilt=detector_tilt,
                rotation=detector_tilt_rotation,
        )
        self.img_data.add_img_correction(self.oiadac_correction, "oiadac")

    def tearDown(self):
        del self.calibration_data.cake_geometry
        del self.calibration_data.spectrum_geometry

    def test_the_world(self):
        params = Parameters()
        params.add("diamond_thickness", value=2, min=1.9, max=2.3)
        params.add("seat_thickness", value=5.3, min=4.0, max=6.6, vary=False)
        params.add("small_cbn_seat_radius", value=0.2, min=0.10, max=0.5, vary=True)
        params.add("large_cbn_seat_radius", value=1.95, min=1.85, max=2.05, vary=False)
        params.add("tilt", value=3.3, min=0, max=8)
        params.add("tilt_rotation", value=0, min=-15, max=+15)
        params.add("cbn_abs_length", value=14.05, min=12, max=16)

        region = [8, 26]

        self.tth_array = 180.0 / np.pi * self.tth_array
        self.azi_array = 180.0 / np.pi * self.azi_array

        def fcn2min(params):
            cbn_correction = CbnCorrection(
                    tth_array=self.tth_array,
                    azi_array=self.azi_array,
                    diamond_thickness=params['diamond_thickness'].value,
                    seat_thickness=params['seat_thickness'].value,
                    small_cbn_seat_radius=params['small_cbn_seat_radius'].value,
                    large_cbn_seat_radius=params['large_cbn_seat_radius'].value,
                    tilt=params['tilt'].value,
                    tilt_rotation=params['tilt_rotation'].value,
                    cbn_abs_length=params["cbn_abs_length"].value
            )
            self.img_data.add_img_correction(cbn_correction, "cbn")
            tth, int = self.calibration_data.integrate_1d(mask=self.mask_data.get_mask())
            self.img_data.delete_img_correction("cbn")
            ind = np.where((tth > region[0]) & (tth < region[1]))
            int = gaussian_filter1d(int, 20)
            return (np.diff(int[ind])) ** 2

        def output_values(param1, iteration, residual):
            report_fit(param1)

        result = minimize(fcn2min, params, iter_cb=output_values)
        report_fit(params)

        # plotting result:
        cbn_correction = CbnCorrection(
                tth_array=self.tth_array,
                azi_array=self.azi_array,
                diamond_thickness=params['diamond_thickness'].value,
                seat_thickness=params['seat_thickness'].value,
                small_cbn_seat_radius=params['small_cbn_seat_radius'].value,
                large_cbn_seat_radius=params['large_cbn_seat_radius'].value,
                tilt=params['tilt'].value,
                tilt_rotation=params['tilt_rotation'].value,
                cbn_abs_length=params['cbn_abs_length'].value
        )
        self.img_data.add_img_correction(cbn_correction, "cbn")
        tth, int = self.calibration_data.integrate_1d(mask=self.mask_data.get_mask())
        ind = np.where((tth > region[0]) & (tth < region[1]))
        tth = tth[ind]
        int = int[ind]
        int_smooth = gaussian_filter1d(int, 10)

        int_diff1 = np.diff(int)
        int_diff1_smooth = np.diff(int_smooth)
        int_diff2 = np.diff(int_diff1)
        int_diff2_smooth = np.diff(int_diff1_smooth)

        plt.figure()
        plt.subplot(3, 1, 1)
        plt.plot(tth, int)
        plt.plot(tth, int_smooth)
        plt.subplot(3, 1, 2)
        plt.plot(int_diff1)
        plt.plot(int_diff1_smooth)
        plt.subplot(3, 1, 3)
        plt.plot(int_diff2)
        plt.plot(int_diff2_smooth)
        plt.savefig("Results/optimize_cbn_absorption.png", dpi=300)

        os.system("open " + "Results/optimize_cbn_absorption.png")
class PhaseControllerTest(unittest.TestCase):
    def setUp(self):
        self.app = QtGui.QApplication(sys.argv)
        self.image_model = ImgModel()
        self.calibration_model = CalibrationModel()
        self.calibration_model.load(os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))
        self.spectrum_model = SpectrumModel()
        self.phase_model = PhaseModel()
        self.widget = IntegrationWidget()
        self.widget.spectrum_view._auto_range = True
        self.phase_tw = self.widget.phase_tw

        self.spectrum_controller = SpectrumController({}, self.widget, self.image_model, None,
                                                                   self.calibration_model, self.spectrum_model)
        self.controller = PhaseController({}, self.widget, self.calibration_model, self.spectrum_model,
                                                       self.phase_model)
        self.spectrum_model.load_spectrum(os.path.join(data_path, 'spectrum_001.xy'))


    def tearDown(self):
        del self.app

    def test_manual_deleting_phases(self):
        self.load_phases()
        QtGui.QApplication.processEvents()

        self.assertEqual(self.phase_tw.rowCount(), 6)
        self.assertEqual(len(self.phase_model.phases), 6)
        self.assertEqual(len(self.widget.spectrum_view.phases), 6)
        self.assertEqual(self.phase_tw.currentRow(), 5)

        self.controller.remove_btn_click_callback()
        self.assertEqual(self.phase_tw.rowCount(), 5)
        self.assertEqual(len(self.phase_model.phases), 5)
        self.assertEqual(len(self.widget.spectrum_view.phases), 5)
        self.assertEqual(self.phase_tw.currentRow(), 4)

        self.widget.select_phase(1)
        self.controller.remove_btn_click_callback()
        self.assertEqual(self.phase_tw.rowCount(), 4)
        self.assertEqual(len(self.phase_model.phases), 4)
        self.assertEqual(len(self.widget.spectrum_view.phases), 4)
        self.assertEqual(self.phase_tw.currentRow(), 1)

        self.widget.select_phase(0)
        self.controller.remove_btn_click_callback()
        self.assertEqual(self.phase_tw.rowCount(), 3)
        self.assertEqual(len(self.phase_model.phases), 3)
        self.assertEqual(len(self.widget.spectrum_view.phases), 3)
        self.assertEqual(self.phase_tw.currentRow(), 0)

        self.controller.remove_btn_click_callback()
        self.controller.remove_btn_click_callback()
        self.controller.remove_btn_click_callback()
        self.assertEqual(self.phase_tw.rowCount(), 0)
        self.assertEqual(len(self.phase_model.phases), 0)
        self.assertEqual(len(self.widget.spectrum_view.phases), 0)
        self.assertEqual(self.phase_tw.currentRow(), -1)

        self.controller.remove_btn_click_callback()
        self.assertEqual(self.phase_tw.rowCount(), 0)
        self.assertEqual(len(self.phase_model.phases), 0)
        self.assertEqual(len(self.widget.spectrum_view.phases), 0)
        self.assertEqual(self.phase_tw.currentRow(), -1)

    def test_automatic_deleting_phases(self):
        self.load_phases()
        self.load_phases()
        self.assertEqual(self.phase_tw.rowCount(), 12)
        self.assertEqual(len(self.phase_model.phases), 12)
        self.assertEqual(len(self.widget.spectrum_view.phases), 12)
        self.controller.clear_phases()
        self.assertEqual(self.phase_tw.rowCount(), 0)
        self.assertEqual(len(self.phase_model.phases), 0)
        self.assertEqual(len(self.widget.spectrum_view.phases), 0)
        self.assertEqual(self.phase_tw.currentRow(), -1)

        multiplier = 1
        for dummy_index in range(multiplier):
            self.load_phases()

        self.assertEqual(self.phase_tw.rowCount(), multiplier * 6)
        self.controller.clear_phases()
        self.assertEqual(self.phase_tw.rowCount(), 0)
        self.assertEqual(len(self.phase_model.phases), 0)
        self.assertEqual(len(self.widget.spectrum_view.phases), 0)
        self.assertEqual(self.phase_tw.currentRow(), -1)


    def test_pressure_change(self):
        self.load_phases()
        pressure = 200
        self.widget.phase_pressure_sb.setValue(200)
        for ind, phase in enumerate(self.phase_model.phases):
            self.assertEqual(phase.pressure, pressure)
            self.assertEqual(self.widget.get_phase_pressure(ind), pressure)

    def test_temperature_change(self):
        self.load_phases()
        temperature = 1500
        self.widget.phase_temperature_sb.setValue(temperature)
        for ind, phase in enumerate(self.phase_model.phases):
            if phase.has_thermal_expansion():
                self.assertEqual(phase.temperature, temperature)
                self.assertEqual(self.widget.get_phase_temperature(ind), temperature)
            else:
                self.assertEqual(phase.temperature, 298)
                self.assertEqual(self.widget.get_phase_temperature(ind), None)

    def test_apply_to_all_for_new_added_phase_in_table_widget(self):
        temperature = 1500
        pressure = 200
        self.widget.phase_temperature_sb.setValue(temperature)
        self.widget.phase_pressure_sb.setValue(pressure)
        self.load_phases()
        for ind, phase in enumerate(self.phase_model.phases):
            self.assertEqual(phase.pressure, pressure)
            self.assertEqual(self.widget.get_phase_pressure(ind), pressure)
            if phase.has_thermal_expansion():
                self.assertEqual(phase.temperature, temperature)
                self.assertEqual(self.widget.get_phase_temperature(ind), temperature)
            else:
                self.assertEqual(phase.temperature, 298)
                self.assertEqual(self.widget.get_phase_temperature(ind), None)

    def test_apply_to_all_for_new_added_phase_d_positions(self):
        pressure = 50
        self.load_phase('au_Anderson.jcpds')
        self.widget.phase_pressure_sb.setValue(pressure)
        self.load_phase('au_Anderson.jcpds')

        reflections1 = self.phase_model.get_lines_d(0)
        reflections2 = self.phase_model.get_lines_d(1)
        self.assertTrue(np.array_equal(reflections1, reflections2))

    def test_to_not_show_lines_in_legend(self):
        self.load_phases()
        self.phase_tw.selectRow(1)
        QTest.mouseClick(self.widget.phase_del_btn, QtCore.Qt.LeftButton)
        self.widget.spectrum_view.hide_phase(1)

    def test_auto_scaling_of_lines_in_spectrum_view(self):
        spectrum_view = self.widget.spectrum_view

        spectrum_view_range = spectrum_view.view_box.viewRange()
        spectrum_y = spectrum_view.plot_item.getData()[1]
        expected_maximum_height = np.max(spectrum_y) - spectrum_view_range[1][0]

        self.load_phase('au_Anderson.jcpds')
        phase_plot = spectrum_view.phases[0]
        line_heights = []
        for line in phase_plot.line_items:
            line_data = line.getData()
            height = line_data[1][1]-line_data[1][0]
            line_heights.append(height)

        self.assertAlmostEqual(expected_maximum_height, np.max(line_heights))

        spectrum_view_range = spectrum_view.view_box.viewRange()
        spectrum_y = spectrum_view.plot_item.getData()[1]
        expected_maximum_height = np.max(spectrum_y) - spectrum_view_range[1][0]

        self.assertAlmostEqual(expected_maximum_height, np.max(line_heights))

    def test_line_height_in_spectrum_view_after_zooming(self):
        spectrum_view = self.widget.spectrum_view
        self.load_phase('au_Anderson.jcpds')

        spectrum_view.view_box.setRange(xRange=[17,30])
        spectrum_view.emit_sig_range_changed()

        phase_plot = spectrum_view.phases[0]
        line_heights = []
        for line in phase_plot.line_items:
            line_data = line.getData()
            if (line_data[0][0] > 17) and (line_data[0][1]<30):
                height = line_data[1][1]-line_data[1][0]
                line_heights.append(height)

        spectrum_view_range = spectrum_view.view_box.viewRange()
        spectrum_x, spectrum_y = spectrum_view.plot_item.getData()
        spectrum_y_max_in_range = np.max(spectrum_y[(spectrum_x > spectrum_view_range[0][0]) &\
            (spectrum_x<spectrum_view_range[0][1])])
        expected_maximum_height = spectrum_y_max_in_range - spectrum_view_range[1][0]

        self.assertAlmostEqual(expected_maximum_height, np.max(line_heights))


    def load_phases(self):
        self.load_phase('ar.jcpds')
        self.load_phase('ag.jcpds')
        self.load_phase('au_Anderson.jcpds')
        self.load_phase('mo.jcpds')
        self.load_phase('pt.jcpds')
        self.load_phase('re.jcpds')

    def load_phase(self, filename):
        self.controller.add_btn_click_callback(os.path.join(jcpds_path, filename))
示例#19
0
class MainController(object):
    """
    Creates a the main controller for Dioptas. Loads all the data objects and connects them with the other controllers
    """

    def __init__(self, use_settings=True):
        self.use_settings = use_settings

        self.widget = MainWidget()
        #create data
        self.img_model = ImgModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.mask_model = MaskModel()
        self.spectrum_model = SpectrumModel()
        self.phase_model = PhaseModel()

        self.settings_directory = os.path.join(os.path.expanduser("~"), '.Dioptas')
        self.working_directories = {'calibration': '', 'mask': '', 'image': '', 'spectrum': '', 'overlay': '',
                                'phase': ''}


        if use_settings:
            self.load_settings()
        #create controller
        self.calibration_controller = CalibrationController(self.working_directories,
                                                            self.widget.calibration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model)
        self.mask_controller = MaskController(self.working_directories,
                                              self.widget.mask_widget,
                                              self.img_model,
                                              self.mask_model)
        self.integration_controller = IntegrationController(self.working_directories,
                                                            self.widget.integration_widget,
                                                            self.img_model,
                                                            self.mask_model,
                                                            self.calibration_model,
                                                            self.spectrum_model,
                                                            self.phase_model)
        self.create_signals()
        self.set_title()

    def show_window(self):
        self.widget.show()
        self.widget.setWindowState(self.widget.windowState() & ~QtCore.Qt.WindowMinimized | QtCore.Qt.WindowActive)
        self.widget.activateWindow()
        self.widget.raise_()

    def create_signals(self):
        self.widget.tabWidget.currentChanged.connect(self.tab_changed)
        self.widget.closeEvent = self.close_event
        self.img_model.subscribe(self.set_title)
        self.spectrum_model.spectrum_changed.connect(self.set_title)

    def tab_changed(self, ind):
        if ind == 2:
            self.mask_model.set_supersampling()
            self.integration_controller.image_controller.plot_mask()
            self.integration_controller.widget.calibration_lbl.setText(self.calibration_model.calibration_name)
            self.integration_controller.image_controller._auto_scale = False
            self.integration_controller.spectrum_controller.image_changed()
            self.integration_controller.image_controller.update_img()
        elif ind == 1:
            self.mask_controller.plot_mask()
            self.mask_controller.plot_image()
        elif ind == 0:
            self.calibration_controller.plot_mask()
            try:
                self.calibration_controller.update_calibration_parameter_in_view()
            except (TypeError, AttributeError):
                pass

    def set_title(self):
        img_filename = os.path.basename(self.img_model.filename)
        spec_filename = os.path.basename(self.spectrum_model.spectrum_filename)
        calibration_name = self.calibration_model.calibration_name
        str = 'Dioptas ' + __version__
        if img_filename is '' and spec_filename is '':
            self.widget.setWindowTitle(str + u' - © 2015 C. Prescher')
            self.widget.integration_widget.img_frame.setWindowTitle(str + u' - © 2015 C. Prescher')
            return

        if img_filename is not '' or spec_filename is not '':
            str += ' - ['
        if img_filename is not '':
            str += img_filename
        elif img_filename is '' and spec_filename is not '':
            str += spec_filename
        if not img_filename == spec_filename:
            str += ', ' + spec_filename
        if calibration_name is not None:
            str += ', calibration: ' + calibration_name
        str += ']'
        str += u' - © 2015 C. Prescher'
        self.widget.setWindowTitle(str)
        self.widget.integration_widget.img_frame.setWindowTitle(str)

    def load_settings(self):
        if os.path.exists(self.settings_directory):
            self.load_directories()
            self.load_xml_settings()

    def load_directories(self):
        working_directories_path = os.path.join(self.settings_directory, 'working_directories.csv')
        if os.path.exists(working_directories_path):
            reader = csv.reader(open(working_directories_path, 'r'))
            self.working_directories = dict(x for x in reader)


    def load_xml_settings(self):
        xml_settings_path = os.path.join(self.settings_directory, "settings.xml")
        if os.path.exists(xml_settings_path):
            tree = ET.parse(xml_settings_path)
            root = tree.getroot()
            filenames = root.find("filenames")
            calibration_path=filenames.find("calibration").text
            if os.path.exists(str(calibration_path)):
                self.calibration_model.load(calibration_path)

    def save_settings(self):
        if not os.path.exists(self.settings_directory):
            os.mkdir(self.settings_directory)
        self.save_directories()
        self.save_xml_settings()

    def save_directories(self):
        working_directories_path = os.path.join(self.settings_directory, 'working_directories.csv')
        writer = csv.writer(open(working_directories_path, 'w'))
        for key, value in list(self.working_directories.items()):
            writer.writerow([key, value])
            writer.writerow([key, value])

    def save_xml_settings(self):
        root = ET.Element("DioptasSettings")
        filenames = ET.SubElement(root, "filenames")
        calibration_filename = ET.SubElement(filenames, "calibration")
        calibration_filename.text = self.calibration_model.filename
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.settings_directory, "settings.xml"))

    def close_event(self, _):
        if self.use_settings:
            self.save_settings()
        QtGui.QApplication.closeAllWindows()
        QtGui.QApplication.quit()
class CalibrationControllerTest(unittest.TestCase):
    def setUp(self):
        self.app = QtGui.QApplication(sys.argv)
        self.img_model = ImgModel()
        self.mask_model = MaskModel()
        self.calibration_model = CalibrationModel(self.img_model)
        self.calibration_model._calibrants_working_dir = os.path.join(data_path, 'calibrants')
        self.calibration_widget = CalibrationWidget()
        self.working_dir = {}
        self.calibration_controller = CalibrationController(working_dir=self.working_dir,
                                                            img_model=self.img_model,
                                                            mask_model=self.mask_model,
                                                            widget=self.calibration_widget,
                                                            calibration_model=self.calibration_model)

    def tearDown(self):
        del self.img_model
        del self.calibration_model.cake_geometry
        del self.calibration_model.spectrum_geometry
        del self.calibration_model
        del self.app
        gc.collect()

    def load_pilatus_1M_and_pick_peaks(self):
        self.calibration_controller.load_img(os.path.join(data_path,'CeO2_Pilatus1M.tif'))
        QTest.mouseClick(self.calibration_widget.automatic_peak_num_inc_cb, QtCore.Qt.LeftButton)

        self.assertFalse(self.calibration_widget.automatic_peak_num_inc_cb.isChecked())
        self.calibration_controller.search_peaks(517.664434674, 647.529865592)
        self.calibration_controller.search_peaks(667.380513299, 525.252854758)
        self.calibration_controller.search_peaks(671.110095329, 473.571503774)
        self.calibration_controller.search_peaks(592.788872703, 350.495296791)
        self.calibration_controller.search_peaks(387.395462348, 390.987901686)
        self.calibration_controller.search_peaks(367.94835605, 554.290314848)

        self.calibration_widget.sv_wavelength_txt.setText('0.406626')
        self.calibration_widget.sv_distance_txt.setText('200')
        self.calibration_widget.sv_pixel_width_txt.setText('172')
        self.calibration_widget.sv_pixel_height_txt.setText('172')
        calibrant_index = self.calibration_widget.calibrant_cb.findText('CeO2')
        self.calibration_widget.calibrant_cb.setCurrentIndex(calibrant_index)

        self.mask_model.set_dimension(self.img_model.img_data.shape)

    def test_automatic_calibration1(self):
        self.calibration_controller.load_img(os.path.join(data_path,'LaB6_40keV_MarCCD.tif'))
        self.calibration_controller.search_peaks(1179.6, 1129.4)
        self.calibration_controller.search_peaks(1268.5, 1119.8)
        self.calibration_controller.widget.sv_wavelength_txt.setText('0.31')
        self.calibration_controller.widget.sv_distance_txt.setText('200')
        self.calibration_controller.widget.sv_pixel_width_txt.setText('79')
        self.calibration_controller.widget.sv_pixel_height_txt.setText('79')
        calibrant_index = self.calibration_widget.calibrant_cb.findText('LaB6')
        self.calibration_controller.widget.calibrant_cb.setCurrentIndex(calibrant_index)
        self.calibration_controller.calibrate()
        self.calibration_controller.widget.cake_view.set_vertical_line_pos(1419.8, 653.4)

    def test_automatic_calibration2(self):
        self.calibration_controller.load_img(os.path.join(data_path,'LaB6_OffCenter_PE.tif'))
        self.calibration_controller.search_peaks(1245.2, 1919.3)
        self.calibration_controller.search_peaks(1334.0, 1823.7)
        self.calibration_controller.widget.sv_wavelength_txt.setText('0.3344')
        self.calibration_controller.widget.sv_distance_txt.setText('500')
        self.calibration_controller.widget.sv_pixel_width_txt.setText('200')
        self.calibration_controller.widget.sv_pixel_height_txt.setText('200')
        calibrant_index = self.calibration_widget.calibrant_cb.findText('LaB6')
        self.calibration_controller.widget.calibrant_cb.setCurrentIndex(calibrant_index)
        self.calibration_controller.calibrate()
        self.calibration_controller.widget.cake_view.set_vertical_line_pos(206.5, 171.6)

    def test_automatic_calibration3(self):
        self.load_pilatus_1M_and_pick_peaks()

        start_values = self.calibration_widget.get_start_values()
        self.assertAlmostEqual(start_values['wavelength'], 0.406626e-10)
        self.assertAlmostEqual(start_values['pixel_height'], 172e-6)
        self.assertAlmostEqual(start_values['pixel_width'], 172e-6)
        self.calibration_controller.load_calibrant()
        self.assertAlmostEqual(self.calibration_model.calibrant.wavelength, 0.406626e-10)

        QTest.mouseClick(self.calibration_widget.integrate_btn, QtCore.Qt.LeftButton)
        calibration_parameter = self.calibration_model.get_calibration_parameter()[0]

        self.assertAlmostEqual(calibration_parameter['dist'], .2086, places=4)

    def test_automatic_calibration_with_supersampling(self):
        self.load_pilatus_1M_and_pick_peaks()
        self.img_model.set_supersampling(2)
        self.calibration_model.set_supersampling(2)
        self.calibration_controller.calibrate()

    def test_automatic_calibration_with_supersampling_and_mask(self):
        self.load_pilatus_1M_and_pick_peaks()
        self.img_model.set_supersampling(2)
        self.mask_model.mask_below_threshold(self.img_model._img_data, 1)
        self.mask_model.set_supersampling(2)
        self.calibration_model.set_supersampling(2)
        self.calibration_controller.widget.use_mask_cb.setChecked(True)
        self.calibration_controller.calibrate()

    def test_calibrating_one_image_size_and_loading_different_image_size(self):
        self.load_pilatus_1M_and_pick_peaks()
        self.calibration_controller.calibrate()
        self.calibration_model.integrate_1d()
        self.calibration_model.integrate_2d()
        self.calibration_controller.load_img(os.path.join(data_path, 'CeO2_Pilatus1M.tif'))
        self.calibration_model.integrate_1d()
        self.calibration_model.integrate_2d()

    def test_loading_and_saving_of_calibration_files(self):
        self.calibration_controller.load_calibration(os.path.join(data_path, 'LaB6_40keV_MarCCD.poni'))
        self.calibration_controller.save_calibration(os.path.join(data_path, 'calibration.poni'))
        self.assertTrue(os.path.exists(os.path.join(data_path, 'calibration.poni')))
        os.remove(os.path.join(data_path, 'calibration.poni'))