class GalvoRegistrator:
    def __init__(self, *args, **kwargs):
        self.cam = CamActuator()
        self.cam.initializeCamera()

    def registration(self, grid_points_x=3, grid_points_y=3):
        galvothread = DAQmission()
        readinchan = []

        x_coords = np.linspace(-10, 10, grid_points_x + 2)[1:-1]
        y_coords = np.linspace(-10, 10, grid_points_y + 2)[1:-1]

        xy_mesh = np.reshape(np.meshgrid(x_coords, y_coords), (2, -1),
                             order='F').transpose()

        galvo_coordinates = xy_mesh
        camera_coordinates = np.zeros((galvo_coordinates.shape))

        for i in range(galvo_coordinates.shape[0]):

            galvothread.sendSingleAnalog('galvosx', galvo_coordinates[i, 0])
            galvothread.sendSingleAnalog('galvosy', galvo_coordinates[i, 1])
            time.sleep(1)

            image = self.cam.SnapImage(0.06)
            plt.imsave(
                os.getcwd() +
                '/CoordinatesManager/Registration_Images/2P/image_' + str(i) +
                '.png', image)

            camera_coordinates[i, :] = readRegistrationImages.gaussian_fitting(
                image)

        print('Galvo Coordinate')
        print(galvo_coordinates)
        print('Camera coordinates')
        print(camera_coordinates)
        del galvothread
        self.cam.Exit()

        transformation = CoordinateTransformations.polynomial2DFit(
            camera_coordinates, galvo_coordinates, order=1)

        print('Transformation found for x:')
        print(transformation[:, :, 0])
        print('Transformation found for y:')
        print(transformation[:, :, 1])
        return transformation
Exemple #2
0
class DMDRegistator:
    def __init__(self, parent):
        self.DMD = DMDActuator.DMDActuator()
        self.cam = CamActuator()
        self.cam.initializeCamera()

    def registration(self, laser='640', points=6):
        x_coords = np.linspace(0, 1024, 5)[1:-1]
        y_coords = np.linspace(0, 768, 4)[1:-1]

        dmd_coordinates = np.vstack((x_coords, y_coords))

        camera_coordinates = np.zeros(dmd_coordinates.shape)

        cnt = 0
        for i in range(points):
            for j in range(y_coords.shape[0]):
                mask = create_registration_image(i, j)
                self.DMD.send_data_to_DMD(mask)
                self.DMD.start_projection()

                image = self.cam.SnapImage(0.4)
                camera_coordinates[cnt, :] = touchingCoordinateFinder(
                    image, method='curvefit')
                cnt += 1

                self.DMD.stop_projection()
                self.DMD.free_memory()

        self.DMD.disconnect_DMD()
        transformation = findTransformationCurvefit(camera_coordinates,
                                                    dmd_coordinates,
                                                    kx=3,
                                                    ky=2)
        return transformation

    def create_registration_image(x, y, sigma=75):
        array = np.zeros((1024, 768))
        array[skd.draw.rectangle((x - sigma, y - sigma), (x, y))] = 255
        array[skd.draw.rectangle((x + sigma, y + sigma), (x, y))] = 255
        return array
class DMDRegistator:
    def __init__(self, DMD, *args, **kwargs):
        self.DMD = DMD
        self.cam = CamActuator()
        self.cam.initializeCamera()

    def registration(self,
                     laser='640',
                     grid_points_x=2,
                     grid_points_y=3,
                     registration_pattern='circles'):
        x_coords = np.linspace(0, 768, grid_points_x + 2)[1:-1]
        y_coords = np.linspace(0, 1024, grid_points_y + 2)[1:-1]

        x_mesh, y_mesh = np.meshgrid(x_coords, y_coords)

        x_coords = np.ravel(x_mesh)
        y_coords = np.ravel(y_mesh)

        dmd_coordinates = np.stack((x_coords, y_coords), axis=1)

        camera_coordinates = np.zeros(dmd_coordinates.shape)

        for i in range(dmd_coordinates.shape[0]):
            x = int(dmd_coordinates[i, 0])
            y = int(dmd_coordinates[i, 1])

            if registration_pattern == 'squares':
                mask = DMDRegistator.create_registration_image_touching_squares(
                    x, y)
            else:
                mask = DMDRegistator.create_registration_image_circle(x, y)

            self.DMD.send_data_to_DMD(mask)
            self.DMD.start_projection()

            image = self.cam.SnapImage(0.01)
            plt.imsave(
                os.getcwd() +
                '/CoordinatesManager/Registration_Images/TouchingSquares/image_'
                + str(i) + '.png', image)
            camera_coordinates[
                i, :] = readRegistrationImages.touchingCoordinateFinder(
                    image, method='curvefit')

            self.DMD.stop_projection()

        print('DMD coordinates:')
        print(dmd_coordinates)
        print('Found camera coordinates:')
        print(camera_coordinates)

        self.DMD.free_memory()
        self.cam.Exit()

        transformation = CoordinateTransformations.polynomial2DFit(
            camera_coordinates, dmd_coordinates, order=1)
        print('Transformation found for x:')
        print(transformation[:, :, 0])
        print('Transformation found for y:')
        print(transformation[:, :, 1])
        return transformation

    def create_registration_image_touching_squares(x, y, sigma=75):
        array = np.zeros((768, 1024))
        array[skimage.draw.rectangle((x - sigma, y - sigma), (x, y))] = 255
        array[skimage.draw.rectangle((x + sigma, y + sigma), (x, y))] = 255
        return array

    def create_registration_image_circle(x, y, sigma=75):
        array = np.zeros((768, 1024))
        array[skimage.draw.circle(x, y, sigma)] = 255
        return array
Exemple #4
0
class RegistrationThread(QThread):

    sig_finished_registration = pyqtSignal(dict)

    def __init__(self, parent, laser=None):
        QThread.__init__(self)
        self.flag_finished = [0, 0, 0]
        self.backend = parent
        self.dmd = self.backend.DMD

        if not isinstance(laser, list):
            self.laser_list = [laser]
        else:
            self.laser_list = laser

        self.dict_transformators = {}

        self.dict_transformations = {}
        self.dtype_ref_co = np.dtype([('camera', int, (3, 2)),
                                      ('dmd', int, (3, 2)),
                                      ('galvos', int, (3, 2)),
                                      ('stage', int, (3, 2))])
        self.reference_coordinates = {}

    def set_device_to_register(self, device_1, device_2='camera'):
        self.device_1 = device_1
        self.device_2 = device_2

    def run(self):
        #Make sure registration can only start when camera is connected
        try:
            self.cam = CamActuator()
            self.cam.initializeCamera()
        except:
            print(sys.exc_info())
            self.backend.ui_widget.normalOutputWritten(
                'Unable to connect Hamamatsu camera')
            return

        self.cam.setROI(0, 0, 2048, 2048)

        if self.device_1 == 'galvos':
            reference_coordinates = self.gather_reference_coordinates_galvos()
            self.dict_transformations['camera-galvos'] = findTransform(reference_coordinates[0], \
                                                                       reference_coordinates[1])
        elif self.device_1 == 'dmd':
            reference_coordinates = self.gather_reference_coordinates_dmd()
            for laser in self.laser_list:
                self.dict_transformations['camera-dmd-'+laser] = findTransform(reference_coordinates[0], \
                                                                               reference_coordinates[1])

        elif self.device_1 == 'stage':
            reference_coordinates = self.gather_reference_coordinates_stage()
            self.dict_transformations['camera-stage'] = findTransform(reference_coordinates[0], \
                                                                      reference_coordinates[1])

        self.cam.Exit()

        ## Save transformation to file
        with open('CoordinatesManager/Registration/transformation.txt',
                  'w') as json_file:

            dict_transformations_list_format = {}
            for key, value in self.dict_transformations.items():
                dict_transformations_list_format[key] = value.tolist()

            json.dump(dict_transformations_list_format, json_file)

        self.sig_finished_registration.emit(self.dict_transformations)

    def gather_reference_coordinates_stage(self):
        image = np.zeros((2048, 2048, 3))
        stage_coordinates = np.array([[-2800, 100], [-2500, 400],
                                      [-1900, -200]])

        self.backend.loadMask(mask=np.ones((768, 1024)))
        self.backend.startProjection()

        for idx, pos in enumerate(stage_coordinates):

            stage_movement_thread = StagemovementAbsoluteThread(pos[0], pos[1])
            stage_movement_thread.start()
            time.sleep(0.5)
            stage_movement_thread.quit()
            stage_movement_thread.wait()

            image[:, :, idx] = self.cam.SnapImage(0.04)

        camera_coordinates = find_subimage_location(image, save=True)

        self.backend.stopProjection()
        self.backend.freeMemory()

        return np.array([camera_coordinates, stage_coordinates])

    def gather_reference_coordinates_galvos(self):
        galvothread = DAQmission()
        readinchan = []

        camera_coordinates = np.zeros((3, 2))
        galvo_coordinates = np.array([[0, 3], [3, -3], [-3, -3]])

        for i in range(3):
            pos_x = galvo_coordinates[i, 0]
            pos_y = galvo_coordinates[i, 1]

            galvothread.sendSingleAnalog('galvosx', pos_x)
            galvothread.sendSingleAnalog('galvosy', pos_y)

            image = self.cam.SnapImage(0.04)

            camera_coordinates[i, :] = gaussian_fitting(image)

        del galvothread
        return np.array([camera_coordinates, galvo_coordinates])

    def gather_reference_coordinates_dmd(self):
        galvo_coordinates = np.zeros((3, 2))

        for laser in self.laser_list:
            self.flag_finished = [0, 0, 0]

            self.backend.ui_widget.sig_control_laser.emit(laser, 5)

            self.registration_single_laser(laser)

            self.backend.ui_widget.sig_control_laser.emit(laser, 0)

        return np.array(
            [self.camera_coordinates, self.dmd_coordinates, galvo_coordinates])

    def registration_single_laser(self, laser):
        date_time = datetime.datetime.now().timetuple()
        image_id = ''
        for i in range(5):
            image_id += str(date_time[i]) + '_'
        image_id += str(date_time[5]) + '_l' + laser

        self.camera_coordinates = np.zeros((3, 2))
        self.touchingCoordinateFinder = []

        for i in range(3):
            self.touchingCoordinateFinder.append(
                touchingCoordinateFinder_Thread(i, method='curvefit'))
            self.touchingCoordinateFinder[
                i].sig_finished_coordinatefinder.connect(
                    self.touchingCoordinateFinder_finished)

        for i in range(3):
            self.loadFileName = './CoordinatesManager/Registration_Images/TouchingSquares/registration_mask_' + str(
                i) + '.png'

            # Transpose because mask in file is rotated by 90 degrees.
            mask = np.transpose(plt.imread(self.loadFileName))

            self.backend.loadMask(mask)
            self.backend.startProjection()

            time.sleep(0.5)
            self.image = self.cam.SnapImage(0.0015)
            time.sleep(0.5)

            self.backend.stopProjection()
            self.backend.freeMemory()

            # Start touchingCoordinateFinder thread
            self.touchingCoordinateFinder[i].put_image(self.image)
            self.touchingCoordinateFinder[i].start()

        self.dmd_coordinates = self.read_dmd_coordinates_from_file()

        # Block till all touchingCoordinateFinder_Thread threads are finished
        while np.prod(self.flag_finished) == 0:
            time.sleep(0.1)

    def read_dmd_coordinates_from_file(self):
        file = open(
            './CoordinatesManager/Registration_Images/TouchingSquares/positions.txt',
            'r')

        self.dmd_coordinates = []
        for ln in file.readlines():
            self.dmd_coordinates.append(ln.strip().split(','))
        file.close()

        return np.asarray(self.dmd_coordinates).astype(int)

    def touchingCoordinateFinder_finished(self, sig):
        self.camera_coordinates[sig, :] = np.flip(
            self.touchingCoordinateFinder[sig].coordinates)
        self.flag_finished[sig] = 1
class StageWidget(QWidget):
    def __init__(self, parent=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.main_application = parent

        self.set_image_saving_location(
            os.getcwd() +
            '/CoordinatesManager/Registration_Images/StageRegistration/')

        self.init_gui()

        self.ludlStage = LudlStage("COM12")

    def init_gui(self):
        layout = QGridLayout()

        # self.setFixedSize(320,100)

        self.box = roundQGroupBox()
        self.box.setTitle("Stage")
        box_layout = QGridLayout()
        self.box.setLayout(box_layout)

        self.setLayout(layout)

        self.collect_data_button = QPushButton('Collect data')
        self.collect_data_button.clicked.connect(self.start_aqcuisition)

        box_layout.addWidget(self.collect_data_button)

        layout.addWidget(self.box)

    def set_image_saving_location(self, filepath):
        self.image_file_path = filepath

    def start_aqcuisition(self):
        global_pos = np.array(
            ((-5000, -5000), (-5000, 5000), (5000, -5000), (5000, 5000)))
        global_pos_name = ['A', 'B', 'C', 'D']

        delta = 200
        local_pos = np.transpose(
            np.reshape(
                np.meshgrid(np.array((-delta, 0, delta)),
                            np.array((-delta, 0, delta))), (2, -1)))
        local_pos_name = [str(i) for i in range(9)]

        self.cam = CamActuator()
        self.cam.initializeCamera()

        # Offset variables are used to generate replicates for good statistics
        offset_y = offset_x = delta

        cnt = 0
        for i in range(global_pos.shape[0]):
            for j in range(local_pos.shape[0]):
                x = global_pos[i, 0] + local_pos[j, 0] + offset_x
                y = global_pos[i, 1] + local_pos[j, 1] + offset_y

                self.ludlStage.moveAbs(x=x, y=y)
                # stage_movement_thread = StagemovementAbsoluteThread(x, y)
                # stage_movement_thread.start()
                # time.sleep(2)
                # stage_movement_thread.quit()
                # stage_movement_thread.wait()

                image = self.cam.SnapImage(0.04)
                filename = global_pos_name[i] + local_pos_name[j]

                self.save_image(filename, image)

                cnt += 1
                print(
                    str(cnt) + '/' +
                    str(len(local_pos_name) * len(global_pos_name)))

        self.cam.Exit()

    def save_image(self, filename, image):
        plt.imsave(self.image_file_path + filename + '.png', image)
Exemple #6
0
class CoordinatesWidgetUI(QWidget):

    sig_cast_mask_coordinates_to_dmd = pyqtSignal(list)
    sig_cast_mask_coordinates_to_galvo = pyqtSignal(list)
    sig_start_registration = pyqtSignal()
    sig_finished_registration = pyqtSignal()
    sig_cast_camera_image = pyqtSignal(np.ndarray)

    def __init__(self, parent=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.main_application = parent
        self.init_gui()

    def closeEvent(self, event):
        try:
            self.DMD
        except:
            pass
        else:
            self.DMD.disconnect_DMD()

        QtWidgets.QApplication.quit()
        event.accept()

    def init_gui(self):
        self.setWindowTitle("Coordinate control")

        self.layout = QGridLayout()
        self.setMinimumSize(1250, 1000)
        self.setLayout(self.layout)

        self.image_mask_stack = QTabWidget()

        self.selection_view = DrawingWidget(self)
        self.selection_view.enable_drawing(True)
        self.selection_view.getView().setLimits(xMin=0,
                                                xMax=2048,
                                                yMin=0,
                                                yMax=2048,
                                                minXRange=2048,
                                                minYRange=2048,
                                                maxXRange=2048,
                                                maxYRange=2048)
        self.selection_view.ui.roiBtn.hide()
        self.selection_view.ui.menuBtn.hide()
        self.selection_view.ui.normGroup.hide()
        self.selection_view.ui.roiPlot.hide()
        # self.selection_view.setImage(plt.imread('CoordinatesManager/Registration_Images/StageRegistration/Distance200_Offset0/A1.png'))

        self.mask_view = SquareImageView()
        self.mask_view.getView().setLimits(xMin=0,
                                           xMax=2048,
                                           yMin=0,
                                           yMax=2048,
                                           minXRange=2048,
                                           minYRange=2048,
                                           maxXRange=2048,
                                           maxYRange=2048)
        self.mask_view.ui.roiBtn.hide()
        self.mask_view.ui.menuBtn.hide()
        self.mask_view.ui.normGroup.hide()
        self.mask_view.ui.roiPlot.hide()
        self.mask_view.ui.histogram.hide()

        self.image_mask_stack.addTab(self.selection_view, 'Select')
        self.image_mask_stack.addTab(self.mask_view, 'Mask')

        self.layout.addWidget(self.image_mask_stack, 0, 0, 5, 1)

        # ---------------------- Mask generation Container  --------------

        self.maskGeneratorContainer = roundQGroupBox()
        self.maskGeneratorContainer.setFixedSize(320, 220)
        self.maskGeneratorContainer.setTitle("Mask generator")
        self.maskGeneratorContainerLayout = QGridLayout()

        self.maskGeneratorLayout = QGridLayout()
        self.maskGeneratorContainer.setLayout(
            self.maskGeneratorContainerLayout)

        self.loadMaskFromFileButton = QPushButton('Open mask')
        self.loadMaskFromFileButton.clicked.connect(self.load_mask_from_file)

        self.addRoiButton = QPushButton("Add ROI")
        self.createMaskButton = QPushButton("Create mask")
        self.removeSelectionButton = QPushButton("Remove ROIs")
        self.addRoiButton.clicked.connect(self.add_polygon_roi)

        self.createMaskButton.clicked.connect(self.create_mask)
        self.removeSelectionButton.clicked.connect(self.remove_selection)

        self.maskGeneratorContainerLayout.addWidget(self.addRoiButton, 1, 0)
        self.maskGeneratorContainerLayout.addWidget(self.createMaskButton, 2,
                                                    0)
        self.maskGeneratorContainerLayout.addWidget(self.removeSelectionButton,
                                                    1, 1)
        self.selectionOptionsContainer = roundQGroupBox()
        self.selectionOptionsContainer.setTitle('Options')
        self.selectionOptionsLayout = QGridLayout()
        self.fillContourButton = QCheckBox()
        self.invertMaskButton = QCheckBox()
        self.thicknessSpinBox = QSpinBox()
        self.thicknessSpinBox.setRange(1, 25)
        self.selectionOptionsLayout.addWidget(QLabel('Fill contour:'), 0, 0)
        self.selectionOptionsLayout.addWidget(self.fillContourButton, 0, 1)
        self.selectionOptionsLayout.addWidget(QLabel('Invert mask:'), 1, 0)
        self.selectionOptionsLayout.addWidget(self.invertMaskButton, 1, 1)
        self.selectionOptionsLayout.addWidget(QLabel('Thickness:'), 2, 0)
        self.selectionOptionsLayout.addWidget(self.thicknessSpinBox, 2, 1)
        self.selectionOptionsContainer.setLayout(self.selectionOptionsLayout)

        self.snapFovButton = QPushButton('Image FOV')
        self.snapFovButton.clicked.connect(self.snap_fov)

        self.maskGeneratorContainerLayout.addWidget(self.snapFovButton, 0, 0,
                                                    1, 1)
        self.maskGeneratorContainerLayout.addWidget(
            self.loadMaskFromFileButton, 0, 1, 1, 1)
        self.maskGeneratorContainerLayout.addWidget(
            self.selectionOptionsContainer, 2, 1, 2, 1)

        self.layout.addWidget(self.maskGeneratorContainer, 0, 1)

        self.DMDWidget = DMDWidget.DMDWidget()
        self.layout.addWidget(self.DMDWidget, 1, 1)

        self.DMDWidget.sig_request_mask_coordinates.connect(
            lambda: self.cast_mask_coordinates('dmd'))
        self.sig_cast_mask_coordinates_to_dmd.connect(
            self.DMDWidget.receive_mask_coordinates)
        self.DMDWidget.sig_start_registration.connect(
            lambda: self.sig_start_registration.emit())
        self.DMDWidget.sig_finished_registration.connect(
            lambda: self.sig_finished_registration.emit())

        self.GalvoWidget = GalvoWidget.GalvoWidget()
        self.layout.addWidget(self.GalvoWidget, 2, 1)

        self.GalvoWidget.sig_request_mask_coordinates.connect(
            lambda: self.cast_mask_coordinates('galvo'))
        self.sig_cast_mask_coordinates_to_galvo.connect(
            self.GalvoWidget.receive_mask_coordinates)
        self.GalvoWidget.sig_start_registration.connect(
            lambda: self.sig_start_registration.emit())
        self.GalvoWidget.sig_finished_registration.connect(
            lambda: self.sig_finished_registration.emit())

        self.ManualRegistrationWidget = ManualRegistration.ManualRegistrationWidget(
            self)
        self.ManualRegistrationWidget.sig_request_camera_image.connect(
            self.cast_camera_image)
        self.sig_cast_camera_image.connect(
            self.ManualRegistrationWidget.receive_camera_image)

        self.layout.addWidget(self.ManualRegistrationWidget, 3, 1)

        self.StageRegistrationWidget = StageRegistrationWidget.StageWidget(
            self)
        self.layout.addWidget(self.StageRegistrationWidget, 4, 1)

    def cast_transformation_to_DMD(self, transformation, laser):
        self.DMDWidget.transform[laser] = transformation
        self.DMDWidget.save_transformation()

    def cast_transformation_to_galvos(self, sig):
        transformation = sig
        self.GalvoWidget.transform = transformation
        self.GalvoWidget.save_transformation()

    def cast_camera_image(self):
        image = self.selection_view.image
        if type(image) == np.ndarray:
            self.sig_cast_camera_image.emit(image)

    def snap_fov(self):
        self.DMDWidget.interupt_projection()

        self.DMDWidget.project_full_white()

        self.cam = CamActuator()
        self.cam.initializeCamera()
        image = self.cam.SnapImage(0.04)
        self.cam.Exit()
        self.selection_view.setImage(image)

    def cast_mask_coordinates(self, receiver):
        list_of_rois = self.get_list_of_rois()

        sig = [
            list_of_rois,
            self.fillContourButton.isChecked(),
            self.thicknessSpinBox.value(),
            self.invertMaskButton.isChecked()
        ]

        if receiver == 'dmd':
            self.sig_cast_mask_coordinates_to_dmd.emit(sig)
        else:
            self.sig_cast_mask_coordinates_to_galvo.emit(sig)

    def get_list_of_rois(self):
        view = self.selection_view
        list_of_rois = []

        for roi in view.roilist:
            roi_handle_positions = roi.getLocalHandlePositions()
            roi_origin = roi.pos()

            for idx, pos in enumerate(roi_handle_positions):
                roi_handle_positions[idx] = pos[1]

            num_vertices = len(roi_handle_positions)
            vertices = np.zeros([num_vertices, 2])

            for idx, vertex in enumerate(roi_handle_positions):
                vertices[idx, :] = np.array(
                    [vertex.x() + roi_origin.x(),
                     vertex.y() + roi_origin.y()])

            list_of_rois.append(vertices)

        return list_of_rois

    def create_mask(self):
        flag_fill_contour = self.fillContourButton.isChecked()
        flag_invert_mode = self.invertMaskButton.isChecked()
        contour_thickness = self.thicknessSpinBox.value()

        list_of_rois = self.get_list_of_rois()

        self.mask = ProcessImage.CreateBinaryMaskFromRoiCoordinates(list_of_rois, \
                                                       fill_contour = flag_fill_contour, \
                                                       contour_thickness = contour_thickness, \
                                                       invert_mask = flag_invert_mode)

        self.mask_view.setImage(self.mask)

    def remove_selection(self):
        self.selection_view.clear_rois()

    def set_camera_image(self, sig):
        self.selection_view.setImage(sig)

    def add_polygon_roi(self):
        view = self.selection_view

        x = (view.getView().viewRect().x()) * 0.3
        y = (view.getView().viewRect().y()) * 0.3
        a = (view.getView().viewRect().width() + x) * 0.3
        b = (view.getView().viewRect().height() + y) * 0.3
        c = (view.getView().viewRect().width() + x) * 0.7
        d = (view.getView().viewRect().height() + y) * 0.7
        polygon_roi = pg.PolyLineROI([[a, b], [c, b], [c, d], [a, d]],
                                     pen=view.pen,
                                     closed=True,
                                     movable=True,
                                     removable=True)

        view.getView().addItem(polygon_roi)
        view.append_to_roilist(polygon_roi)

    def load_mask_from_file(self):
        """
        Open a file manager to browse through files, load image file
        """
        self.loadFileName, _ = QtWidgets.QFileDialog.getOpenFileName(
            self, 'Select file', './CoordinateManager/Images/',
            "(*.png, *.tiff, *.jpg)")
        try:
            image = plt.imread(self.loadFileName)

            self.mask = image
            self.mask_view.setImage(self.mask)
        except:
            print('fail to load file.')
Exemple #7
0
class FocusFinder():
    
    def __init__(self, source_of_image = "PMT", init_search_range = 0.010, total_step_number = 5, imaging_conditions = {'edge_volt':5}, \
                 motor_handle = None, camera_handle = None, twophoton_handle = None, *args, **kwargs):
        """
        

        Parameters
        ----------
        source_of_image : string, optional
            The input source of image. The default is PMT.
        init_search_range : int, optional
            The step size when first doing coarse searching. The default is 0.010.
        total_step_number : int, optional
            Number of steps in total to find optimal focus. The default is 5.
        imaging_conditions : list
            Parameters for imaging.
            For PMT, it specifies the scanning voltage.
            For camera, it specifies the AOTF voltage and exposure time.
        motor_handle : TYPE, optional
            Handle to control PI motor. The default is None.
        twophoton_handle : TYPE, optional
            Handle to control Insight X3. The default is None.

        Returns
        -------
        None.

        """
        super().__init__(*args, **kwargs)
        
        # The step size when first doing coarse searching.
        self.init_search_range = init_search_range
        
        # Number of steps in total to find optimal focus.
        self.total_step_number = total_step_number
        
        # Parameters for imaging.
        self.imaging_conditions = imaging_conditions
        
        if motor_handle == None:
            # Connect the objective if the handle is not provided.
            self.pi_device_instance = PIMotor()
        else:
            self.pi_device_instance = motor_handle
        
        # Current position of the focus.
        self.current_pos = self.pi_device_instance.GetCurrentPos()

        # Number of steps already tried.
        self.steps_taken = 0
        # The focus degree of previous position.
        self.previous_degree_of_focus = 0
        # Number of going backwards.
        self.turning_point = 0
        # The input source of image.
        self.source_of_image = source_of_image
        if source_of_image == "PMT":
            self.galvo = RasterScan(Daq_sample_rate = 500000, edge_volt = self.imaging_conditions['edge_volt'])
        elif source_of_image == "Camera":
            if camera_handle == None:
                # If no camera instance fed in, initialize camera.
                self.HamamatsuCam_ins = CamActuator()
                self.HamamatsuCam_ins.initializeCamera()
            else:
                self.HamamatsuCam_ins = camera_handle
    
    def gaussian_fit(self, move_to_focus = True):
        
        # The upper edge.
        upper_position = self.current_pos + self.init_search_range
        # The lower edge.
        lower_position = self.current_pos - self.init_search_range
        
        # Generate the sampling positions.
        sample_positions = np.linspace(lower_position, upper_position, self.total_step_number)
        
        degree_of_focus_list = []
        for each_pos in sample_positions:
            # Go through each position and write down the focus degree.
            degree_of_focus = self.evaluate_focus(round(each_pos, 6))
            degree_of_focus_list.append(degree_of_focus)
        print(degree_of_focus_list)
        
        try:
            interpolated_fitted_curve = ProcessImage.gaussian_fit(degree_of_focus_list)

            # Generate the inpterpolated new focus position axis.
            x_axis_new = np.linspace(lower_position, upper_position, len(interpolated_fitted_curve))
            
            # Generate a dictionary and find the position where has the highest focus degree.
            max_focus_pos = dict(zip(interpolated_fitted_curve, x_axis_new))[np.amax(interpolated_fitted_curve)]
            
            if True: # Plot the fitting.
                plt.plot(sample_positions, np.asarray(degree_of_focus_list),'b+:',label='data')
                plt.plot(x_axis_new, interpolated_fitted_curve,'ro:',label='fit')
                plt.legend()
                plt.title('Fig. Fit for focus degree')
                plt.xlabel('Position')
                plt.ylabel('Focus degree')
                plt.show()
            
            max_focus_pos = round(max_focus_pos, 6)
            print(max_focus_pos)
            self.pi_device_instance.move(max_focus_pos)
            # max_focus_pos_focus_degree = self.evaluate_focus(round(max_focus_pos, 6))
        except:
            print("Fitting failed. Find max in the list.")
            
            max_focus_pos = sample_positions[degree_of_focus_list.index(max(degree_of_focus_list))]
            print(max_focus_pos)
            
        if move_to_focus == True:
            self.pi_device_instance.move(max_focus_pos)
            
        return max_focus_pos
        
    def bisection(self):
        """
        Bisection way of finding focus.

        Returns
        -------
        mid_position : float
            DESCRIPTION.

        """
        # The upper edge in which we run bisection.
        upper_position = self.current_pos + self.init_search_range
        # The lower edge in which we run bisection.
        lower_position = self.current_pos - self.init_search_range

        for step_index in range(1, self.total_step_number + 1):   
            # In each step of bisection finding.
            
            # In the first round, get degree of focus at three positions.
            if step_index == 1:
                # Get degree of focus in the mid.
                mid_position = (upper_position + lower_position)/2
                degree_of_focus_mid = self.evaluate_focus(mid_position)
                print("mid focus degree: {}".format(round(degree_of_focus_mid, 5)))
                
                # Break the loop if focus degree is below threshold which means
                # that there's no cell in image.
                if not ProcessImage.if_theres_cell(self.galvo_image.astype('float32')):
                    print('no cell')
                    mid_position = False
                    break

                # Move to top and evaluate.
                degree_of_focus_up = self.evaluate_focus(obj_position = upper_position)
                print("top focus degree: {}".format(round(degree_of_focus_up, 5)))
                # Move to bottom and evaluate.
                degree_of_focus_low = self.evaluate_focus(obj_position = lower_position)
                print("bot focus degree: {}".format(round(degree_of_focus_low, 5)))
                # Sorting dicitonary of degrees in ascending.
                biesection_range_dic = {"top":[upper_position, degree_of_focus_up], 
                                        "bot":[lower_position, degree_of_focus_low]}
                
            # In the next rounds, only need to go to center and update boundaries.
            elif step_index > 1:
                # The upper edge in which we run bisection.
                upper_position = biesection_range_dic["top"][0]
                # The lower edge in which we run bisection.
                lower_position = biesection_range_dic["bot"][0]
                
                # Get degree of focus in the mid.
                mid_position = (upper_position + lower_position)/2
                degree_of_focus_mid = self.evaluate_focus(mid_position)
                
                print("Current focus degree: {}".format(round(degree_of_focus_mid, 5)))
                
            # If sits in upper half, make the middle values new bottom.
            if biesection_range_dic["top"][1] > biesection_range_dic["bot"][1]:
                biesection_range_dic["bot"] = [mid_position, degree_of_focus_mid]
            else:
                biesection_range_dic["top"] = [mid_position, degree_of_focus_mid]
            
            print("The upper pos: {}; The lower: {}".format(biesection_range_dic["top"][0], biesection_range_dic["bot"][0]))
            
        return mid_position
                
                
    
    def evaluate_focus(self, obj_position = None):
        """
        Evaluate the focus degree of certain objective position.

        Parameters
        ----------
        obj_position : float, optional
            The target objective position. The default is None.

        Returns
        -------
        degree_of_focus : float
            Degree of focus.

        """
        
        if obj_position != None:
            self.pi_device_instance.move(obj_position)
            
        # Get the image.
        if self.source_of_image == "PMT":
            self.galvo_image = self.galvo.run()
            plt.figure()
            plt.imshow(self.galvo_image)
            plt.show()
            
            if False:
                with skimtiff.TiffWriter(os.path.join(r'M:\tnw\ist\do\projects\Neurophotonics\Brinkslab\Data\Xin\2020-11-17 gaussian fit auto-focus cells\trial_11', str(obj_position).replace(".", "_")+ '.tif')) as tif:                
                    tif.save(self.galvo_image.astype('float32'), compress=0)
                            
            degree_of_focus = ProcessImage.local_entropy(self.galvo_image.astype('float32'))
            
        elif self.source_of_image == "Camera":
            # First configure the AOTF.
            self.AOTF_runner = DAQmission()
            # Find the AOTF channel key
            for key in self.imaging_conditions:
                if 'AO' in key:
                    # like '488AO'
                    AOTF_channel_key = key
            
            # Set the AOTF first.
            self.AOTF_runner.sendSingleDigital('blankingall', True)
            self.AOTF_runner.sendSingleAnalog(AOTF_channel_key, self.imaging_conditions[AOTF_channel_key])
            
            # Snap an image from camera
            self.camera_image = self.HamamatsuCam_ins.SnapImage(self.imaging_conditions['exposure_time'])
            time.sleep(0.5)
            
            # Set back AOTF
            self.AOTF_runner.sendSingleDigital('blankingall', False)
            self.AOTF_runner.sendSingleAnalog(AOTF_channel_key, 0)
            
            plt.figure()
            plt.imshow(self.camera_image)
            plt.show()
            
            if False:
                with skimtiff.TiffWriter(os.path.join(r'M:\tnw\ist\do\projects\Neurophotonics\Brinkslab\Data\Xin\2021-03-06 Camera AF\beads', str(obj_position).replace(".", "_")+ '.tif')) as tif:                
                    tif.save(self.camera_image.astype('float32'), compress=0)
                            
            degree_of_focus = ProcessImage.variance_of_laplacian(self.camera_image.astype('float32'))
                
        time.sleep(0.2)
        
        return degree_of_focus
Exemple #8
0
class MainGUI(QWidget):

    # signal_DMDmask is cictionary with laser specification as key and binary mask as content.
    signal_DMDmask = pyqtSignal(dict)
    signal_DMDcontour = pyqtSignal(list)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        os.chdir('./')  # Set directory to current folder.
        self.setFont(QFont("Arial"))

        #        self.setMinimumSize(900, 1020)
        self.setWindowTitle("Cell Selection")
        self.layout = QGridLayout(self)

        self.roi_list_freehandl_added = []
        self.selected_ML_Index = []
        self.selected_cells_infor_dict = {}

        self.mask_color_multiplier = [1, 1, 0]
        # =============================================================================
        #         Container for image display
        # =============================================================================
        graphContainer = StylishQT.roundQGroupBox()
        graphContainerLayout = QGridLayout()

        self.Imgviewtabs = QTabWidget()

        MLmaskviewBox = QWidget()
        MLmaskviewBoxLayout = QGridLayout()

        self.Matdisplay_Figure = Figure()
        self.Matdisplay_Canvas = FigureCanvas(self.Matdisplay_Figure)
        self.Matdisplay_Canvas.setFixedWidth(500)
        self.Matdisplay_Canvas.setFixedHeight(500)
        self.Matdisplay_Canvas.mpl_connect('button_press_event', self._onclick)

        self.Matdisplay_toolbar = NavigationToolbar(self.Matdisplay_Canvas,
                                                    self)

        MLmaskviewBoxLayout.addWidget(self.Matdisplay_toolbar, 0, 0)
        MLmaskviewBoxLayout.addWidget(self.Matdisplay_Canvas, 1, 0)

        MLmaskviewBox.setLayout(MLmaskviewBoxLayout)

        self.Imgviewtabs.addTab(MLmaskviewBox, "MaskRCNN")

        # =============================================================================
        #         Mask editing tab
        # =============================================================================
        MLmaskEditBox = QWidget()
        MLmaskEditBoxLayout = QGridLayout()

        self.Mask_edit_view = DrawingWidget(self)
        self.Mask_edit_view.enable_drawing(False)  # Disable drawing first
        #        self.Mask_edit_view = pg.ImageView()
        #        self.Mask_edit_view.getView().setLimits(xMin = 0, xMax = 2048, yMin = 0, yMax = 2048, minXRange = 2048, minYRange = 2048, maxXRange = 2048, maxYRange = 2048)
        self.Mask_edit_viewItem = self.Mask_edit_view.getImageItem()

        #        self.ROIitem = pg.PolyLineROI([[0,0], [80,0], [80,80], [0,80]], closed=True)
        self.Mask_edit_view_getView = self.Mask_edit_view.getView()
        #        self.Mask_edit_view_getView.addItem(self.ROIitem)

        self.Mask_edit_view.ui.roiBtn.hide()
        self.Mask_edit_view.ui.menuBtn.hide()
        self.Mask_edit_view.ui.normGroup.hide()
        self.Mask_edit_view.ui.roiPlot.hide()

        MLmaskEditBoxLayout.addWidget(self.Mask_edit_view, 0, 0)

        MLmaskEditBox.setLayout(MLmaskEditBoxLayout)

        self.Imgviewtabs.addTab(MLmaskEditBox, "Mask edit")

        graphContainerLayout.addWidget(self.Imgviewtabs, 0, 0)
        graphContainer.setLayout(graphContainerLayout)

        # =============================================================================
        #         Operation container
        # =============================================================================
        operationContainer = StylishQT.roundQGroupBox()
        operationContainerLayout = QGridLayout()

        self.init_ML_button = QPushButton('Initialize ML', self)
        operationContainerLayout.addWidget(self.init_ML_button, 0, 0)
        self.init_ML_button.clicked.connect(self.init_ML)

        #---------------------Load image from file-----------------------------
        self.textbox_loadimg = QLineEdit(self)
        operationContainerLayout.addWidget(self.textbox_loadimg, 1, 0)

        self.button_import_img_browse = QPushButton('Browse', self)
        operationContainerLayout.addWidget(self.button_import_img_browse, 1, 1)
        self.button_import_img_browse.clicked.connect(self.get_img_file_tif)

        self.run_ML_button = QPushButton('Analysis', self)
        operationContainerLayout.addWidget(self.run_ML_button, 2, 0)
        self.run_ML_button.clicked.connect(self.run_ML_onImg_and_display)

        self.generate_MLmask_button = QPushButton('Mask', self)
        operationContainerLayout.addWidget(self.generate_MLmask_button, 2, 1)
        self.generate_MLmask_button.clicked.connect(self.generate_MLmask)

        self.update_MLmask_button = QPushButton('Update mask', self)
        operationContainerLayout.addWidget(self.update_MLmask_button, 3, 0)
        self.update_MLmask_button.clicked.connect(self.update_mask)

        self.enable_modify_MLmask_button = QPushButton('Enable free-hand',
                                                       self)
        self.enable_modify_MLmask_button.setCheckable(True)
        operationContainerLayout.addWidget(self.enable_modify_MLmask_button, 4,
                                           0)
        self.enable_modify_MLmask_button.clicked.connect(self.enable_free_hand)

        #        self.modify_MLmask_button = QPushButton('Add patch', self)
        #        operationContainerLayout.addWidget(self.modify_MLmask_button, 4, 1)
        #        self.modify_MLmask_button.clicked.connect(self.addedROIitem_to_Mask)

        self.clear_roi_button = QPushButton('Clear ROIs', self)
        operationContainerLayout.addWidget(self.clear_roi_button, 5, 0)
        self.clear_roi_button.clicked.connect(self.clear_edit_roi)

        #        self.maskLaserComboBox = QComboBox()
        #        self.maskLaserComboBox.addItems(['640', '532', '488'])
        #        operationContainerLayout.addWidget(self.maskLaserComboBox, 6, 0)
        #
        #        self.generate_transformed_mask_button = QPushButton('Transform mask', self)
        #        operationContainerLayout.addWidget(self.generate_transformed_mask_button, 6, 1)
        #        self.generate_transformed_mask_button.clicked.connect(self.generate_transformed_mask)

        self.emit_transformed_mask_button = QPushButton('Emit mask', self)
        operationContainerLayout.addWidget(self.emit_transformed_mask_button,
                                           7, 1)
        self.emit_transformed_mask_button.clicked.connect(
            self.emit_mask_contour)

        operationContainer.setLayout(operationContainerLayout)

        # =============================================================================
        #         Mask para container
        # =============================================================================
        MaskparaContainer = StylishQT.roundQGroupBox()
        MaskparaContainerContainerLayout = QGridLayout()

        #----------------------------------------------------------------------
        self.fillContourButton = QCheckBox()
        self.invertMaskButton = QCheckBox()
        self.thicknessSpinBox = QSpinBox()
        self.thicknessSpinBox.setRange(1, 25)
        MaskparaContainerContainerLayout.addWidget(QLabel('Fill contour:'), 0,
                                                   0)
        MaskparaContainerContainerLayout.addWidget(self.fillContourButton, 0,
                                                   1)
        MaskparaContainerContainerLayout.addWidget(QLabel('Invert mask:'), 1,
                                                   0)
        MaskparaContainerContainerLayout.addWidget(self.invertMaskButton, 1, 1)
        MaskparaContainerContainerLayout.addWidget(QLabel('Thickness:'), 2, 0)
        MaskparaContainerContainerLayout.addWidget(self.thicknessSpinBox, 2, 1)

        MaskparaContainer.setLayout(MaskparaContainerContainerLayout)
        # =============================================================================
        #         Device operation container
        # =============================================================================
        deviceOperationContainer = StylishQT.roundQGroupBox()
        deviceOperationContainerLayout = QGridLayout()

        #----------------------------------------------------------------------
        self.CamExposureBox = QDoubleSpinBox(self)
        self.CamExposureBox.setDecimals(6)
        self.CamExposureBox.setMinimum(0)
        self.CamExposureBox.setMaximum(100)
        self.CamExposureBox.setValue(0.001501)
        self.CamExposureBox.setSingleStep(0.001)
        deviceOperationContainerLayout.addWidget(self.CamExposureBox, 0, 1)
        deviceOperationContainerLayout.addWidget(QLabel("Exposure time:"), 0,
                                                 0)

        cam_snap_button = QPushButton('Cam snap', self)
        deviceOperationContainerLayout.addWidget(cam_snap_button, 0, 2)
        cam_snap_button.clicked.connect(self.cam_snap)

        cam_snap_button = QPushButton('Cam snap', self)
        deviceOperationContainerLayout.addWidget(cam_snap_button, 0, 2)
        cam_snap_button.clicked.connect(self.cam_snap)

        deviceOperationContainer.setLayout(deviceOperationContainerLayout)

        self.layout.addWidget(graphContainer, 0, 0, 3, 1)
        self.layout.addWidget(operationContainer, 0, 1)
        self.layout.addWidget(MaskparaContainer, 1, 1)
        self.layout.addWidget(deviceOperationContainer, 2, 1)
        self.setLayout(self.layout)

    #%%
    # =============================================================================
    #     MaskRCNN detection part
    # =============================================================================
#    @run_in_thread

    def init_ML(self):
        # Initialize the detector instance and load the model.
        self.ProcessML = ProcessImageML()

    def get_img_file_tif(self):
        self.img_tif_filePath, _ = QtWidgets.QFileDialog.getOpenFileName(
            self, 'Single File',
            'M:/tnw/ist/do/projects/Neurophotonics/Brinkslab/Data', "(*.tif)")
        self.textbox_loadimg.setText(self.img_tif_filePath)

        if self.img_tif_filePath != None:
            self.Rawimage = imread(self.img_tif_filePath)

            self.MLtargetedImg_raw = self.Rawimage.copy()

            self.MLtargetedImg = self.convert_for_MaskRCNN(
                self.MLtargetedImg_raw)

            self.show_raw_image(self.MLtargetedImg)

            self.addedROIitemMask = np.zeros(
                (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))
            self.MLmask = np.zeros(
                (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))

    def show_raw_image(self, image):
        # display a single image
        try:
            self.Matdisplay_Figure.clear()
        except:
            pass
        ax1 = self.Matdisplay_Figure.add_subplot(111)
        ax1.set_xticks([])
        ax1.set_yticks([])
        ax1.imshow(image)

        self.Matdisplay_Figure.tight_layout()
        self.Matdisplay_Canvas.draw()

        RGB_image = gray2rgb(image)
        self.Mask_edit_viewItem.setImage(RGB_image)

    def convert_for_MaskRCNN(self, input_img):
        """Convert the image size and bit-depth to make it suitable for MaskRCNN detection."""
        if input_img.shape[0] > 1024 or input_img.shape[1] > 1024:
            resized_img = resize(input_img, [1024, 1024],
                                 preserve_range=True).astype(input_img.dtype)

        minval = np.min(resized_img)
        maxval = np.max(resized_img)

        return ((resized_img - minval) / (maxval - minval) * 255).astype(
            np.uint8)

    def run_ML_onImg_and_display(self):
        """Run MaskRCNN on input image"""
        self.Matdisplay_Figure.clear()
        ax1 = self.Matdisplay_Figure.add_subplot(111)

        # Depends on show_mask or not, the returned figure will be input raw image with mask or not.
        self.MLresults, self.Matdisplay_Figure_axis, self.unmasked_fig = self.ProcessML.DetectionOnImage(
            self.MLtargetedImg, axis=ax1, show_mask=False, show_bbox=False)
        self.Mask = self.MLresults['masks']
        self.Label = self.MLresults['class_ids']
        self.Score = self.MLresults['scores']
        self.Bbox = self.MLresults['rois']

        self.SelectedCellIndex = 0
        self.NumCells = int(len(self.Label))
        self.selected_ML_Index = []
        self.selected_cells_infor_dict = {}

        self.Matdisplay_Figure_axis.imshow(self.unmasked_fig.astype(np.uint8))

        self.Matdisplay_Figure.tight_layout()
        self.Matdisplay_Canvas.draw()

    #%%
    # =============================================================================
    #     Configure click event to add clicked cell mask
    # =============================================================================

    def _onclick(self, event):
        """Highlights the cell selected in the figure by the user when clicked on"""
        if self.NumCells > 0:
            ShapeMask = np.shape(self.Mask)
            # get coorinates at selected location in image coordinates
            if event.xdata == None or event.ydata == None:
                return
            xcoor = min(max(int(event.xdata), 0), ShapeMask[1])
            ycoor = min(max(int(event.ydata), 0), ShapeMask[0])

            # search for the mask coresponding to the selected cell
            for EachCell in range(self.NumCells):
                if self.Mask[ycoor, xcoor, EachCell]:
                    self.SelectedCellIndex = EachCell
                    break

            # highlight selected cell
            if self.SelectedCellIndex not in self.selected_ML_Index:
                # Get the selected cell's contour coordinates and mask patch
                self.contour_verts, self.Cell_patch = self.get_cell_polygon(
                    self.Mask[:, :, self.SelectedCellIndex])

                self.Matdisplay_Figure_axis.add_patch(self.Cell_patch)
                self.Matdisplay_Canvas.draw()

                self.selected_ML_Index.append(self.SelectedCellIndex)
                self.selected_cells_infor_dict['cell{}_verts'.format(
                    str(self.SelectedCellIndex))] = self.contour_verts
            else:
                # If click on the same cell
                self.Cell_patch.remove()
                self.Matdisplay_Canvas.draw()
                self.selected_ML_Index.remove(self.SelectedCellIndex)
                self.selected_cells_infor_dict.pop('cell{}_verts'.format(
                    str(self.SelectedCellIndex)))

    def get_cell_polygon(self, mask):
        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                               dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        for verts in contours:
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            contour_polygon = mpatches.Polygon(
                verts, facecolor=self.random_colors(1)[0])

        return contours, contour_polygon

    def random_colors(self, N, bright=True):
        """
        Generate random colors.
        To get visually distinct colors, generate them in HSV space then
        convert to RGB.
        """
        brightness = 1.0 if bright else 0.7
        hsv = [(i / N, 1, brightness) for i in range(N)]
        colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
        random.shuffle(colors)
        return colors

    #%%
    # =============================================================================
    #     For mask generation
    # =============================================================================

    def generate_MLmask(self):
        """ Generate binary mask with all selected cells"""
        self.MLmask = np.zeros(
            (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))

        if len(self.selected_ML_Index) > 0:
            for selected_index in self.selected_ML_Index:
                self.MLmask = np.add(self.MLmask, self.Mask[:, :,
                                                            selected_index])

            self.intergrate_into_final_mask()

            self.add_rois_of_selected()

        else:
            self.intergrate_into_final_mask()
#            self.Mask_edit_viewItem.setImage(gray2rgb(self.MLtargetedImg))

    def add_rois_of_selected(self):
        """
        Using find_contours to get list of contour coordinates in the binary mask, and then generate polygon rois based on these coordinates.
        """

        for selected_index in self.selected_ML_Index:

            contours = self.selected_cells_infor_dict['cell{}_verts'.format(
                str(selected_index))]
            #            contours = find_contours(self.Mask[:,:,selected_index], 0.5) # Find iso-valued contours in a 2D array for a given level value.

            for n, contour in enumerate(contours):
                contour_coord_array = contours[n]
                #Swap columns
                contour_coord_array[:,
                                    0], contour_coord_array[:,
                                                            1] = contour_coord_array[:,
                                                                                     1], contour_coord_array[:, 0].copy(
                                                                                     )

                #Down sample the coordinates otherwise it will be too dense.
                contour_coord_array_del = np.delete(
                    contour_coord_array,
                    np.arange(2, contour_coord_array.shape[0] - 3, 2), 0)

                self.selected_cells_infor_dict['cell{}_ROIitem'.format(str(selected_index))] = \
                pg.PolyLineROI(positions=contour_coord_array_del, closed=True)

                self.Mask_edit_view.getView().addItem(
                    self.selected_cells_infor_dict['cell{}_ROIitem'.format(
                        str(selected_index))])

    def update_mask(self):
        """
        Regenerate the masks for MaskRCNN and free-hand added (in case they are changed), and show in imageview.
        
        !!!ISSUE: getLocalHandlePositions: moving handles changes the position read out, dragging roi as a whole doesn't.
        """

        # Binary mask from ML detection
        if len(self.selected_ML_Index) > 0:
            # Delete items in dictionary that are not roi items
            roi_dict = self.selected_cells_infor_dict.copy()
            del_key_list = []
            for key in roi_dict:
                print(key)
                if 'ROIitem' not in key:
                    del_key_list.append(key)
            for key in del_key_list:
                del roi_dict[key]

            self.MLmask = ProcessImage.ROIitem2Mask(
                roi_dict,
                mask_resolution=(self.MLtargetedImg.shape[0],
                                 self.MLtargetedImg.shape[1]))
        # Binary mask of added rois
        self.addedROIitemMask = ProcessImage.ROIitem2Mask(
            self.roi_list_freehandl_added,
            mask_resolution=(self.MLtargetedImg.shape[0],
                             self.MLtargetedImg.shape[1]))

        self.intergrate_into_final_mask()

#        if type(self.roi_list_freehandl_added) is list:
#            for ROIitem in self.roi_list_freehandl_added:
#
#                ROIitem.sigHoverEvent.connect(lambda: self.show_roi_detail(ROIitem))
#
#        plt.figure()
#        plt.imshow(self.addedROIitemMask)
#        plt.show()
# =============================================================================
#     For free-hand rois
# =============================================================================

    def enable_free_hand(self):
        if self.enable_modify_MLmask_button.isChecked():
            self.Mask_edit_view.enable_drawing(True)
        else:
            self.Mask_edit_view.enable_drawing(False)

    def add_freehand_roi(self, roi):
        # For drawwidget
        self.roi_list_freehandl_added.append(roi)

    def clear_edit_roi(self):
        """
        Clean up all the free-hand rois.
        """

        for roi in self.roi_list_freehandl_added:
            self.Mask_edit_view.getView().removeItem(roi)

        self.roi_list_freehandl_added = []

        if len(self.selected_cells_infor_dict) > 0:
            # Remove all selected masks
            for roiItemkey in self.selected_cells_infor_dict:
                if 'ROIitem' in roiItemkey:
                    self.Mask_edit_view.getView().removeItem(
                        self.selected_cells_infor_dict[roiItemkey])

        self.selected_cells_infor_dict = {}
        self.MLmask = np.zeros(
            (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))
        self.intergrate_into_final_mask()

    def intergrate_into_final_mask(self):
        # Binary mask of added rois
        self.addedROIitemMask = ProcessImage.ROIitem2Mask(
            self.roi_list_freehandl_added,
            mask_resolution=(self.MLtargetedImg.shape[0],
                             self.MLtargetedImg.shape[1]))
        #Display the RGB mask, ML mask plus free-hand added.
        self.Mask_edit_viewItem.setImage(gray2rgb(self.addedROIitemMask) * self.mask_color_multiplier + \
                                         gray2rgb(self.MLmask) * self.mask_color_multiplier + gray2rgb(self.MLtargetedImg))

        self.final_mask = self.MLmask + self.addedROIitemMask

        # In case the input image is 2048*2048, and it is resized to fit in MaskRCNN, need to convert back to original size for DMD tranformation.
        if self.final_mask.shape[0] != self.Rawimage.shape[
                0] or self.final_mask.shape[1] != self.Rawimage.shape[1]:
            self.final_mask = resize(
                self.final_mask,
                [self.Rawimage.shape[0], self.Rawimage.shape[1]],
                preserve_range=True).astype(self.final_mask.dtype)
#        self.final_mask = np.where(self.final_mask <= 1, self.final_mask, int(1))

        plt.figure()
        plt.imshow(self.final_mask)
        plt.show()

    # =============================================================================
    # For DMD transformation and mask generation
    # =============================================================================
    def generate_transformed_mask(self):
        self.read_transformations_from_file()
        #        self.transform_to_DMD_mask(laser = self.maskLaserComboBox.currentText(), dict_transformations = self.dict_transformations)
        target_laser = self.maskLaserComboBox.currentText()
        self.final_DMD_mask = self.finalmask_to_DMD_mask(
            laser=target_laser, dict_transformations=self.dict_transformations)

        plt.figure()
        plt.imshow(self.final_DMD_mask)
        plt.show()

    def emit_mask_contour(self):
        """Use find_contours to get a list of (n,2)-ndarrays consisting of n (row, column) coordinates along the contour,
           and then feed the list of signal:[list_of_rois, flag_fill_contour, contour_thickness, flag_invert_mode] to the 
           receive_mask_coordinates function in DMDWidget.
        """
        contours = find_contours(self.final_mask, 0.5)

        sig = [
            contours,
            self.fillContourButton.isChecked(),
            self.thicknessSpinBox.value(),
            self.invertMaskButton.isChecked()
        ]

        self.signal_DMDcontour.emit(sig)

    def emit_mask(self):
        target_laser = self.maskLaserComboBox.currentText()
        final_DMD_mask_dict = {}
        final_DMD_mask_dict['camera-dmd-' + target_laser] = self.final_DMD_mask

        self.signal_DMDmask.emit(final_DMD_mask_dict)

    def read_transformations_from_file(self):
        try:
            with open(
                    r'M:\tnw\ist\do\projects\Neurophotonics\Brinkslab\People\Xin Meng\Code\Python_test\DMDManager\Registration\transformation.txt',
                    'r') as json_file:
                self.dict_transformations = json.load(json_file)
        except:
            print(
                'No transformation could be loaded from previous registration run.'
            )
            return

#    def transform_to_DMD_mask(self, laser, dict_transformations, flag_fill_contour = True, contour_thickness = 1, flag_invert_mode = False, mask_resolution = (1024, 768)):
#        """
#        Get roi vertices from all roi items and perform the transformation, and then create the mask for DMD.
#        """
#
#        #list of roi vertices each being (n,2) numpy array for added rois
#        if len(self.roi_list_freehandl_added) > 0:
#            self.addedROIitem_vertices = ProcessImage.ROIitem2Vertices(self.roi_list_freehandl_added)
#            #addedROIitem_vertices needs to be seperated to be inidividual (n,2) np.array
#                self.ROIitems_mask_transformed = ProcessImage.vertices_to_DMD_mask(self.addedROIitem_vertices, laser, dict_transformations, flag_fill_contour = True, contour_thickness = 1,\
#                                                                          flag_invert_mode = False, mask_resolution = (1024, 768))
#
#        #Dictionary with (n,2) numpy array for clicked cells
#        if len(self.selected_cells_infor_dict) > 0:
#            #Convert dictionary to np.array
#            for roiItemkey in self.selected_cells_infor_dict:
#                #Each one is 'contours' from find_contour
#                if '_verts' in roiItemkey:
#                    self.selected_cells_infor_dict[roiItemkey]
#
#            self.MLitems_mask_transformed = ProcessImage.vertices_to_DMD_mask(self.selected_cells_infor_dict, laser, dict_transformations, flag_fill_contour = True, contour_thickness = 1,\
#                                                                      flag_invert_mode = False, mask_resolution = (1024, 768))
#
#        if len(self.roi_list_freehandl_added) > 0:
#            self.final_DMD_mask = self.ROIitems_mask_transformed + self.MLitems_mask_transformed
#            self.final_DMD_mask[self.final_DMD_mask>1] = 1
#        else:
#            self.final_DMD_mask = self.MLitems_mask_transformed
#
#        return self.final_DMD_mask

    def finalmask_to_DMD_mask(self,
                              laser,
                              dict_transformations,
                              flag_fill_contour=True,
                              contour_thickness=1,
                              flag_invert_mode=False,
                              mask_resolution=(1024, 768)):
        """
        Same goal as transform_to_DMD_mask, with input being the final binary mask and using find_contour to get all vertices and perform transformation,
        and then coordinates to mask.
        """

        self.final_DMD_mask = ProcessImage.binarymask_to_DMD_mask(self.final_mask, laser, dict_transformations, flag_fill_contour = True, \
                                                                  contour_thickness = 1, flag_invert_mode = False, mask_resolution = (1024, 768))

        return self.final_DMD_mask

    def closeEvent(self, event):
        QtWidgets.QApplication.quit()
        event.accept()


#    def apply_mask(self, image, mask, color, alpha=0.5):
#        """Apply the given mask to the image.
#        """
#        for c in range(3):
#            image[:, :, c] = np.where(mask == 1,
#                                      image[:, :, c] *
#                                      (1 - alpha) + alpha * color[c] * 255,
#                                      image[:, :, c])
#        return image
#%%
#    @run_in_thread

    def cam_snap(self):
        """Get a image from camera"""
        self.cam = CamActuator()
        self.cam.initializeCamera()

        exposure_time = self.CamExposureBox.value()
        self.Rawimage = self.cam.SnapImage(exposure_time)
        self.cam.Exit()
        print('Snap finished')

        self.MLtargetedImg_raw = self.Rawimage.copy()

        self.MLtargetedImg = self.convert_for_MaskRCNN(self.MLtargetedImg_raw)

        self.show_raw_image(self.MLtargetedImg)

        self.addedROIitemMask = np.zeros(
            (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))
        self.MLmask = np.zeros(
            (self.MLtargetedImg.shape[0], self.MLtargetedImg.shape[1]))
class GalvoRegistrator:
    def __init__(self, *args, **kwargs):
        self.cam = CamActuator()
        self.cam.initializeCamera()

    def registration(self, grid_points_x=3, grid_points_y=3):
        """
        By default, generate 9 galvo voltage coordinates from (-5,-5) to (5,5),
        take the camera images of these points, return a function matrix that 
        transforms camera_coordinates into galvo_coordinates using polynomial transform. 

        Parameters
        ----------
        grid_points_x : TYPE, optional
            DESCRIPTION. The default is 3.
        grid_points_y : TYPE, optional
            DESCRIPTION. The default is 3.

        Returns
        -------
        transformation : TYPE
            DESCRIPTION.

        """
        galvothread = DAQmission()
        readinchan = []

        x_coords = np.linspace(-10, 10, grid_points_x + 2)[1:-1]
        y_coords = np.linspace(-10, 10, grid_points_y + 2)[1:-1]

        xy_mesh = np.reshape(np.meshgrid(x_coords, y_coords), (2, -1),
                             order='F').transpose()

        galvo_coordinates = xy_mesh
        camera_coordinates = np.zeros((galvo_coordinates.shape))

        for i in range(galvo_coordinates.shape[0]):

            galvothread.sendSingleAnalog('galvosx', galvo_coordinates[i, 0])
            galvothread.sendSingleAnalog('galvosy', galvo_coordinates[i, 1])
            time.sleep(1)

            image = self.cam.SnapImage(0.06)
            plt.imsave(
                os.getcwd() +
                '/CoordinatesManager/Registration_Images/2P/image_' + str(i) +
                '.png', image)

            camera_coordinates[i, :] = readRegistrationImages.gaussian_fitting(
                image)

        print('Galvo Coordinate')
        print(galvo_coordinates)
        print('Camera coordinates')
        print(camera_coordinates)
        del galvothread
        self.cam.Exit()

        transformation_cam2galvo = CoordinateTransformations.polynomial2DFit(
            camera_coordinates, galvo_coordinates, order=1)

        transformation_galvo2cam = CoordinateTransformations.polynomial2DFit(
            galvo_coordinates, camera_coordinates, order=1)

        print('Transformation found for x:')
        print(transformation_cam2galvo[:, :, 0])
        print('Transformation found for y:')
        print(transformation_cam2galvo[:, :, 1])

        print('galvo2cam found for x:')
        print(transformation_galvo2cam[:, :, 0])
        print('galvo2cam found for y:')
        print(transformation_galvo2cam[:, :, 1])

        return transformation_cam2galvo