Esempio n. 1
0
class RenderThread(QThread):
    ColormapSize = 512

    renderedImage = Signal(QImage, float)

    def __init__(self, parent=None):
        super(RenderThread, self).__init__(parent)

        self.mutex = QMutex()
        self.condition = QWaitCondition()
        self.centerX = 0.0
        self.centerY = 0.0
        self.scaleFactor = 0.0
        self.resultSize = QSize()
        self.colormap = []

        self.restart = False
        self.abort = False

        for i in range(RenderThread.ColormapSize):
            self.colormap.append(self.rgbFromWaveLength(380.0 + (i * 400.0 / RenderThread.ColormapSize)))

    def stop(self):
        self.mutex.lock()
        self.abort = True
        self.condition.wakeOne()
        self.mutex.unlock()

        self.wait(2000)

    def render(self, centerX, centerY, scaleFactor, resultSize):
        locker = QMutexLocker(self.mutex)

        self.centerX = centerX
        self.centerY = centerY
        self.scaleFactor = scaleFactor
        self.resultSize = resultSize

        if not self.isRunning():
            self.start(QThread.LowPriority)
        else:
            self.restart = True
            self.condition.wakeOne()

    def run(self):
        while True:
            self.mutex.lock()
            resultSize = self.resultSize
            scaleFactor = self.scaleFactor
            centerX = self.centerX
            centerY = self.centerY
            self.mutex.unlock()

            halfWidth = resultSize.width() // 2
            halfHeight = resultSize.height() // 2
            image = QImage(resultSize, QImage.Format_RGB32)

            NumPasses = 8
            curpass = 0

            while curpass < NumPasses:
                MaxIterations = (1 << (2 * curpass + 6)) + 32
                Limit = 4
                allBlack = True

                for y in range(-halfHeight, halfHeight):
                    if self.restart:
                        break
                    if self.abort:
                        return

                    ay = 1j * (centerY + (y * scaleFactor))

                    for x in range(-halfWidth, halfWidth):
                        c0 = centerX + (x * scaleFactor) + ay
                        c = c0
                        numIterations = 0

                        while numIterations < MaxIterations:
                            numIterations += 1
                            c = c*c + c0
                            if abs(c) >= Limit:
                                break
                            numIterations += 1
                            c = c*c + c0
                            if abs(c) >= Limit:
                                break
                            numIterations += 1
                            c = c*c + c0
                            if abs(c) >= Limit:
                                break
                            numIterations += 1
                            c = c*c + c0
                            if abs(c) >= Limit:
                                break

                        if numIterations < MaxIterations:
                            image.setPixel(x + halfWidth, y + halfHeight,
                                           self.colormap[numIterations % RenderThread.ColormapSize])
                            allBlack = False
                        else:
                            image.setPixel(x + halfWidth, y + halfHeight, qRgb(0, 0, 0))

                if allBlack and curpass == 0:
                    curpass = 4
                else:
                    if not self.restart:
                        self.renderedImage.emit(image, scaleFactor)
                    curpass += 1

            self.mutex.lock()
            if not self.restart:
                self.condition.wait(self.mutex)
            self.restart = False
            self.mutex.unlock()

    def rgbFromWaveLength(self, wave):
        r = 0.0
        g = 0.0
        b = 0.0

        if wave >= 380.0 and wave <= 440.0:
            r = -1.0 * (wave - 440.0) / (440.0 - 380.0)
            b = 1.0
        elif wave >= 440.0 and wave <= 490.0:
            g = (wave - 440.0) / (490.0 - 440.0)
            b = 1.0
        elif wave >= 490.0 and wave <= 510.0:
            g = 1.0
            b = -1.0 * (wave - 510.0) / (510.0 - 490.0)
        elif wave >= 510.0 and wave <= 580.0:
            r = (wave - 510.0) / (580.0 - 510.0)
            g = 1.0
        elif wave >= 580.0 and wave <= 645.0:
            r = 1.0
            g = -1.0 * (wave - 645.0) / (645.0 - 580.0)
        elif wave >= 645.0 and wave <= 780.0:
            r = 1.0

        s = 1.0
        if wave > 700.0:
            s = 0.3 + 0.7 * (780.0 - wave) / (780.0 - 700.0)
        elif wave < 420.0:
            s = 0.3 + 0.7 * (wave - 380.0) / (420.0 - 380.0)

        r = pow(r * s, 0.8)
        g = pow(g * s, 0.8)
        b = pow(b * s, 0.8)

        return qRgb(r*255, g*255, b*255)
class GraphViewMixin:
    """Provides the graph view for the DS form."""

    graph_created = Signal()

    _node_extent = 64
    _arc_width = 0.25 * _node_extent
    _arc_length_hint = 3 * _node_extent

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._added_objects = {}
        self._added_relationships = {}
        self.object_class_list_model = ObjectClassListModel(
            self, self.db_mngr, self.db_map)
        self.relationship_class_list_model = RelationshipClassListModel(
            self, self.db_mngr, self.db_map)
        self.ui.listView_object_class.setModel(self.object_class_list_model)
        self.ui.listView_relationship_class.setModel(
            self.relationship_class_list_model)
        self.hidden_items = list()
        self.rejected_items = list()
        self.entity_item_selection = list()
        self.zoom_widget_action = None
        area = self.dockWidgetArea(self.ui.dockWidget_item_palette)
        self._handle_item_palette_dock_location_changed(area)
        self.ui.treeView_object.qsettings = self.qsettings
        self.live_demo = GraphViewDemo(self)
        self.setup_zoom_widget_action()

    def add_menu_actions(self):
        """Adds toggle view actions to View menu."""
        super().add_menu_actions()
        self.ui.menuView.addSeparator()
        self.ui.menuView.addAction(
            self.ui.dockWidget_entity_graph.toggleViewAction())
        self.ui.menuView.addAction(
            self.ui.dockWidget_item_palette.toggleViewAction())

    def restore_dock_widgets(self):
        super().restore_dock_widgets()
        self.live_demo.hide()

    def connect_signals(self):
        """Connects signals."""
        super().connect_signals()
        self.ui.graphicsView.context_menu_requested.connect(
            self.show_graph_view_context_menu)
        self.ui.graphicsView.item_dropped.connect(self._handle_item_dropped)
        self.ui.dockWidget_entity_graph.visibilityChanged.connect(
            self._handle_entity_graph_visibility_changed)
        self.ui.dockWidget_item_palette.visibilityChanged.connect(
            self._handle_item_palette_visibility_changed)
        self.ui.dockWidget_item_palette.dockLocationChanged.connect(
            self._handle_item_palette_dock_location_changed)
        self.ui.actionHide_selected.triggered.connect(self.hide_selected_items)
        self.ui.actionShow_hidden.triggered.connect(self.show_hidden_items)
        self.ui.actionPrune_selected.triggered.connect(
            self.prune_selected_items)
        self.ui.actionRestore_pruned.triggered.connect(
            self.restore_pruned_items)
        self.ui.actionLive_graph_demo.triggered.connect(self.show_demo)
        # Dock Widgets menu action
        self.ui.menuGraph.aboutToShow.connect(
            self._handle_menu_graph_about_to_show)
        self.ui.menuHelp.aboutToShow.connect(
            self._handle_menu_help_about_to_show)
        self.zoom_widget_action.minus_pressed.connect(
            self._handle_zoom_minus_pressed)
        self.zoom_widget_action.plus_pressed.connect(
            self._handle_zoom_plus_pressed)
        self.zoom_widget_action.reset_pressed.connect(
            self._handle_zoom_reset_pressed)
        # Connect Add more items in Item palette
        self.ui.listView_object_class.clicked.connect(
            self._add_more_object_classes)
        self.ui.listView_relationship_class.clicked.connect(
            self._add_more_relationship_classes)

    def setup_zoom_widget_action(self):
        """Setups zoom widget action in view menu."""
        self.zoom_widget_action = ZoomWidgetAction(self.ui.menuView)
        self.ui.menuGraph.addSeparator()
        self.ui.menuGraph.addAction(self.zoom_widget_action)

    def init_models(self):
        """Initializes models."""
        super().init_models()
        self.object_class_list_model.populate_list()
        self.relationship_class_list_model.populate_list()

    def receive_object_classes_added(self, db_map_data):
        super().receive_object_classes_added(db_map_data)
        self.object_class_list_model.receive_entity_classes_added(db_map_data)

    def receive_object_classes_updated(self, db_map_data):
        super().receive_object_classes_updated(db_map_data)
        self.object_class_list_model.receive_entity_classes_updated(
            db_map_data)
        self.refresh_icons(db_map_data)

    def receive_object_classes_removed(self, db_map_data):
        super().receive_object_classes_removed(db_map_data)
        self.object_class_list_model.receive_entity_classes_removed(
            db_map_data)

    def receive_relationship_classes_added(self, db_map_data):
        super().receive_relationship_classes_added(db_map_data)
        self.relationship_class_list_model.receive_entity_classes_added(
            db_map_data)

    def receive_relationship_classes_updated(self, db_map_data):
        super().receive_relationship_classes_updated(db_map_data)
        self.relationship_class_list_model.receive_entity_classes_updated(
            db_map_data)
        self.refresh_icons(db_map_data)

    def receive_relationship_classes_removed(self, db_map_data):
        super().receive_relationship_classes_removed(db_map_data)
        self.relationship_class_list_model.receive_entity_classes_removed(
            db_map_data)

    def receive_objects_added(self, db_map_data):
        """Runs when objects are added to the db.
        Builds a lookup dictionary consumed by ``add_object``.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        super().receive_objects_added(db_map_data)
        self._added_objects = {(x["class_id"], x["name"]): x["id"]
                               for x in db_map_data.get(self.db_map, [])}

    def receive_objects_updated(self, db_map_data):
        """Runs when objects are updated in the db. Refreshes names of objects in graph.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        super().receive_objects_updated(db_map_data)
        updated_ids = {x["id"] for x in db_map_data.get(self.db_map, [])}
        for item in self.ui.graphicsView.items():
            if isinstance(item, ObjectItem) and item.entity_id in updated_ids:
                item.refresh_name()

    def receive_objects_removed(self, db_map_data):
        """Runs when objects are removed from the db. Rebuilds graph if needed.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        super().receive_objects_removed(db_map_data)
        self.receive_entities_removed(db_map_data)

    def receive_relationships_added(self, db_map_data):
        """Runs when relationships are added to the db.
        Builds a lookup dictionary consumed by ``add_relationship``.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        super().receive_relationships_added(db_map_data)
        self._added_relationships = {(x["class_id"], x["object_id_list"]):
                                     x["id"]
                                     for x in db_map_data.get(self.db_map, [])}

    def receive_relationships_removed(self, db_map_data):
        """Runs when relationships are removed from the db. Rebuilds graph if needed.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        super().receive_relationships_removed(db_map_data)
        self.receive_entities_removed(db_map_data)

    def receive_entities_removed(self, db_map_data):
        removed_ids = {x["id"] for x in db_map_data.get(self.db_map, [])}
        for item in self.ui.graphicsView.items():
            if isinstance(item, EntityItem) and item.entity_id in removed_ids:
                item.wipe_out()

    def refresh_icons(self, db_map_data):
        """Runs when entity classes are updated in the db. Refreshes icons of entities in graph.

        Args:
            db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
        """
        updated_ids = {x["id"] for x in db_map_data.get(self.db_map, [])}
        for item in self.ui.graphicsView.items():
            if isinstance(item,
                          EntityItem) and item.entity_class_id in updated_ids:
                item.refresh_icon()

    @Slot("QModelIndex")
    def _add_more_object_classes(self, index):
        """Runs when the user clicks on the Item palette Object class view.
        Opens the form  to add more object classes if the index is the one that sayes 'New...'.

        Args:
            index (QModelIndex): The clicked index.
        """
        if index == index.model().new_index:
            self.show_add_object_classes_form()

    @Slot("QModelIndex")
    def _add_more_relationship_classes(self, index):
        """Runs when the user clicks on the Item palette Relationship class view.
        Opens the form to add more relationship classes if the index is the one that sayes 'New...'.

        Args:
            index (QModelIndex): The clicked index.
        """
        if index == index.model().new_index:
            self.show_add_relationship_classes_form()

    @Slot()
    def _handle_zoom_minus_pressed(self):
        """Performs a zoom out on the view."""
        self.ui.graphicsView.zoom_out()

    @Slot()
    def _handle_zoom_plus_pressed(self):
        """Performs a zoom in on the view."""
        self.ui.graphicsView.zoom_in()

    @Slot()
    def _handle_zoom_reset_pressed(self):
        """Resets the zoom on the view."""
        self.ui.graphicsView.reset_zoom()

    @Slot()
    def _handle_menu_graph_about_to_show(self):
        """Enables or disables actions according to current selection in the graph."""
        visible = self.ui.dockWidget_entity_graph.isVisible()
        self.ui.actionHide_selected.setEnabled(
            visible and bool(self.entity_item_selection))
        self.ui.actionShow_hidden.setEnabled(visible
                                             and bool(self.hidden_items))
        self.ui.actionPrune_selected.setEnabled(
            visible and bool(self.entity_item_selection))
        self.ui.actionRestore_pruned.setEnabled(visible
                                                and bool(self.rejected_items))
        self.zoom_widget_action.setEnabled(visible)

    @Slot()
    def _handle_menu_help_about_to_show(self):
        """Enables or disables action according to current status of the demo."""
        self.ui.actionLive_graph_demo.setEnabled(
            not self.live_demo.is_running())

    @Slot("Qt.DockWidgetArea")
    def _handle_item_palette_dock_location_changed(self, area):
        """Runs when the item palette dock widget location changes.
        Adjusts splitter orientation accordingly."""
        if area & (Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea):
            self.ui.splitter_object_relationship_class.setOrientation(
                Qt.Vertical)
        else:
            self.ui.splitter_object_relationship_class.setOrientation(
                Qt.Horizontal)

    @Slot(bool)
    def _handle_entity_graph_visibility_changed(self, visible):
        if visible:
            self.build_graph()
        self.ui.dockWidget_item_palette.setVisible(
            self.ui.dockWidget_entity_graph.isVisible())

    @Slot(bool)
    def _handle_item_palette_visibility_changed(self, visible):
        if visible:
            self.ui.dockWidget_entity_graph.show()

    @Slot("QItemSelection", "QItemSelection")
    def _handle_object_tree_selection_changed(self, selected, deselected):
        """Builds graph."""
        super()._handle_object_tree_selection_changed(selected, deselected)
        if self.ui.dockWidget_entity_graph.isVisible():
            self.build_graph()

    @busy_effect
    def build_graph(self, timeit=False):
        """Builds the graph."""
        tic = time.clock()
        new_items = self._get_new_items()
        wip_relationship_items = self._get_wip_relationship_items()
        scene = self.new_scene()
        if not new_items and not wip_relationship_items:
            item = QGraphicsTextItem("Nothing to show.")
            scene.addItem(item)
        else:
            if new_items:
                object_items = new_items[0]
                self._add_new_items(scene, *new_items)  # pylint: disable=no-value-for-parameter
            else:
                object_items = []
            if wip_relationship_items:
                self._add_wip_relationship_items(scene, wip_relationship_items,
                                                 object_items)
            self.hidden_items.clear()
        self.extend_scene()
        toc = time.clock()
        _ = timeit and self.msg.emit(
            "Graph built in {} seconds\t".format(toc - tic))
        self.graph_created.emit()

    def _get_selected_object_ids(self):
        """Returns a set of object ids according to selection in the object tree.

        Returns:
            set
        """
        root_index = self.object_tree_model.root_index
        if self.ui.treeView_object.selectionModel().isSelected(root_index):
            return {x["id"] for x in self.db_mngr.get_objects(self.db_map)}
        unique_object_ids = set()
        for index in self.object_tree_model.selected_object_indexes:
            item = index.model().item_from_index(index)
            object_id = item.db_map_id(self.db_map)
            unique_object_ids.add(object_id)
        for index in self.object_tree_model.selected_object_class_indexes:
            item = index.model().item_from_index(index)
            object_class_id = item.db_map_id(self.db_map)
            object_ids = {
                x["id"]
                for x in self.db_mngr.get_objects(self.db_map,
                                                  class_id=object_class_id)
            }
            unique_object_ids.update(object_ids)
        return unique_object_ids

    def _get_graph_data(self):
        """Returns data for making graph according to selection in Object tree.

        Returns:
            list: integer object ids
            list: integer relationship ids
            list: arc source indices
            list: arc destination indices
        """
        rejected_entity_ids = {x.entity_id for x in self.rejected_items}
        object_ids = list(self._get_selected_object_ids() -
                          rejected_entity_ids)
        src_inds = list()
        dst_inds = list()
        relationship_ids = list()
        relationship_ind = len(object_ids)
        for relationship in self.db_mngr.get_relationships(self.db_map):
            if relationship["id"] in rejected_entity_ids:
                continue
            object_id_list = relationship["object_id_list"]
            object_id_list = [int(x) for x in object_id_list.split(",")]
            object_inds = list()
            for object_id in object_id_list:
                try:
                    object_ind = object_ids.index(object_id)
                    object_inds.append(object_ind)
                except ValueError:
                    pass
            if len(object_inds) < 2:
                continue
            relationship_ids.append(relationship["id"])
            for object_ind in object_inds:
                src_inds.append(relationship_ind)
                dst_inds.append(object_ind)
            relationship_ind += 1
        return object_ids, relationship_ids, src_inds, dst_inds

    def _get_new_items(self):
        """Returns new items for the graph.

        Returns:
            list: ObjectItem instances
            list: RelationshipItem instances
            list: ArcItem instances
        """
        object_ids, relationship_ids, src_inds, dst_inds = self._get_graph_data(
        )
        d = self.shortest_path_matrix(
            len(object_ids) + len(relationship_ids), src_inds, dst_inds,
            self._arc_length_hint)
        if d is None:
            return []
        x, y = self.vertex_coordinates(d)
        object_items = list()
        relationship_items = list()
        arc_items = list()
        for i, object_id in enumerate(object_ids):
            object_item = ObjectItem(self,
                                     x[i],
                                     y[i],
                                     self._node_extent,
                                     entity_id=object_id)
            object_items.append(object_item)
        offset = len(object_items)
        for i, relationship_id in enumerate(relationship_ids):
            relationship_item = RelationshipItem(self,
                                                 x[offset + i],
                                                 y[offset + i],
                                                 self._node_extent,
                                                 entity_id=relationship_id)
            relationship_items.append(relationship_item)
        for rel_ind, obj_ind in zip(src_inds, dst_inds):
            arc_item = ArcItem(relationship_items[rel_ind - offset],
                               object_items[obj_ind], self._arc_width)
            arc_items.append(arc_item)
        return object_items, relationship_items, arc_items

    def _get_wip_relationship_items(self):
        """Removes and returns wip relationship items from the current scene.

        Returns:
            list: RelationshipItem instances
        """
        scene = self.ui.graphicsView.scene()
        if not scene:
            return []
        wip_items = []
        for item in scene.items():
            if isinstance(item, RelationshipItem) and item.is_wip:
                for arc_item in item.arc_items:
                    scene.removeItem(arc_item)
                unique_object_items = set(arc_item.obj_item
                                          for arc_item in item.arc_items)
                for obj_item in unique_object_items:
                    scene.removeItem(obj_item)
                scene.removeItem(item)
                wip_items.append(item)
        return wip_items

    @staticmethod
    def _add_new_items(scene, object_items, relationship_items, arc_items):
        for item in object_items + relationship_items + arc_items:
            scene.addItem(item)

    @staticmethod
    def _add_wip_relationship_items(scene, wip_relationship_items,
                                    new_object_items):
        """Adds wip relationship items to the given scene, merging completed members with existing
        object items by entity id.

        Args:
            scene (QGraphicsScene)
            wip_relationship_items (list)
            new_object_items (list)
        """
        object_items_lookup = dict()
        for object_item in new_object_items:
            object_items_lookup[object_item.entity_id] = object_item
        for rel_item in wip_relationship_items:
            scene.addItem(rel_item)
            for arc_item in rel_item.arc_items:
                scene.addItem(arc_item)
            unique_object_items = set(arc_item.obj_item
                                      for arc_item in rel_item.arc_items)
            for obj_item in unique_object_items:
                scene.addItem(obj_item)
                obj_item._merge_target = object_items_lookup.get(
                    obj_item.entity_id)
                if obj_item._merge_target:
                    obj_item.merge_into_target(force=True)

    @staticmethod
    def shortest_path_matrix(N, src_inds, dst_inds, spread):
        """Returns the shortest-path matrix.

        Args:
            N (int): The number of nodes in the graph.
            src_inds (list): Source indices
            dst_inds (list): Destination indices
            spread (int): The desired 'distance' between neighbours
        """
        if not N:
            return None
        dist = np.zeros((N, N))
        src_inds = arr(src_inds)
        dst_inds = arr(dst_inds)
        try:
            dist[src_inds, dst_inds] = dist[dst_inds, src_inds] = spread
        except IndexError:
            pass
        d = dijkstra(dist, directed=False)
        # Remove infinites and zeros
        d[d == np.inf] = spread * 3
        d[d == 0] = spread * 1e-6
        return d

    @staticmethod
    def sets(N):
        """Returns sets of vertex pairs indices.

        Args:
            N (int)
        """
        sets = []
        for n in range(1, N):
            pairs = np.zeros((N - n, 2), int)  # pairs on diagonal n
            pairs[:, 0] = np.arange(N - n)
            pairs[:, 1] = pairs[:, 0] + n
            mask = np.mod(range(N - n), 2 * n) < n
            s1 = pairs[mask]
            s2 = pairs[~mask]
            if s1.any():
                sets.append(s1)
            if s2.any():
                sets.append(s2)
        return sets

    @staticmethod
    def vertex_coordinates(matrix,
                           heavy_positions=None,
                           iterations=10,
                           weight_exp=-2,
                           initial_diameter=1000):
        """Returns x and y coordinates for each vertex in the graph, computed using VSGD-MS."""
        if heavy_positions is None:
            heavy_positions = dict()
        N = len(matrix)
        if N == 1:
            return [0], [0]
        mask = np.ones((N, N)) == 1 - np.tril(np.ones(
            (N, N)))  # Upper triangular except diagonal
        np.random.seed(0)
        layout = np.random.rand(
            N, 2
        ) * initial_diameter - initial_diameter / 2  # Random layout with initial diameter
        heavy_ind_list = list()
        heavy_pos_list = list()
        for ind, pos in heavy_positions.items():
            heavy_ind_list.append(ind)
            heavy_pos_list.append([pos.x(), pos.y()])
        heavy_ind = arr(heavy_ind_list)
        heavy_pos = arr(heavy_pos_list)
        if heavy_ind.any():
            layout[heavy_ind, :] = heavy_pos
        weights = matrix**weight_exp  # bus-pair weights (lower for distant buses)
        maxstep = 1 / np.min(weights[mask])
        minstep = 1 / np.max(weights[mask])
        lambda_ = np.log(minstep / maxstep) / (
            iterations - 1)  # exponential decay of allowed adjustment
        sets = GraphViewMixin.sets(N)  # construct sets of bus pairs
        for iteration in range(iterations):
            step = maxstep * np.exp(
                lambda_ * iteration)  # how big adjustments are allowed?
            rand_order = np.random.permutation(
                N)  # we don't want to use the same pair order each iteration
            for p in sets:
                v1, v2 = rand_order[p[:, 0]], rand_order[
                    p[:, 1]]  # arrays of vertex1 and vertex2
                # current distance (possibly accounting for system rescaling)
                dist = ((layout[v1, 0] - layout[v2, 0])**2 +
                        (layout[v1, 1] - layout[v2, 1])**2)**0.5
                r = (matrix[v1, v2] - dist)[:, None] / 2 * (
                    layout[v1] - layout[v2]) / dist[:, None]  # desired change
                dx1 = r * np.minimum(1, weights[v1, v2] * step)[:, None]
                dx2 = -dx1
                layout[v1, :] += dx1  # update position
                layout[v2, :] += dx2
                if heavy_ind.any():
                    layout[heavy_ind, :] = heavy_pos
        return layout[:, 0], layout[:, 1]

    def new_scene(self):
        """Replaces the current scene with a new one."""
        self.tear_down_scene()
        scene = ShrinkingScene(100.0, 100.0, None)
        self.ui.graphicsView.setScene(scene)
        scene.changed.connect(self._handle_scene_changed)
        scene.selectionChanged.connect(self._handle_scene_selection_changed)
        return scene

    def tear_down_scene(self):
        """Removes all references to this form in graphics items and schedules
        the scene for deletion."""
        scene = self.ui.graphicsView.scene()
        if not scene:
            return
        scene.deleteLater()

    def extend_scene(self):
        """Extends the scene to show all items."""
        bounding_rect = self.ui.graphicsView.scene().itemsBoundingRect()
        self.ui.graphicsView.scene().setSceneRect(bounding_rect)
        self.ui.graphicsView.init_zoom()

    @Slot(name="_handle_scene_selection_changed")
    def _handle_scene_selection_changed(self):
        """Filters parameters by selected objects in the graph."""
        scene = self.ui.graphicsView.scene()
        selected_items = scene.selectedItems()
        self.entity_item_selection = [
            x for x in selected_items if isinstance(x, EntityItem)
        ]
        selected_objs = {self.db_map: []}
        selected_rels = {self.db_map: []}
        for item in selected_items:
            if isinstance(item, ObjectItem):
                selected_objs[self.db_map].append(item.db_representation)
            elif isinstance(item, RelationshipItem):
                selected_rels[self.db_map].append(item.db_representation)
        cascading_rels = self.db_mngr.find_cascading_relationships(
            self.db_mngr._to_ids(selected_objs))
        selected_rels = self._extend_merge(selected_rels, cascading_rels)
        for db_map, items in selected_objs.items():
            self.selected_ent_cls_ids["object class"].setdefault(
                db_map, set()).update({x["class_id"]
                                       for x in items})
        for db_map, items in selected_rels.items():
            self.selected_ent_cls_ids["relationship class"].setdefault(
                db_map, set()).update({x["class_id"]
                                       for x in items})
        self.selected_ent_ids["object"] = self._db_map_class_id_data(
            selected_objs)
        self.selected_ent_ids["relationship"] = self._db_map_class_id_data(
            selected_rels)
        self.update_filter()

    @Slot(list)
    def _handle_scene_changed(self, region):
        """Enlarges the scene rect if needed."""
        scene_rect = self.ui.graphicsView.scene().sceneRect()
        if all(scene_rect.contains(rect) for rect in region):
            return
        extended_rect = scene_rect
        for rect in region:
            extended_rect = extended_rect.united(rect)
        self.ui.graphicsView.scene().setSceneRect(extended_rect)

    @Slot("QPoint", "QString", name="_handle_item_dropped")
    def _handle_item_dropped(self, pos, text):
        """Runs when an item is dropped from Item palette onto the view.
        Creates the object or relationship template.

        Args:
            pos (QPoint)
            text (str)
        """
        scene = self.ui.graphicsView.scene()
        if not scene:
            scene = self.new_scene()
        for item in scene.items():
            if isinstance(item, QGraphicsTextItem):
                scene.removeItem(item)
        scene_pos = self.ui.graphicsView.mapToScene(pos)
        entity_type, entity_class_id = text.split(":")
        entity_class_id = int(entity_class_id)
        if entity_type == "object class":
            object_item = ObjectItem(self,
                                     scene_pos.x(),
                                     scene_pos.y(),
                                     self._node_extent,
                                     entity_class_id=entity_class_id)
            scene.addItem(object_item)
            self.ui.graphicsView.setFocus()
            object_item.edit_name()
        elif entity_type == "relationship class":
            self.add_wip_relationship(scene, scene_pos, entity_class_id)
        self.extend_scene()
        self.graph_created.emit()

    def add_wip_relationship(self,
                             scene,
                             pos,
                             relationship_class_id,
                             center_item=None,
                             center_dimension=None):
        """Makes items for a wip relationship and adds them to the scene at the given coordinates.

        Args:
            scene (QGraphicsScene)
            pos (QPointF)
            relationship_class_id (int)
            center_item_dimension (tuple, optional): A tuple of (ObjectItem, dimension) to put at the center of the wip item.

        """
        relationship_class = self.db_mngr.get_item(self.db_map,
                                                   "relationship class",
                                                   relationship_class_id)
        if not relationship_class:
            return
        object_class_id_list = [
            int(id_)
            for id_ in relationship_class["object_class_id_list"].split(",")
        ]
        dimension_count = len(object_class_id_list)
        rel_inds = [dimension_count for _ in range(dimension_count)]
        obj_inds = list(range(dimension_count))
        d = self.shortest_path_matrix(dimension_count + 1, rel_inds, obj_inds,
                                      self._arc_length_hint)
        if d is None:
            return
        x, y = self.vertex_coordinates(d)
        # Fix position
        x_offset = pos.x()
        y_offset = pos.y()
        if center_item:
            center = center_item.sceneBoundingRect().center()
            x_offset -= pos.x() - center.x()
            y_offset -= pos.y() - center.y()
        x += x_offset
        y += y_offset
        relationship_item = RelationshipItem(
            self,
            x[-1],
            y[-1],
            self._node_extent,
            entity_class_id=relationship_class_id)
        object_items = list()
        arc_items = list()
        for i, object_class_id in enumerate(object_class_id_list):
            object_item = ObjectItem(self,
                                     x[i],
                                     y[i],
                                     self._node_extent,
                                     entity_class_id=object_class_id)
            object_items.append(object_item)
            arc_item = ArcItem(relationship_item,
                               object_item,
                               self._arc_width,
                               is_wip=True)
            arc_items.append(arc_item)
        entity_items = object_items + [relationship_item]
        for item in entity_items + arc_items:
            scene.addItem(item)
        if center_item and center_dimension is not None:
            center_item._merge_target = object_items[center_dimension]
            center_item.merge_into_target()

    def add_object(self, object_class_id, name):
        """Adds object to the database.

        Args:
            object_class_id (int)
            name (str)

        Returns:
            int, NoneType: The id of the added object if successful, None otherwise.
        """
        item = dict(class_id=object_class_id, name=name)
        db_map_data = {self.db_map: [item]}
        self.db_mngr.add_objects(db_map_data)
        object_id = self._added_objects.get((object_class_id, name))
        self._added_objects.clear()
        return object_id

    def update_object(self, object_id, name):
        """Updates object in the db.

        Args:
            object_id (int)
            name (str)
        """
        item = dict(id=object_id, name=name)
        db_map_data = {self.db_map: [item]}
        self.db_mngr.update_objects(db_map_data)

    def add_relationship(self, class_id, object_id_list, object_name_list):
        """Adds relationship to the db.

        Args:
            class_id (int)
            object_id_list (list)
        """
        class_name = self.db_mngr.get_item(self.db_map, "relationship class",
                                           class_id)["name"]
        name = class_name + "_" + "__".join(object_name_list)
        relationship = {
            'name': name,
            'object_id_list': object_id_list,
            'class_id': class_id
        }
        self.db_mngr.add_relationships({self.db_map: [relationship]})
        object_id_list = ",".join([str(id_) for id_ in object_id_list])
        relationship_id = self._added_relationships.get(
            (class_id, object_id_list))
        self._added_relationships.clear()
        return relationship_id

    @Slot("QPoint")
    def show_graph_view_context_menu(self, global_pos):
        """Shows context menu for graphics view.

        Args:
            global_pos (QPoint)
        """
        menu = GraphViewContextMenu(self, global_pos)
        option = menu.get_action()
        if option == "Hide selected":
            self.hide_selected_items()
        elif option == "Show hidden":
            self.show_hidden_items()
        elif option == "Prune selected":
            self.prune_selected_items()
        elif option == "Restore pruned":
            self.restore_pruned_items()
        else:
            pass
        menu.deleteLater()

    @Slot(bool)
    def hide_selected_items(self, checked=False):
        """Hides selected items."""
        self.hidden_items.extend(self.entity_item_selection)
        for item in self.entity_item_selection:
            item.set_all_visible(False)

    @Slot(bool)
    def show_hidden_items(self, checked=False):
        """Shows hidden items."""
        if not self.ui.graphicsView.scene():
            return
        for item in self.hidden_items:
            item.set_all_visible(True)
        self.hidden_items.clear()

    @Slot(bool)
    def prune_selected_items(self, checked=False):
        """Prunes selected items."""
        self.rejected_items.extend(self.entity_item_selection)
        self.build_graph()

    @Slot(bool)
    def restore_pruned_items(self, checked=False):
        """Reinstates pruned items."""
        self.rejected_items.clear()
        self.build_graph()

    @Slot(bool)
    def show_demo(self, checked=False):
        self.live_demo.show()

    def show_object_item_context_menu(self, global_pos, main_item):
        """Shows context menu for entity item.

        Args:
            global_pos (QPoint)
            main_item (spinetoolbox.widgets.graph_view_graphics_items.ObjectItem)
        """
        menu = ObjectItemContextMenu(self, global_pos, main_item)
        option = menu.get_action()
        if self._apply_entity_context_menu_option(option):
            pass
        elif option in ('Set name', 'Rename'):
            main_item.edit_name()
        elif option in menu.relationship_class_dict:
            relationship_class = menu.relationship_class_dict[option]
            relationship_class_id = relationship_class["id"]
            dimension = relationship_class['dimension']
            scene = self.ui.graphicsView.scene()
            self.add_wip_relationship(scene,
                                      global_pos,
                                      relationship_class_id,
                                      center_item=main_item,
                                      center_dimension=dimension)
        menu.deleteLater()

    def show_relationship_item_context_menu(self, global_pos):
        """Shows context menu for entity item.

        Args:
            global_pos (QPoint)
        """
        menu = RelationshipItemContextMenu(self, global_pos)
        option = menu.get_action()
        self._apply_entity_context_menu_option(option)
        menu.deleteLater()

    def _apply_entity_context_menu_option(self, option):
        if option == 'Hide':
            self.hide_selected_items()
        elif option == 'Prune':
            self.prune_selected_items()
        elif option == 'Remove':
            self.remove_graph_items()
        else:
            return False
        return True

    @Slot("bool", name="remove_graph_items")
    def remove_graph_items(self, checked=False):
        """Removes all selected items in the graph."""
        if not self.entity_item_selection:
            return
        db_map_typed_data = {self.db_map: {}}
        for item in self.entity_item_selection:
            if item.is_wip:
                item.wipe_out()
            else:
                db_item = item.db_representation
                db_map_typed_data[self.db_map].setdefault(
                    item.entity_type, []).append(db_item)
        self.db_mngr.remove_items(db_map_typed_data)

    def closeEvent(self, event=None):
        """Handles close window event.

        Args:
            event (QEvent): Closing event if 'X' is clicked.
        """
        self.live_demo.setFloating(True)
        super().closeEvent(event)
        self.tear_down_scene()
Esempio n. 3
0
class WorkerSignals(QObject):
    finished = Signal()
    error = Signal(tuple)
    result = Signal(object)
    progress = Signal(int)
Esempio n. 4
0
class ImageViewer(QWidget):
    viewChanged = Signal(QRect, float, int, int)

    def __init__(self, original, processed, title=None, parent=None):
        super(ImageViewer, self).__init__(parent)
        if original is None and processed is None:
            raise ValueError(
                self.tr('ImageViewer.__init__: Empty image received'))
        if original is None and processed is not None:
            original = processed
        self.original = original
        self.processed = processed
        if self.original is not None and self.processed is None:
            self.view = DynamicView(self.original)
        else:
            self.view = DynamicView(self.processed)

        # view_label = QLabel(self.tr('View:'))
        self.original_radio = QRadioButton(self.tr('Original'))
        self.original_radio.setToolTip(
            self.tr('Show the original image for comparison'))
        self.process_radio = QRadioButton(self.tr('Processed'))
        self.process_radio.setToolTip(
            self.tr('Show result of the current processing'))
        self.zoom_label = QLabel()
        full_button = QToolButton()
        full_button.setText(self.tr('100%'))
        fit_button = QToolButton()
        fit_button.setText(self.tr('Fit'))
        height, width, _ = self.original.shape
        size_label = QLabel(self.tr('[{}x{} px]'.format(height, width)))
        export_button = QToolButton()
        export_button.setText(self.tr('Export...'))

        tool_layout = QHBoxLayout()
        if processed is not None:
            # tool_layout.addWidget(view_label)
            tool_layout.addWidget(self.original_radio)
            tool_layout.addWidget(self.process_radio)
            tool_layout.addStretch()
        tool_layout.addWidget(QLabel(self.tr('Zoom:')))
        tool_layout.addWidget(self.zoom_label)
        # tool_layout.addWidget(full_button)
        # tool_layout.addWidget(fit_button)
        tool_layout.addStretch()
        tool_layout.addWidget(size_label)
        if processed is not None:
            tool_layout.addWidget(export_button)
            self.original_radio.setChecked(False)
            self.process_radio.setChecked(True)
            self.toggle_mode(False)

        vert_layout = QVBoxLayout()
        if title is not None:
            self.title_label = QLabel(title)
            modify_font(self.title_label, bold=True)
            self.title_label.setAlignment(Qt.AlignCenter)
            vert_layout.addWidget(self.title_label)
        else:
            self.title_label = None
        vert_layout.addWidget(self.view)
        vert_layout.addLayout(tool_layout)
        self.setLayout(vert_layout)

        self.original_radio.toggled.connect(self.toggle_mode)
        fit_button.clicked.connect(self.view.zoom_fit)
        full_button.clicked.connect(self.view.zoom_full)
        export_button.clicked.connect(self.export_image)
        self.view.viewChanged.connect(self.forward_changed)

        # view_label.setVisible(processed is not None)
        # self.original_radio.setVisible(processed is not None)
        # self.process_radio.setVisible(processed is not None)
        # export_button.setVisible(processed is not None)
        # if processed is not None:
        #
        # self.adjustSize()

    def update_processed(self, image):
        if self.processed is None:
            return
        self.processed = image
        self.toggle_mode(self.original_radio.isChecked())

    def update_original(self, image):
        self.original = image
        self.toggle_mode(True)

    def change_view(self, rect, scaling, horizontal, vertical):
        self.view.change_view(rect, scaling, horizontal, vertical)

    def forward_changed(self, rect, scaling, horizontal, vertical):
        self.zoom_label.setText('{:.2f}%'.format(scaling * 100))
        modify_font(self.zoom_label, scaling == 1)
        self.viewChanged.emit(rect, scaling, horizontal, vertical)

    def get_rect(self):
        return self.view.get_rect()

    def keyPressEvent(self, event):
        if event.key() == Qt.Key_Space:
            if self.original_radio.isChecked():
                self.process_radio.setChecked(True)
            else:
                self.original_radio.setChecked(True)
        QWidget.keyPressEvent(self, event)

    def toggle_mode(self, toggled):
        if toggled:
            self.view.set_image(self.original)
        elif self.processed is not None:
            self.view.set_image(self.processed)

    def export_image(self):
        settings = QSettings()
        filename = QFileDialog.getSaveFileName(
            self, self.tr('Export image...'), settings.value('save_folder'),
            self.tr('PNG images (*.png)'))[0]
        if not filename:
            return
        if not filename.endswith('.png'):
            filename += '.png'
        cv.imwrite(filename, self.processed)

    def set_title(self, title):
        if self.title_label is not None:
            self.title_label.setText(title)
Esempio n. 5
0
class MemoryLoad(QObject):
    finished = Signal()

    def __init__(self, usbif, write_msg, num_banks, bank_size, switches, aux_switch=None):
        QObject.__init__(self)

        self._usbif = usbif
        self._default_write_msg = write_msg
        self._write_msg = write_msg
        self._num_banks = num_banks
        self._bank_size = bank_size

        self._switches = switches
        self._aux_switch = aux_switch
        self._bank = 0

        self._timer = QTimer()
        self._timer.timeout.connect(self._load_next_bank)

    def load_memory(self, filename, write_msg=None):
        if write_msg is None:
            write_msg = self._default_write_msg

        self._write_msg = write_msg
        self._load_data = array.array('H')
        with open(filename, 'rb') as f:
            self._load_data.fromfile(f, int(os.path.getsize(filename)/2))
        self._load_data.byteswap()
        self._bank = 0
        self._timer.start(20)

    def _load_next_bank(self):
        while self._bank < self._num_banks:
            if self._bank < 0o44:
                sw = self._switches[self._bank]
            else:
                sw = self._aux_switch

            if sw.isChecked():
                sw.setCheckState(Qt.PartiallyChecked)
                break

            self._bank += 1

        if self._bank == self._num_banks:
            self._complete_load()
            return

        bank_addr = self._bank * self._bank_size
        words = self._load_data[bank_addr:bank_addr+self._bank_size]

        if len(words) == 0:
            self._complete_load()
            return

        for a,w in enumerate(words):
            d,p = agc.unpack_word(w)
            self._usbif.send(self._write_msg(addr=bank_addr+a, data=d, parity=p))

        self._bank += 1

    def _complete_load(self):
        self._timer.stop()
        for sw in self._switches:
            sw.setTristate(False)
            sw.update()

        if self._aux_switch:
            self._aux_switch.setTristate(False)
            self._aux_switch.update()

        self.finished.emit()
Esempio n. 6
0
class Script(QObject):

    name_changed = Signal(str)

    def __init__(self, main_window, name, config=None):
        super(Script, self).__init__()

        self.main_window = main_window
        self.widget = WUIScript()

        # GENERAL ATTRIBUTES
        self.logger = Logger(self)
        self.variables = []
        self.variables_handler = None
        self.name = name
        self.flow = None
        self.thumbnail_source = ''  # URL to the Script's thumbnail picture
        self.code_preview_widget = CodePreview_Widget()

        if config:
            self.name = config['name']
            self.variables_handler = VariablesHandler(self,
                                                      config['variables'])
            self.flow = Flow(main_window, self, config['flow'])
            self.variables_handler.flow = self.flow
        else:
            self.flow = Flow(main_window, self)
            self.variables_handler = VariablesHandler(self)
            self.variables_handler.flow = self.flow

        # variables list widget
        self.widget.ui.variables_scrollArea.setWidget(
            self.variables_handler.list_widget)
        self.widget.ui.add_variable_push_button.clicked.connect(
            self.add_var_clicked)
        self.widget.ui.new_var_name_lineEdit.returnPressed.connect(
            self.new_var_line_edit_return_pressed)
        self.widget.ui.algorithm_data_flow_radioButton.toggled.connect(
            self.flow.algorithm_mode_data_flow_toggled)
        self.widget.ui.viewport_update_mode_sync_radioButton.toggled.connect(
            self.flow.viewport_update_mode_sync_toggled)

        # flow
        self.widget.ui.splitter.insertWidget(0, self.flow)

        # code preview
        self.widget.ui.source_code_groupBox.layout().addWidget(
            self.code_preview_widget)

        # logs
        self.widget.ui.logs_scrollArea.setWidget(self.logger)
        self.widget.ui.splitter.setSizes([700, 0])

    def show_NI_code(self, ni):
        """Called from Flow when the selection changed."""
        self.code_preview_widget.set_new_NI(ni)

    def add_var_clicked(self):
        self.variables_handler.create_new_var(
            self.widget.ui.new_var_name_lineEdit.text())

    def new_var_line_edit_return_pressed(self):
        self.variables_handler.create_new_var(
            self.widget.ui.new_var_name_lineEdit.text())

    def get_json_data(self):
        script_dict = {
            'name': self.name,
            'variables': self.variables_handler.get_json_data(),
            'flow': self.flow.get_json_data()
        }

        return script_dict
Esempio n. 7
0
class PSO(QThread):
    sig_console = Signal(str)
    sig_current_iter_time = Signal(int)
    sig_current_error = Signal(float)
    sig_iter_error = Signal(float, float, float)
    sig_indicate_busy = Signal()
    sig_rbfn = Signal(RBFN)

    def __init__(self,
                 iter_times,
                 population_size,
                 inertia_weight,
                 cognitive_const_upper,
                 social_const_upper,
                 v_max,
                 nneuron,
                 dataset,
                 sd_max=1,
                 is_multicore=True):
        super().__init__()
        self.abort = False
        self.iter_times = iter_times
        self.population_size = population_size
        self.inertia_weight = inertia_weight
        self.cognitive_const_upper = cognitive_const_upper
        self.social_const_upper = social_const_upper
        self.dataset = dataset
        self.is_multicore = is_multicore

        self.population = [
            Individual(self.dataset, nneuron, v_max, sd_max)
            for _ in range(self.population_size)
        ]
        self.rbfn = RBFN(nneuron, (0, 40), sd_max)

    def run(self):
        total_best = copy.deepcopy(self.population[0])
        for i in range(self.iter_times):
            if self.abort:
                break
            self.sig_current_iter_time.emit(i)

            # get the best individual in current iteration
            global_best = self.__get_best_individual()

            # save the best individual in whole training
            total_best = copy.deepcopy(
                min((total_best, global_best), key=operator.attrgetter('err')))
            self.__show_errs(global_best, total_best)

            # update the position and velocity for each individual
            for indiv in self.population:
                indiv.update_position(
                    self.inertia_weight,
                    random.uniform(0, self.cognitive_const_upper),
                    random.uniform(0, self.social_const_upper),
                    global_best.position)
        self.sig_indicate_busy.emit()
        self.sig_console.emit('Selecting the best individual...')
        global_best = self.__get_best_individual()
        total_best = copy.deepcopy(
            min((total_best, global_best), key=operator.attrgetter('err')))
        self.__show_errs(global_best, total_best)
        self.sig_console.emit('The least error: %f' % total_best.err)
        self.sig_console.emit('The best individual: \n{}'.format(
            total_best.position))
        self.rbfn.load_model(total_best.position)
        self.sig_rbfn.emit(self.rbfn)

    @Slot()
    def stop(self):
        if self.isRunning():
            self.sig_console.emit(
                'WARNING: User interrupts running thread. The thread will be '
                'stop in next iteration. Please wait a second...')

        self.abort = True

    def __get_best_individual(self):
        """Update every individual's fitness and return the best individual."""
        if self.is_multicore:
            with mp.Pool() as pool:
                res = pool.map(get_indiv_fitness_update, self.population)
                for indiv, result in zip(self.population, res):
                    indiv.fitness = result
            return max(self.population, key=lambda indiv: indiv.fitness)
        return max(self.population, key=lambda indiv: indiv.update_fitness())

    def __show_errs(self, global_best, total_best):
        for indiv in self.population:
            time.sleep(0.001)
            self.sig_current_error.emit(indiv.err)
        self.sig_iter_error.emit(
            sum(i.err for i in self.population) / len(self.population),
            global_best.err, total_best.err)
Esempio n. 8
0
class ButtonList(QWidget):

    # 按钮添加时发送信号
    # QPushButton为触发的按钮
    signalBtnAdded = Signal(QPushButton)
    signalBtnDeleted = Signal(QPushButton)
    signalBtnClicked = Signal(QPushButton)
    """按钮列表,默认有一个添加按钮的按钮"""
    def __init__(self, addStr: str, parent: any = None) -> None:
        '''按钮列表,默认有一个添加按钮的按钮

        Args:
            addStr(str):添加按钮上显示的文字
        '''
        super().__init__(parent=parent)

        self.addIterm = QPushButton(addStr, self)
        self.addIterm.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
        # self.addIterm.resize(self.addIterm.sizeHint())
        self.addIterm.clicked.connect(self.addBtn_clicked)

        # 添加按钮布局
        self.layDown = QVBoxLayout()
        self.layDown.addWidget(self.addIterm)
        # 按钮布局,与addIterm所处的布局分开,防止删除按钮时布局的神秘错乱(我太菜了)
        self.layUp = QVBoxLayout()
        # 整体外布局
        lay = QVBoxLayout()
        lay.setMargin(3)  # 设置边距
        lay.addLayout(self.layUp)
        lay.addLayout(self.layDown)
        lay.addStretch(1)  # 添加拉伸因子,防止按钮由于父控件大小被纵向拉伸
        self.setLayout(lay)  # 应用布局

    @Slot()
    def addBtn_clicked(self) -> QPushButton:
        index = self.layUp.count()  # 添加的按钮在布局中的索引位置,起始0
        button = QPushButton('btn{}'.format(index))
        button.clicked.connect(lambda: self.button_clicked(button))

        button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)

        def on_context_menu(point):
            # 弹出菜单
            menu.exec_(button.mapToGlobal(point))  # 把相对于按钮的位置转为相对于全局的位置

        button.setContextMenuPolicy(Qt.CustomContextMenu)  # 菜单策略,自定义
        button.customContextMenuRequested.connect(on_context_menu)  # 触发信号

        # 设置右击删除菜单
        menu = QMenu(button)
        delQAction = QAction("删除", button)
        delQAction.triggered.connect(lambda: self.deleteButton(button))
        menu.addAction(delQAction)

        self.layUp.insertWidget(index, button)  # 添加按钮到布局中
        self.signalBtnAdded.emit(button)  # 发送添加信号
        return button

    @Slot(QPushButton)
    def deleteButton(self, button: QPushButton):
        self.signalBtnDeleted.emit(button)  # 发送删除信号
        # self.layUp.removeWidget(button)  # 移除控件
        button.deleteLater()  # 删除控件

    @Slot(QPushButton)
    def button_clicked(self, button: QPushButton):
        self.signalBtnClicked.emit(button)  # 发送点击信号
Esempio n. 9
0
class Mouse_Monitor(QObject):
    """"""
    Mouse_Moved = Signal(str)
    Mouse_Clicked = Signal(str)
    Mouse_Scrolled = Signal(str)
    Left_Mouse_Clicked = Signal(bool)
    Toggle_Auto_Click_Enabled = Signal(bool)

    def __init__(self, parent=None):
        """Constructor"""
        super(Mouse_Monitor, self).__init__(parent)
        self._is_in_autoclick_mode = False

        self.listener = mouse.Listener(on_move=self.On_Mouse_Moved,
                                       on_click=self.On_Mouse_Clicked,
                                       on_scroll=self.On_Mouse_Scrolled)
        self.listener.start()
        self.keyboardlistener = keyboard.Listener(
            on_press=self.On_Keyboard_Press,
            on_release=self.On_Keyboard_Release)
        self.keyboardlistener.start()

    def On_Mouse_Moved(self, x, y):
        """"""
        self.Mouse_Moved.emit('Pointer moved to {0}'.format((x, y)))

    def On_Mouse_Clicked(self, x, y, button, pressed):
        """"""
        if button == mouse.Button.left:
            button_name = "right"
            if pressed and not self._is_in_autoclick_mode:
                self.Left_Mouse_Clicked.emit(True)
            elif not pressed:
                self.Left_Mouse_Clicked.emit(False)

        elif button == mouse.Button.right:
            button_name = "Right"
        else:
            button_name = "Middle"
        if pressed:
            action_name = "Pressed"
        else:
            action_name = "Released"
        #print('{0} Button Was {1} at {2},{3}'.format(button_name,action_name,x, y))
        self.Mouse_Clicked.emit('{0} Button Was {1} at {2},{3}'.format(
            button_name, action_name, x, y))

    def On_Mouse_Scrolled(self, x, y, dx, dy):
        """"""
        self.Mouse_Scrolled.emit('Scrolled {0} at {1}'.format(
            'down' if dy < 0 else 'up', (x, y)))

    def On_Keyboard_Press(self, key):
        try:
            if key.char == "?":
                #print('{0} pressed'.format(key.char))
                if self._is_in_autoclick_mode:
                    self.Toggle_Auto_Click_Enabled.emit(False)
                else:
                    self.Toggle_Auto_Click_Enabled.emit(True)
        except AttributeError:
            pass
            #print('special key {0} pressed'.format(
            #key))

    def On_Keyboard_Release(self, key):
        try:
            if key.char == "?":
                print('{0} released'.format(key))
        except:
            pass
Esempio n. 10
0
class MeasuredDataModel(BaseModel):
    def __init__(self, parent=None):
        super().__init__(parent)
        self._y_obs_column = 1
        self._sy_obs_column = 2
        self._y_max = 1
        self._y_min = 0
        self._upperSeriesRefs = [
        ]  # list of references to QML LineSeries (for 2 charts)
        self._lowerSeriesRefs = [
        ]  # list of references to QML LineSeries (for 2 charts)
        self._log = logger.getLogger(self.__class__.__module__)

    def _setModelsFromProjectDict(self):
        """
        Create the model needed for GUI measured data table and chart.
        """
        self._log.info("Starting to set Model from Project Dict")

        for experiment_id, experiment_dict in self._project_dict[
                'experiments'].items():

            reduced_experiment_dict = {}
            for key, value in experiment_dict['measured_pattern'].items():
                if value is not None:
                    reduced_experiment_dict[key] = value

            column_count = len(reduced_experiment_dict)
            row_count = len(list(reduced_experiment_dict.values())[0])

            self._model.blockSignals(True)
            self._headers_model.blockSignals(True)

            self._model.clear()
            self._model.setColumnCount(column_count)
            self._model.setRowCount(row_count)

            self._headers_model.clear()
            self._headers_model.setColumnCount(column_count)
            self._headers_model.setRowCount(1)

            # Add all the columns from experiment_dict['measured_pattern'] to self._model
            for colum_index, (data_id, data_list) in enumerate(
                    reduced_experiment_dict.items()):
                index = self._headers_model.index(0, colum_index)
                self._headers_model.setData(index, data_id, Qt.DisplayRole)

                for row_index, value in enumerate(data_list):
                    index = self._model.index(row_index, colum_index)
                    self._model.setData(index, value, Qt.DisplayRole)

            self._model.blockSignals(False)
            self._headers_model.blockSignals(False)

            # Emit signal which is caught by the QStandardItemModel-based
            # QML GUI elements in order to update their views
            self._model.layoutChanged.emit()
            self._headers_model.layoutChanged.emit()

            # Update chart series here, as this method is significantly
            # faster, compared to the updating at the QML GUI side via the
            # QStandardItemModel
            self._updateQmlChartViewSeries()

        self._log.info("Finished setting Model from Project Dict")

    def _updateQmlChartViewSeries(self):
        """
        Updates QML LineSeries of ChartView.
        """
        self._log.info("Starting update of ChartView")

        # Indices of the self._model columns to be plotted on chart
        x_column = 0

        # Get values from model
        x_list = []
        y_obs_lower_list = []
        y_obs_upper_list = []
        for row_index in range(self._model.rowCount()):
            x = self._model.data(self._model.index(row_index, x_column))
            y_obs = self._model.data(
                self._model.index(row_index, self._y_obs_column))
            sy_obs = self._model.data(
                self._model.index(row_index, self._sy_obs_column))
            x_list.append(x)
            y_obs_lower_list.append(y_obs - sy_obs)
            y_obs_upper_list.append(y_obs + sy_obs)

        # Update Min and Max
        self._setYMax(max(y_obs_upper_list))
        self._setYMin(min(y_obs_lower_list))

        # Clear series
        upper_series = []
        lower_series = []

        # Insert data into the Series format with QPointF's
        for x, y_obs_lower, y_obs_upper in zip(x_list, y_obs_lower_list,
                                               y_obs_upper_list):
            upper_series.append(QPointF(x, y_obs_upper))
            lower_series.append(QPointF(x, y_obs_lower))

        # Replace series
        for s in self._upperSeriesRefs:
            s.replace(upper_series)
        for s in self._lowerSeriesRefs:
            s.replace(lower_series)

        self._log.info("Finished update of ChartView")

    def onProjectChanged(self):
        """
        Reimplement BaseModel method, as we do not want to update measured data every time.
        """
        pass

    @Slot(QtCharts.QXYSeries)
    def setLowerSeries(self, series):
        """
        Sets lower series to be a reference to the QML LineSeries of ChartView.
        """
        self._lowerSeriesRefs.append(series)

    @Slot(QtCharts.QXYSeries)
    def setUpperSeries(self, series):
        """
        Sets upper series to be a reference to the QML LineSeries of ChartView.
        """
        self._upperSeriesRefs.append(series)

    @Slot(str)
    def setDataType(self, type):
        """
        Sets data type to be displayed on QML ChartView.
        """
        self._log.debug(type)
        if (type == "Sum"):
            self._y_obs_column = 1
            self._sy_obs_column = 2
        elif (type == "Difference"):
            self._y_obs_column = 3
            self._sy_obs_column = 4
        elif (type == "Up"):
            self._y_obs_column = 5
            self._sy_obs_column = 6
        elif (type == "Down"):
            self._y_obs_column = 7
            self._sy_obs_column = 8
        self._updateQmlChartViewSeries()

    def _yMax(self):
        """
        Returns max value for Y-axis.
        """
        return self._y_max

    def _yMin(self):
        """
        Returns min value for Y-axis.
        """
        return self._y_min

    def _setYMax(self, value):
        """
        Sets max value for Y-axis.
        """
        self._y_max = value
        self._yMaxChanged.emit()

    def _setYMin(self, value):
        """
        Sets min value for Y-axis.
        """
        self._y_min = value
        self._yMinChanged.emit()

    _yMaxChanged = Signal()
    _yMinChanged = Signal()

    yMax = Property(float, _yMax, notify=_yMaxChanged)
    yMin = Property(float, _yMin, notify=_yMinChanged)
Esempio n. 11
0
 class commu(QObject):
     someSignal = Signal()
Esempio n. 12
0
class WorkerManager(QAbstractListModel):
    """
    Manager to handle our worker queues and state.
    Also functions as a Qt data model for a view
    displaying progress for each worker.

    """

    _workers = {}
    _state = {}

    status = Signal(str)

    def __init__(self):
        super().__init__()

        # Create a threadpool for our workers.
        self.threadpool = QThreadPool()
        # self.threadpool.setMaxThreadCount(1)
        self.max_threads = self.threadpool.maxThreadCount()
        print("Multithreading with maximum %d threads" % self.max_threads)

        self.status_timer = QTimer()
        self.status_timer.setInterval(100)
        self.status_timer.timeout.connect(self.notify_status)
        self.status_timer.start()

    def notify_status(self):
        n_workers = len(self._workers)
        running = min(n_workers, self.max_threads)
        waiting = max(0, n_workers - self.max_threads)
        self.status.emit("{} running, {} waiting, {} threads".format(
            running, waiting, self.max_threads))

    def enqueue(self, worker):
        """
        Enqueue a worker to run (at some point) by passing it to the QThreadPool.
        """
        worker.signals.error.connect(self.receive_error)
        worker.signals.status.connect(self.receive_status)
        worker.signals.progress.connect(self.receive_progress)
        worker.signals.finished.connect(self.done)

        self.threadpool.start(worker)
        self._workers[worker.job_id] = worker

        # Set default status to waiting, 0 progress.
        self._state[worker.job_id] = DEFAULT_STATE.copy()

        self.layoutChanged.emit()

    def receive_status(self, job_id, status):
        self._state[job_id]["status"] = status
        self.layoutChanged.emit()

    def receive_progress(self, job_id, progress):
        self._state[job_id]["progress"] = progress
        self.layoutChanged.emit()

    def receive_error(self, job_id, message):
        print(job_id, message)

    def done(self, job_id):
        """
        Task/worker complete. Remove it from the active workers
        dictionary. We leave it in worker_state, as this is used to
        to display past/complete workers too.
        """
        del self._workers[job_id]
        self.layoutChanged.emit()

    def cleanup(self):
        """
        Remove any complete/failed workers from worker_state.
        """
        for job_id, s in list(self._state.items()):
            if s["status"] in (STATUS_COMPLETE, STATUS_ERROR):
                del self._state[job_id]
        self.layoutChanged.emit()

    # Model interface
    def data(self, index, role):
        if role == Qt.DisplayRole:
            # See below for the data structure.
            job_ids = list(self._state.keys())
            job_id = job_ids[index.row()]
            return job_id, self._state[job_id]

    def rowCount(self, index):
        return len(self._state)
Esempio n. 13
0
class KnechtCollectVariants(QObject):
    """ Helper class for KnechtEditor to collect variants from the current model """
    reset_missing = Signal()
    recursion_limit = 3

    def __init__(self, view):
        """ Collects variants, so basically Name and Value fields of model items.

        :param modules.itemview.treeview.KnechtTreeView view: The Tree View to collect variants from
        """
        super(KnechtCollectVariants, self).__init__()
        self.view = view
        self.recursion_depth = 0

    @staticmethod
    def _camera_item_override(index, variants) -> KnechtVariantList:
        """ When directly collecting a camera item, knecht-setting send_camera_data should
            be ignored. knecht_deltagen module will only check variant commands of type camera_command
        """
        if index.siblingAtColumn(Kg.TYPE).data(Qt.DisplayRole) == Kg.xml_tag_by_user_type[Kg.camera_item]:
            for v in variants.variants:
                if v.item_type == 'camera_command':
                    v.item_type = 'command'

        return variants

    def collect_current_index(self, collect_reset: bool=True) -> KnechtVariantList:
        """ Collect variants from the current model index """
        index, __ = self.view.editor.get_current_selection()

        variants = self.collect_index(index, collect_reset)
        variants = self._camera_item_override(index, variants)

        return variants

    def collect_index(self, index: QModelIndex, collect_reset: bool=True) -> KnechtVariantList:
        """ Collect variants from the given model index """
        self.recursion_depth = 0
        src_model = self.view.model().sourceModel()

        if not index or not index.isValid():
            return KnechtVariantList()

        return self._collect(index, src_model, collect_reset)

    def _collect(self, index: QModelIndex, src_model: KnechtModel, collect_reset: bool=True
                 ) -> KnechtVariantList:
        variants = KnechtVariantList()
        current_item = src_model.get_item(index)
        reset_found = False

        if not current_item:
            return variants

        if KnechtSettings.dg['reset'] and collect_reset and current_item.userType != Kg.camera_item:
            reset_found = self._collect_reset_preset(variants, src_model)

        if current_item.userType == Kg.reference:
            # Current item is reference, use referenced item instead
            ref_id = current_item.reference
            current_item = src_model.id_mgr.get_preset_from_id(ref_id)

            if not current_item:
                return variants

        if current_item.userType in (Kg.variant, Kg.output_item):
            self._add_variant(current_item, variants, src_model)
            return variants

        variants.preset_name = current_item.data(Kg.NAME)
        variants.preset_id = current_item.preset_id
        self._collect_preset_variants(current_item, variants, src_model)

        if not reset_found and collect_reset and variants.plm_xml_path is None:
            self.reset_missing.emit()

        return variants

    def _collect_reset_preset(self, variants_ls: KnechtVariantList, src_model: KnechtModel):
        reset_presets = list()

        for item in src_model.id_mgr.iterate_presets():
            if item.data(Kg.TYPE) == 'reset':
                reset_presets.append(item)

        if not reset_presets:
            return False

        for reset_preset in reset_presets:
            self._collect_preset_variants(reset_preset, variants_ls, src_model)

        return True

    def _collect_preset_variants(self, preset_item: KnechtItem, variants_ls: KnechtVariantList,
                                 src_model: KnechtModel) -> None:
        self.recursion_depth = 0
        self._collect_preset_variants_recursive(preset_item, variants_ls, src_model)

    def _collect_preset_variants_recursive(self, preset_item: KnechtItem, variants_ls: KnechtVariantList,
                                           src_model: KnechtModel) -> None:
        if self.recursion_depth > self.recursion_limit:
            LOGGER.warning('Recursion limit reached while collecting references! Aborting further collections!')
            return

        if preset_item.userType == Kg.camera_item:
            self._add_camera_variants(preset_item, variants_ls, src_model)
            return

        for child in self._order_children(preset_item):
            self._add_variant(child, variants_ls, src_model)

            if child.userType == Kg.reference:
                ref_preset = self._collect_single_reference(child, src_model)

                if ref_preset.userType in (Kg.output_item, Kg.plmxml_item):
                    self._add_variant(ref_preset, variants_ls, src_model)
                    continue

                if ref_preset.userType == Kg.camera_item:
                    self._add_camera_variants(ref_preset, variants_ls, src_model)
                    continue

                if ref_preset:
                    self.recursion_depth += 1
                    self._collect_preset_variants(ref_preset, variants_ls, src_model)

    @staticmethod
    def _add_variant(item: KnechtItem, variants: KnechtVariantList, src_model: KnechtModel) -> None:
        if item.userType == Kg.variant:
            index = src_model.get_index_from_item(item)
            variants.add(index, item.data(Kg.NAME), item.data(Kg.VALUE), item.data(Kg.TYPE))
        elif item.userType == Kg.output_item:
            variants.output_path = item.data(Kg.VALUE)
            LOGGER.debug('Collected output path: %s', item.data(Kg.VALUE))
        elif item.userType == Kg.plmxml_item:
            variants.plm_xml_path = item.data(Kg.VALUE)
            LOGGER.debug('Collected PlmXml path: %s', item.data(Kg.VALUE))

    @staticmethod
    def _add_camera_variants(item: KnechtItem, variants: KnechtVariantList, src_model: KnechtModel) -> None:
        """ Convert Camera Preset items to camera command variants """
        for child in item.iter_children():
            camera_tag, camera_value = child.data(Kg.NAME), child.data(Kg.VALUE)

            if camera_tag in KnechtImageCameraInfo.rtt_camera_cmds:
                index = src_model.get_index_from_item(child)
                camera_cmd = KnechtImageCameraInfo.rtt_camera_cmds.get(camera_tag)
                camera_value = camera_value.replace(' ', '')

                try:
                    camera_cmd = camera_cmd.format(*camera_value.split(','))
                except Exception as e:
                    LOGGER.warning('Camera Info Tag Value does not match %s\n%s', camera_value, e)

                variants.add(index, camera_tag, camera_cmd, 'camera_command')

                LOGGER.debug('Collecting Camera Command %s: %s', camera_tag, camera_cmd)

    @classmethod
    def _order_children(cls, preset_item: KnechtItem) -> List[KnechtItem]:
        """ The children list of an item corresponds to the source indices which
            do not necessarily reflect the item order by order column.
            We create a list ordered by the order column value of each child.
        """
        return cls.order_items_by_order_column(preset_item.iter_children())

    @staticmethod
    def order_items_by_order_column(items: List[KnechtItem]):
        item_order_ls, item_ls = list(), list()

        for item in items:
            order = int(item.data(Kg.ORDER))
            item_order_ls.append(order)

            insert_idx = bisect_left(sorted(item_order_ls), order)
            item_ls.insert(insert_idx, item)

        return item_ls

    @staticmethod
    def _collect_single_reference(item, src_model) -> Union[KnechtItem, None]:
        ref_id: QUuid = item.reference

        if ref_id:
            return src_model.id_mgr.get_preset_from_id(ref_id)
        else:
            return KnechtItem()
class WorkerSignals(QObject):
    finished = Signal(object)

    def __init__(self):
        super(WorkerSignals, self).__init__()
Esempio n. 15
0
class TreeViewQLibrary(QTreeView):
    """Handles editing a QComponent

    This class extend the `QTreeView`
    """

    qlibrary_filepath_signal = Signal(str)

    def __init__(self, parent: QWidget):
        """
        Inits TreeViewQLibrary
        Args:
            parent (QtWidgets.QWidget): parent widget
        """
        QTreeView.__init__(self, parent)
        self.tool_tip_str = "Library of QComponents"

    def setModel(self, model: QtCore.QAbstractItemModel):
        """Overriding setModel to ensure only correct QAbstractItemModel subclass is used

        Args:
            model (QtCore.QAbstractItemModel): Model to be set

        Raises:
            Exception: QLibraryGUIException if model is not LibraryFileProxyModel

        """
        if not isinstance(model, LibraryFileProxyModel):
            raise QLibraryGUIException(
                f"Invalid model. Expected type {LibraryFileProxyModel} but got type {type(model)}"
            )

        super().setModel(model)

    def mousePressEvent(self, event: QtGui.QMouseEvent):
        """Overrides inherited mousePressEvent to emit appropriate filepath signals
         based on which columns were clicked, and to allow user to clear any selections
        by clicking off the displayed tree.

        Args:
            event (QtGui.QMouseEvent): QMouseEvent triggered by user
        """

        index = self.indexAt(event.pos())

        if index.row() == -1:
            self.clearSelection()
            self.setCurrentIndex(QModelIndex())
            return super().mousePressEvent(event)

        model = self.model()
        source_model = self.model().sourceModel()
        full_path = source_model.filePath(model.mapToSource(index))

        if index.column() == source_model.FILENAME:
            if not source_model.isDir(model.mapToSource(index)):
                qis_abs_path = full_path[full_path.
                                         index(__name__.split('.')[0]):]
                self.qlibrary_filepath_signal.emit(qis_abs_path)

        return super().mousePressEvent(event)

    def setToolTip(self, qcomp_tooltip: str):
        """
        Sets tooltip

        Args:
            qcomp_tooltip (str): Tooltip to be set

        """
        if qcomp_tooltip is None or len(qcomp_tooltip) < 1:
            super().setToolTip(self.tool_tip_str)
        else:
            super().setToolTip(qcomp_tooltip)
Esempio n. 16
0
class ServerInterfaceAllStatements(QObject):
    """
    High level interface to lean server.
    """
    ############################################
    # Qt Signals
    ############################################
    proof_state_change = Signal(ProofState)

    update_started = Signal()
    update_ended = Signal()

    #    proof_no_goals              = Signal()

    # Signal emitted when all effective codes have been received,
    # so that history_replace is called
    # effective_code_received     = Signal(CodeForLean)

    ############################################
    # Init, and state control
    ############################################
    def __init__(self, nursery):
        super().__init__()

        self.log = logging.getLogger("ServerInterfaceAllStatements")

        # Lean environment
        self.lean_env: LeanEnvironment = LeanEnvironment(inst)
        # Lean attributes
        self.lean_file_content: str = ""
        self.lean_server: LeanServer = LeanServer(nursery, self.lean_env)
        self.nursery: trio.Nursery = nursery

        # Set server callbacks
        self.lean_server.on_message_callback = self.__on_lean_message
        self.lean_server.running_monitor.on_state_change_callback = \
            self.__on_lean_state_change

        # Course data
        self.course = None
        self.statement_from_hypo_line = dict()
        self.statement_from_targets_line = dict()
        self.statements = []
        self.hypo_analysis = None
        self.targets_analysis = None
        self.pf_counter = 0
        # Current proof state + Events
        self.file_invalidated = trio.Event()
        # self.__proof_state_valid       = trio.Event()

        # __proof_receive_done is set when enough information have been
        # received, i.e. either we have context and target and all effective
        # codes, OR an error message.
        self.__proof_receive_done = trio.Event()
        # self.__proof_receive_error     = trio.Event()  # Set if request
        # failed

        # When some CodeForLean iss sent to the __update method, it will be
        # duplicated and stored in __tmp_effective_code. This attribute will
        # be progressively modified into an effective code which is devoid
        # of or_else combinator, according to the "EFFECTIVE CODE" messages
        # sent by Lean.

        # Errors memory channels
        self.error_send, self.error_recv = \
            trio.open_memory_channel(max_buffer_size=1024)

    async def start(self):
        """
        Asynchronously start the Lean server.
        """
        await self.lean_server.start()

    def stop(self):
        """
        Stop the Lean server.
        """
        self.lean_server.stop()

    ############################################
    # Callbacks from lean server
    ############################################
    def __check_receive_state(self, index):
        """
        Check if every awaited piece of information has been received:
        """
        hypo = self.hypo_analysis[index]
        target = self.targets_analysis[index]
        if hypo and target:
            st = self.statements[index]
            if not st.initial_proof_state:
                ps = ProofState.from_lean_data(hypo, target)
                st.initial_proof_state = ps
                self.pf_counter += 1

            # TODO: check all statements
            if self.pf_counter == len(self.statements):
                self.__proof_receive_done.set()

    def __on_lean_message(self, msg: Message):
        """
        Treatment of relevant Lean messages. Note that the text may contain
        several lines. Error messages are treated via the __filter_error
        method. Other relevant messages are
        - message providing the new context,
        - message providing the new target,
        - messages providing the successful effective codes that will be
        used to replace the "or else" sequences of instructions.
        After relevant messages, the __check_receive_state method is called
        to check if all awaited messages have been received.
        """

        txt = msg.text
        # self.log.debug("Lean message: " + txt)
        line = msg.pos_line
        severity = msg.severity

        if severity == Message.Severity.error:
            self.log.error(f"Lean error at line {msg.pos_line}: {txt}")
            self.__filter_error(msg)  # Record error ?

        elif severity == Message.Severity.warning:
            if not txt.endswith(LEAN_USES_SORRY):
                self.log.warning(f"Lean warning at line {msg.pos_line}: {txt}")

        elif txt.startswith("context:"):
            st = self.statement_from_hypo_line[line]
            index = self.statements.index(st)
            self.log.info(f"Got new context for statmnt {st.lean_name}, "
                          f"index = {index}")
            self.hypo_analysis[index] = txt

            self.__check_receive_state(index)

        elif txt.startswith("targets:"):
            st = self.statement_from_targets_line[line]
            index = self.statements.index(st)
            self.log.info(f"Got new targets for statmnt {st.lean_name}, "
                          f"index = {index}")
            self.targets_analysis[index] = txt

            self.__check_receive_state(index)

    def __on_lean_state_change(self, is_running: bool):
        self.log.info(f"New lean state: {is_running}")
        self.is_running = is_running

    ############################################
    # Message filtering
    ############################################
    def __filter_error(self, msg: Message):
        """
        Filter error messages from Lean,
        - according to position (i.e. ignore messages that do not correspond
         to the new part of the virtual file),
        - ignore "proof uses sorry" messages.
        """
        # Filter message text, record if not ignored message
        if msg.text.startswith(LEAN_NOGOALS_TEXT):
            if hasattr(self.proof_no_goals, "emit"):
                self.proof_no_goals.emit()
                self.__proof_receive_done.set()  # Done receiving
        elif msg.text.startswith(LEAN_UNRESOLVED_TEXT):
            pass
        # Ignore messages that do not concern current proof
        elif msg.pos_line < self.lean_file.first_line_of_last_change \
                or msg.pos_line > self.lean_file.last_line_of_inner_content:
            pass
        else:
            self.error_send.send_nowait(msg)
            self.__proof_receive_done.set()  # Done receiving

    ############################################
    # Update
    ############################################
    async def get_proof_states(self):
        """
        Call Lean server to update the proof_state.
        """
        file_name = str(self.course.relative_course_path)
        req = SyncRequest(file_name=file_name, content=self.lean_file_content)

        # Invalidate events
        self.file_invalidated = trio.Event()
        self.__proof_receive_done = trio.Event()

        resp = await self.lean_server.send(req)

        if resp.message == "file invalidated":
            self.file_invalidated.set()

            #########################################
            # Waiting for all pieces of information #
            #########################################
            await self.__proof_receive_done.wait()

            self.log.debug(_("All proof states received"))
            # Next line removed by FLR
            # await self.lean_server.running_monitor.wait_ready()

            if hasattr(self.update_ended, "emit"):
                self.update_ended.emit()

        # Emit exceptions ? TODO: adapt
        # error_list = []
        # try:
        #     while True:
        #         error_list.append(self.error_recv.receive_nowait())
        # except trio.WouldBlock:
        #     pass
        #
        # if error_list:
        #     raise exceptions.FailedRequestError(error_list)

    ############################################
    # Exercise initialisation
    ############################################
    async def set_statements(self, course: Course, statements: [] = None):
        """
        Set course, statements, and insert hypo_analysis / targets_analysis
        for each statement.
        """

        self.course = course
        file_content = self.course.file_content
        lines = file_content.splitlines()
        if not statements:
            self.statements = course.statements
        else:
            self.statements = statements

        self.log.info({f"Getting proof states for course {course.title}"})
        self.log.info({f"{len(self.statements)} statements"})

        self.hypo_analysis = [None] * len(self.statements)
        self.targets_analysis = [None] * len(self.statements)

        hypo_tactic = "    hypo_analysis,"
        targets_tactic = "    targets_analysis,"

        shift = 0  # Shift due to line insertion/deletion
        for statement in self.statements:
            # self.log.debug(f"Statement n° {self.statements.index(
            # statement)}")
            begin_line = statement.lean_begin_line_number + shift
            end_line = statement.lean_end_line_number + shift
            # self.log.debug(f"{len(lines)} lines")
            # self.log.debug(f"begin, end =  {begin_line, end_line}")
            proof_lines = list(range(begin_line, end_line - 1))
            # self.log.debug(proof_lines)
            proof_lines.reverse()
            for index in proof_lines:
                lines.pop(index)
            lines.insert(begin_line, hypo_tactic)
            lines.insert(begin_line + 1, targets_tactic)
            self.statement_from_hypo_line[begin_line + 1] = statement
            self.statement_from_targets_line[begin_line + 2] = statement
            # No shift if end_line = begin_line + 3
            shift += 3 - (end_line - begin_line)
            # self.log.debug(f"Shift: {shift}")
            # Construct virtual file

        file_content = "\n".join(lines)
        # self.log.debug(file_content)

        self.lean_file_content = file_content

        await self.get_proof_states()
Esempio n. 17
0
class TreeModelU(TreeModel):
    checkChanged = Signal(int, str)

    def __init__(self, headers, data, tablemodel, editables, parent=None):
        super(TreeModel, self).__init__(parent)

        rootData = [header for header in headers]
        self.rootItem = TreeItem(rootData)
        self.treeDict = data
        self.tablemodel = tablemodel
        self.editableKeys = editables
        self.examingParents = False
        self.examiningChildren = False
        self.setupModelData(data, self.rootItem)
        self.checkList()

    def checkList(self):
        missingSources = []
        displayMessage = False
        for i in range(len(self.tablemodel.templatesources)):
            if "Custom Input" not in self.tablemodel.templatesources[i]:
                if self.tablemodel.templatesources[
                        i] not in self.tablemodel.newmetadatasources:
                    missingSources.append(self.tablemodel.templatesources[i])
                    displayMessage = True

        for i in range(len(self.tablemodel.newmetadataList)):
            self.tablemodel.addRow(self.tablemodel.newmetadataList[i])

        if displayMessage:
            QMessageBox.warning(
                None, QApplication.applicationDisplayName(),
                "Bad stuff happens. " +
                "The file extracted is missing Source: \n\n" +
                str(missingSources))

        self.tablemodel.newmetadataList = []
        self.tablemodel.newmetadatasources = []

    def setupModelData(self, data, parent):
        visited = {}
        queue = []
        grandParents = {}

        for key in data.keys():
            visited[(parent.itemData[0])] = [key]
            queue.append((key, parent, ""))
            grandParents[key] = (data[key], parent)
        curDict = data
        tempSource = ""
        while queue:
            poppedItem = queue.pop(0)
            child = poppedItem[0]
            parentOfChild = poppedItem[1]
            childSource = poppedItem[2]
            parent = parentOfChild
            parent.insertChildren(parent.childCount(), 1,
                                  self.rootItem.columnCount())
            parent.child(parent.childCount() - 1).setData(0, child)

            if child in grandParents:

                curDict = grandParents[child][0]
                tempSource = childSource + child + "/"
                for curChild in range(grandParents[child][1].childCount()):
                    if child == grandParents[child][1].child(
                            curChild).itemData[0]:
                        parent = grandParents[child][1].child(curChild)
                        visited[(parent.itemData[0])] = []

            if isinstance(curDict, dict):
                for key in curDict.keys():
                    if key not in visited[(parent.itemData[0])]:
                        visited[(parent.itemData[0])].append(key)
                        queue.append((key, parent, tempSource))
                        if (isinstance(curDict[key], dict)):
                            grandParents[key] = (curDict[key], parent)
                        else:
                            self.tablemodel.prepRow(curDict, tempSource, key)
class FitGrainsResultsDialog(QObject):
    finished = Signal()

    def __init__(self, data, parent=None):
        super(FitGrainsResultsDialog, self).__init__()

        self.ax = None
        self.cmap = hexrd.ui.constants.DEFAULT_CMAP
        self.data = data
        self.data_model = FitGrainsResultsModel(data)
        self.canvas = None
        self.fig = None
        self.scatter_artist = None
        self.colorbar = None

        loader = UiLoader()
        self.ui = loader.load_file('fit_grains_results_dialog.ui', parent)
        flags = self.ui.windowFlags()
        self.ui.setWindowFlags(flags | Qt.Tool)
        self.ui.splitter.setStretchFactor(0, 1)
        self.ui.splitter.setStretchFactor(1, 10)

        self.setup_tableview()

        # Add column for equivalent strain
        ngrains = self.data.shape[0]
        eqv_strain = np.zeros(ngrains)
        for i in range(ngrains):
            emat = vecMVToSymm(self.data[i, 15:21], scale=False)
            eqv_strain[i] = 2.*np.sqrt(np.sum(emat*emat))/3.
        np.append(self.data, eqv_strain)

        self.setup_gui()

    def setup_gui(self):
        self.setup_selectors()
        self.setup_plot()
        self.setup_toolbar()
        self.setup_view_direction_options()
        self.setup_connections()
        self.on_colorby_changed()
        self.backup_ranges()
        self.update_ranges_gui()

    def clear_artists(self):
        # Colorbar must be removed before the scatter artist
        if self.colorbar is not None:
            self.colorbar.remove()
            self.colorbar = None

        if self.scatter_artist is not None:
            self.scatter_artist.remove()
            self.scatter_artist = None

    def on_colorby_changed(self):
        column = self.ui.plot_color_option.currentData()
        colors = self.data[:, column]

        xs = self.data[:, 6]
        ys = self.data[:, 7]
        zs = self.data[:, 8]
        sz = matplotlib.rcParams['lines.markersize'] ** 3

        # I could not find a way to update scatter plot marker colors and
        # the colorbar mappable. So we must re-draw both from scratch...
        self.clear_artists()
        self.scatter_artist = self.ax.scatter3D(
            xs, ys, zs, c=colors, cmap=self.cmap, s=sz)
        self.colorbar = self.fig.colorbar(self.scatter_artist, shrink=0.8)
        self.draw()

    def on_export_button_pressed(self):
        selected_file, selected_filter = QFileDialog.getSaveFileName(
            self.ui, 'Export Fit-Grains Results', HexrdConfig().working_dir,
            'Output files (*.out)|All files(*.*)')

        if selected_file:
            HexrdConfig().working_dir = os.path.dirname(selected_file)
            name, ext = os.path.splitext(selected_file)
            if not ext:
                selected_file += '.out'

            self.data_model.save(selected_file)

    def on_sort_indicator_changed(self, index, order):
        """Shows sort indicator for columns 0-2, hides for all others."""
        if index < 3:
            self.ui.table_view.horizontalHeader().setSortIndicatorShown(True)
            self.ui.table_view.horizontalHeader().setSortIndicator(
                index, order)
        else:
            self.ui.table_view.horizontalHeader().setSortIndicatorShown(False)

    @property
    def projection(self):
        name_map = {
            'Perspective': 'persp',
            'Orthographic': 'ortho'
        }
        return name_map[self.ui.projection.currentText()]

    def projection_changed(self):
        self.ax.set_proj_type(self.projection)
        self.draw()

    def setup_connections(self):
        self.ui.export_button.clicked.connect(self.on_export_button_pressed)
        self.ui.projection.currentIndexChanged.connect(self.projection_changed)
        self.ui.plot_color_option.currentIndexChanged.connect(
            self.on_colorby_changed)
        self.ui.hide_axes.toggled.connect(self.update_axis_visibility)
        self.ui.finished.connect(self.finished)

        for name in ('x', 'y', 'z'):
            action = getattr(self, f'set_view_{name}')
            action.triggered.connect(partial(self.reset_view, name))

        for w in self.range_widgets:
            w.valueChanged.connect(self.update_ranges_mpl)
            w.valueChanged.connect(self.update_range_constraints)

        self.ui.reset_ranges.pressed.connect(self.reset_ranges)

    def setup_plot(self):
        # Create the figure and axes to use
        canvas = FigureCanvas(Figure(tight_layout=True))

        # Get the canvas to take up the majority of the screen most of the time
        canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

        fig = canvas.figure
        ax = fig.add_subplot(111, projection='3d', proj_type=self.projection)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        self.ui.canvas_layout.addWidget(canvas)

        self.fig = fig
        self.ax = ax
        self.canvas = canvas

    def setup_toolbar(self):
        # These don't work for 3D plots
        # "None" removes the separators
        button_blacklist = [
            None,
            'Pan',
            'Zoom',
            'Subplots',
            'Customize'
        ]
        self.toolbar = NavigationToolbar(self.canvas, self.ui, False,
                                         button_blacklist)
        self.ui.toolbar_layout.addWidget(self.toolbar)
        self.ui.toolbar_layout.setAlignment(self.toolbar, Qt.AlignCenter)

        # Make sure our ranges editor gets updated any time matplotlib
        # might have modified the ranges underneath.
        self.toolbar.after_home_callback = self.update_ranges_gui
        self.toolbar.after_back_callback = self.update_ranges_gui
        self.toolbar.after_forward_callback = self.update_ranges_gui

    def setup_view_direction_options(self):
        b = self.ui.set_view_direction

        m = QMenu(b)
        self.set_view_direction_menu = m

        self.set_view_z = m.addAction('XY')
        self.set_view_y = m.addAction('XZ')
        self.set_view_x = m.addAction('YZ')

        b.setMenu(m)

    def reset_view(self, direction):
        # The adjustment is to force the tick markers and label to
        # appear on one side.
        adjust = 1.e-5

        angles_map = {
            'x': (0, 0),
            'y': (0, 90 - adjust),
            'z': (90 - adjust, -90 - adjust)
        }
        self.ax.view_init(*angles_map[direction])

        # Temporarily hide the labels of the axis perpendicular to the
        # screen for easier viewing.
        if self.axes_visible:
            self.hide_axis(direction)

        self.draw()

        # As soon as the image is re-drawn, the perpendicular axis will
        # be visible again.
        if self.axes_visible:
            self.show_axis(direction)

    def set_axis_visible(self, name, visible):
        ax = getattr(self.ax, f'{name}axis')
        set_label_func = getattr(self.ax, f'set_{name}label')
        if visible:
            ax.set_major_locator(ticker.AutoLocator())
            set_label_func(name.upper())
        else:
            ax.set_ticks([])
            set_label_func('')

    def hide_axis(self, name):
        self.set_axis_visible(name, False)

    def show_axis(self, name):
        self.set_axis_visible(name, True)

    @property
    def axes_visible(self):
        return not self.ui.hide_axes.isChecked()

    def update_axis_visibility(self):
        for name in ('x', 'y', 'z'):
            self.set_axis_visible(name, self.axes_visible)

        self.draw()

    def setup_selectors(self):
        # Build combo boxes in code to assign columns in grains data
        blocker = QSignalBlocker(self.ui.plot_color_option)  # noqa: F841
        self.ui.plot_color_option.clear()
        self.ui.plot_color_option.addItem('Completeness', 1)
        self.ui.plot_color_option.addItem('Goodness of Fit', 2)
        self.ui.plot_color_option.addItem('Equivalent Strain', -1)
        self.ui.plot_color_option.addItem('XX Strain', 15)
        self.ui.plot_color_option.addItem('YY Strain', 16)
        self.ui.plot_color_option.addItem('ZZ Strain', 17)
        self.ui.plot_color_option.addItem('YZ Strain', 18)
        self.ui.plot_color_option.addItem('XZ Strain', 19)
        self.ui.plot_color_option.addItem('XY Strain', 20)

        index = self.ui.plot_color_option.findData(-1)
        self.ui.plot_color_option.setCurrentIndex(index)

    def setup_tableview(self):
        view = self.ui.table_view

        # Subclass QSortFilterProxyModel to restrict sorting by column
        class GrainsTableSorter(QSortFilterProxyModel):
            def sort(self, column, order):
                if column > 2:
                    return
                else:
                    super().sort(column, order)

        proxy_model = GrainsTableSorter(self.ui)
        proxy_model.setSourceModel(self.data_model)
        view.verticalHeader().hide()
        view.setModel(proxy_model)
        view.resizeColumnToContents(0)

        view.setSortingEnabled(True)
        view.horizontalHeader().sortIndicatorChanged.connect(
            self.on_sort_indicator_changed)
        view.sortByColumn(0, Qt.AscendingOrder)
        self.ui.table_view.horizontalHeader().setSortIndicatorShown(False)

    def show(self):
        self.ui.show()

    @property
    def range_widgets(self):
        widgets = []
        for name in ('x', 'y', 'z'):
            for i in range(2):
                widgets.append(getattr(self.ui, f'range_{name}_{i}'))

        return widgets

    @property
    def ranges_gui(self):
        return [w.value() for w in self.range_widgets]

    @ranges_gui.setter
    def ranges_gui(self, v):
        self.remove_range_constraints()
        for x, w in zip(v, self.range_widgets):
            w.setValue(round(x, 5))
        self.update_range_constraints()

    @property
    def ranges_mpl(self):
        vals = []
        for name in ('x', 'y', 'z'):
            lims_func = getattr(self.ax, f'get_{name}lim')
            vals.extend(lims_func())
        return vals

    @ranges_mpl.setter
    def ranges_mpl(self, v):
        for i, name in enumerate(('x', 'y', 'z')):
            lims = (v[i * 2], v[i * 2 + 1])
            set_func = getattr(self.ax, f'set_{name}lim')
            set_func(*lims)

        # Update the navigation stack so the home/back/forward
        # buttons will know about the range change.
        self.toolbar.push_current()

        self.draw()

    def update_ranges_mpl(self):
        self.ranges_mpl = self.ranges_gui

    def update_ranges_gui(self):
        blocked = [QSignalBlocker(w) for w in self.range_widgets]  # noqa: F841
        self.ranges_gui = self.ranges_mpl

    def backup_ranges(self):
        self._ranges_backup = self.ranges_mpl

    def reset_ranges(self):
        self.ranges_mpl = self._ranges_backup
        self.update_ranges_gui()

    def remove_range_constraints(self):
        widgets = self.range_widgets
        for w1, w2 in zip(widgets[0::2], widgets[1::2]):
            w1.setMaximum(sys.float_info.max)
            w2.setMinimum(sys.float_info.min)

    def update_range_constraints(self):
        widgets = self.range_widgets
        for w1, w2 in zip(widgets[0::2], widgets[1::2]):
            w1.setMaximum(w2.value())
            w2.setMinimum(w1.value())

    def draw(self):
        self.canvas.draw()
class TabularViewHeaderWidget(QFrame):
    """A draggable QWidget."""

    header_dropped = Signal(object, object, str)
    _H_MARGIN = 3
    _SPACING = 16

    def __init__(self, identifier, area, menu=None, parent=None):
        """

        Args:
            identifier (str)
            area (str): either "rows", "columns", or "frozen"
            menu (FilterMenu, optional)
            parent (QWidget, optional): Parent widget
        """
        super().__init__(parent=parent)
        self._identifier = identifier
        self._area = area
        layout = QHBoxLayout(self)
        button = QToolButton(self)
        button.setPopupMode(QToolButton.InstantPopup)
        button.setStyleSheet("QToolButton {border: none;}")
        button.setEnabled(menu is not None)
        if menu:
            self.menu = menu
            button.setMenu(self.menu)
            self.menu.anchor = self
        self.drag_start_pos = None
        label = QLabel(identifier)
        layout.addWidget(label)
        layout.addWidget(button)
        layout.setContentsMargins(self._H_MARGIN, 0, self._H_MARGIN, 0)
        if area == "rows":
            h_alignment = Qt.AlignLeft
            layout.insertSpacing(1, self._SPACING)
            button.setArrowType(Qt.DownArrow)
        elif area == "columns":
            h_alignment = Qt.AlignRight
            layout.insertSpacing(0, self._SPACING)
            button.setArrowType(Qt.RightArrow)
        elif area == "frozen":
            h_alignment = Qt.AlignHCenter
        label.setAlignment(h_alignment | Qt.AlignVCenter)
        label.setStyleSheet("QLabel {font-weight: bold;}")
        self.setAttribute(Qt.WA_DeleteOnClose)
        self.setAutoFillBackground(True)
        self.setFrameStyle(QFrame.Raised)
        self.setFrameShape(QFrame.Panel)
        self.setStyleSheet("QFrame {background: " + PIVOT_TABLE_HEADER_COLOR +
                           ";}")
        self.setAcceptDrops(True)
        self.setToolTip(
            "<p>This is a draggable header. </p>"
            "<p>Drag-and-drop it onto another header to pivot the table, "
            "or onto the Frozen table to freeze this dimension.</p>")
        self.adjustSize()
        self.setMinimumWidth(self.size().width())

    @property
    def identifier(self):
        return self._identifier

    @property
    def area(self):
        return self._area

    def mousePressEvent(self, event):
        """Register drag start position"""
        if event.button() == Qt.LeftButton:
            self.drag_start_pos = event.pos()

    # noinspection PyArgumentList, PyUnusedLocal
    def mouseMoveEvent(self, event):
        """Start dragging action if needed"""
        if not event.buttons() & Qt.LeftButton:
            return
        if not self.drag_start_pos:
            return
        if (event.pos() - self.drag_start_pos
            ).manhattanLength() < QApplication.startDragDistance():
            return
        drag = QDrag(self)
        mime_data = QMimeData()
        drag.setMimeData(mime_data)
        pixmap = self.grab()
        drag.setPixmap(pixmap)
        drag.setHotSpot(pixmap.rect().center())
        drag.exec_()

    def mouseReleaseEvent(self, event):
        """Forget drag start position"""
        self.drag_start_pos = None

    def dragEnterEvent(self, event):
        if isinstance(event.source(), TabularViewHeaderWidget):
            event.accept()

    def dropEvent(self, event):
        other = event.source()
        if other == self:
            return
        center = self.rect().center()
        drop = event.pos()
        if self.area in ("rows", "frozen"):
            position = "before" if center.x() > drop.x() else "after"
        elif self.area == "columns":
            position = "before" if center.y() > drop.y() else "after"
        self.header_dropped.emit(other, self, position)
Esempio n. 20
0
class ScrollbarWithText(QtWidgets.QWidget):
    position = Signal(int)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.horizontalWidget = QtWidgets.QWidget()
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred,
                                           QtWidgets.QSizePolicy.Maximum)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.horizontalWidget.sizePolicy().hasHeightForWidth())
        self.horizontalWidget.setSizePolicy(sizePolicy)
        self.horizontalWidget.setMaximumSize(QtCore.QSize(16777215, 25))
        self.horizontalWidget.setObjectName("horizontalWidget")
        self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalWidget)
        self.horizontalLayout.setSizeConstraint(
            QtWidgets.QLayout.SetMaximumSize)
        self.horizontalLayout.setContentsMargins(0, 0, 0, 0)

        self.horizontalLayout.setObjectName("horizontalLayout")

        self.horizontalScrollBar = QtWidgets.QScrollBar(self.horizontalWidget)
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum,
                                           QtWidgets.QSizePolicy.Maximum)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.horizontalScrollBar.sizePolicy().hasHeightForWidth())
        self.horizontalScrollBar.setSizePolicy(sizePolicy)
        self.horizontalScrollBar.setMaximumSize(QtCore.QSize(16777215, 25))
        self.horizontalScrollBar.setOrientation(QtCore.Qt.Horizontal)
        self.horizontalScrollBar.setObjectName("horizontalScrollBar")
        self.horizontalScrollBar.setPageStep(1)

        self.horizontalLayout.addWidget(self.horizontalScrollBar)

        self.plainTextEdit = QtWidgets.QLineEdit(self.horizontalWidget)

        # self.plainTextEdit = QtWidgets.QPlainTextEdit(self.horizontalWidget)
        self.plainTextEdit.setEnabled(True)
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed,
                                           QtWidgets.QSizePolicy.Maximum)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(
            self.plainTextEdit.sizePolicy().hasHeightForWidth())
        self.plainTextEdit.setSizePolicy(sizePolicy)
        self.plainTextEdit.setMaximumSize(QtCore.QSize(100, 25))
        font = QtGui.QFont()
        font.setPointSize(8)
        self.plainTextEdit.setFont(font)
        # self.plainTextEdit.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        # self.plainTextEdit.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.plainTextEdit.setObjectName("plainTextEdit")
        self.horizontalLayout.addWidget(self.plainTextEdit)
        self.setLayout(self.horizontalLayout)
        # self.ui.plainTextEdit.textChanged.connect

        self.plainTextEdit.returnPressed.connect(self.text_change)
        self.horizontalScrollBar.sliderMoved.connect(self.scrollbar_change)
        self.horizontalScrollBar.valueChanged.connect(self.scrollbar_change)
        # self.initialize_state(0)
        # self.update_state(0)
        self.update()
        # self.show()

    def sizeHint(self):
        return QtCore.QSize(240, 25)

    def text_change(self):
        value = self.plainTextEdit.text()
        value = int(value)
        self.position.emit(value)

    def scrollbar_change(self):
        value = self.horizontalScrollBar.value()
        self.position.emit(value)

    @Slot(int)
    def update_state(self, value: int):
        if self.plainTextEdit.text() != '{}'.format(value):
            self.plainTextEdit.setText('{}'.format(value))

        if self.horizontalScrollBar.value() != value:
            self.horizontalScrollBar.setValue(value)

    @Slot(int)
    def initialize_state(self, value: int):
        # print('nframes: ', value)
        self.horizontalScrollBar.setMaximum(value - 1)
        self.horizontalScrollBar.setMinimum(0)
        # self.horizontalScrollBar.sliderMoved.connect(self.scrollbar_change)
        # self.horizontalScrollBar.valueChanged.connect(self.scrollbar_change)
        self.horizontalScrollBar.setValue(0)
        self.plainTextEdit.setText('{}'.format(0))
Esempio n. 21
0
class DynamicView(QGraphicsView):
    viewChanged = Signal(QRect, float, int, int)

    def __init__(self, image, parent=None):
        super(DynamicView, self).__init__(parent)
        self.scene = QGraphicsScene()
        self.scene.setBackgroundBrush(Qt.darkGray)
        self.setScene(self.scene)
        self.set_image(image)
        self.setRenderHint(QPainter.SmoothPixmapTransform)
        self.setDragMode(QGraphicsView.ScrollHandDrag)
        self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        self.ZOOM_STEP = 0.2
        self.mouse_pressed = False
        self.next_fit = False
        self.fit_scale = 0
        self.zoom_fit()

    def set_image(self, image):
        if type(image) is QPixmap:
            pixmap = image
        elif type(image) is QImage:
            pixmap = QPixmap.fromImage(image)
        elif type(image) is np.ndarray:
            pixmap = QPixmap.fromImage(mat2img(image))
        else:
            raise TypeError(
                self.tr('DynamicView.set_image: Unsupported type: {}'.format(
                    type(image))))
        if not self.scene.items():
            self.scene.addPixmap(pixmap)
        else:
            self.scene.items()[0].setPixmap(pixmap)
        self.scene.setSceneRect(QRectF(pixmap.rect()))

    def zoom_full(self):
        self.set_scaling(1)
        self.next_fit = True
        self.notify_change()

    def zoom_fit(self):
        self.fitInView(self.scene.sceneRect(), Qt.KeepAspectRatio)
        self.fit_scale = self.matrix().m11()
        if self.fit_scale > 1:
            self.fit_scale = 1
            self.zoom_full()
        else:
            self.next_fit = False
            self.notify_change()

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.mouse_pressed = True
        QGraphicsView.mousePressEvent(self, event)

    def mouseMoveEvent(self, event):
        QGraphicsView.mouseMoveEvent(self, event)
        if self.mouse_pressed:
            self.notify_change()

    def mouseReleaseEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.mouse_pressed = False
        QGraphicsView.mouseReleaseEvent(self, event)

    def mouseDoubleClickEvent(self, event):
        if event.button() == Qt.LeftButton:
            if self.next_fit:
                self.zoom_fit()
            else:
                self.zoom_full()
        QGraphicsView.mouseDoubleClickEvent(self, event)

    def wheelEvent(self, event):
        if event.delta() > 0:
            self.change_zoom(+1)
        else:
            self.change_zoom(-1)

    def resizeEvent(self, event):
        # FIXME: Se la finestra viene massimizzata, il valore di fit_scale non si aggiorna
        if self.matrix().m11() <= self.fit_scale:
            self.zoom_fit()
        else:
            self.notify_change()
        QGraphicsView.resizeEvent(self, event)

    def change_zoom(self, direction):
        level = math.log2(self.matrix().m11())
        if direction > 0:
            level += self.ZOOM_STEP
        else:
            level -= self.ZOOM_STEP
        scaling = 2**level
        if scaling < self.fit_scale:
            scaling = self.fit_scale
            self.next_fit = False
        elif scaling > 1:
            # scaling = 1
            if scaling > 4:
                scaling = 4
            self.next_fit = True
        self.set_scaling(scaling)
        self.notify_change()

    def set_scaling(self, scaling):
        matrix = QMatrix()
        matrix.scale(scaling, scaling)
        self.setMatrix(matrix)

    def change_view(self, _, new_scaling, new_horiz, new_vert):
        old_factor = self.matrix().m11()
        old_horiz = self.horizontalScrollBar().value()
        old_vert = self.verticalScrollBar().value()
        if new_scaling != old_factor or new_horiz != old_horiz or new_vert != old_vert:
            self.set_scaling(new_scaling)
            self.horizontalScrollBar().setValue(new_horiz)
            self.verticalScrollBar().setValue(new_vert)
            self.notify_change()

    def notify_change(self):
        scene_rect = self.get_rect()
        horiz_scroll = self.horizontalScrollBar().value()
        vert_scroll = self.verticalScrollBar().value()
        zoom_factor = self.matrix().m11()
        self.viewChanged.emit(scene_rect, zoom_factor, horiz_scroll,
                              vert_scroll)

    def get_rect(self):
        top_left = self.mapToScene(0, 0).toPoint()
        if top_left.x() < 0:
            top_left.setX(0)
        if top_left.y() < 0:
            top_left.setY(0)
        view_size = self.viewport().size()
        bottom_right = self.mapToScene(view_size.width(),
                                       view_size.height()).toPoint()
        image_size = self.sceneRect().toRect()
        if bottom_right.x() >= image_size.width():
            bottom_right.setX(image_size.width() - 1)
        if bottom_right.y() >= image_size.height():
            bottom_right.setY(image_size.height() - 1)
        return QRect(top_left, bottom_right)
Esempio n. 22
0
class KeypointGroup(QtWidgets.QWidget):
    selected = Signal(int)
    data = Signal(dict)

    def __init__(self,
                 keypoint_dict,
                 scene,
                 parent=None,
                 colormap: str = 'viridis',
                 radius=20,
                 text_over_mouse=True,
                 click_type_to_add_keypoint='right'):
        super().__init__(parent)

        self.cmap = plt.get_cmap(colormap)

        colors = plt.get_cmap(colormap)(np.linspace(0, 1, len(keypoint_dict)))
        self.colors = (colors * 255).clip(0, 255).astype(np.uint8)

        self.keypoints = {}
        self.keys = list(keypoint_dict.keys())

        for i, (key, value) in enumerate(keypoint_dict.items()):
            self.keypoints[key] = Keypoint(color=self.colors[i])

        self.N = len(self.keypoints)
        self.radius = radius
        self.add_to_scene(scene)
        self.index = 0
        self.key = None
        self.selected.emit(self.index)
        self.scene = scene
        self.tmp_selected = None
        self.text_over_mouse = text_over_mouse
        self.click_type_to_add_keypoint = click_type_to_add_keypoint
        self.set_data(keypoint_dict)
        self.text = None
        self.update_text()
        # self.data = {}
        # self.scene.viewport().installEventFilter(self)

    def broadcast_data(self):
        array = self.get_keypoint_coords()
        data = {self.keys[i]: array[i] for i in range(len(self.keys))}
        self.data.emit(data)

    def set_data(self, data: dict):
        for key, value in data.items():
            if value is None:
                continue
            elif isinstance(value, list) and len(value) == 0:
                continue
            else:
                if len(value) == 2:
                    value = (value[0], value[1], self.radius)
                self.keypoints[key].set_coords(*value)

    def add_to_scene(self, scene):
        # print('add to scene')
        for key, value in self.keypoints.items():
            scene.addItem(value)

    def remove_from_scene(self):
        # print('remove from scene')
        for key, value in self.keypoints.items():
            self.scene.removeItem(value)

    def clear_data(self):
        # self.remove_from_scene()
        for i, key in enumerate(self.keys):
            self.keypoints[key].clear()
        # self.broadcast_data()

    def increment_selected(self):
        self.set_selected(self.index + 1)

    def decrement_selected(self):
        self.set_selected(self.index - 1)

    def clear_selected(self):
        self.get_keypoint(self.index).clear()

    @Slot(int)
    def set_selected(self, index: int):
        # print(index)
        if index < 0:
            warnings.warn('index below zero, bug somewhere')
            return
        elif index > self.N:
            warnings.warn('index > len(keypoints)')
            return
        elif index == self.N:
            # happens when you click the final keypoint; tries to increment above the total number
            return
        elif index == self.index:
            # don't do anything. should make sure that there isn't a loop between various widgets that control this
            return
        self.index = index
        self.update_text()
        self.selected.emit(self.index)

    def update_text(self):
        self.key = self.keys[self.index]
        if not self.text_over_mouse:
            return
        color = self.colors[self.index]
        line_color = QColor(color[0], color[1], color[2], color[3])
        face_color = QColor(color[0], color[1], color[2], int(color[3] * 0.3))
        pen = QPen(line_color, 0.5, Qt.SolidLine, Qt.FlatCap, Qt.MiterJoin)
        brush = QBrush(face_color)
        if self.text is None:
            self.text = QtWidgets.QGraphicsSimpleTextItem(self.key)
            self.scene.addItem(self.text)

        self.text.setText(self.key)
        self.text.setBrush(brush)
        self.text.setPen(pen)

    def add_keypoint(self, event):
        pos = event.scenePos()
        x, y = pos.x(), pos.y()

        key = self.keys[self.index]
        # self.keypoints[key].setRect(x, y, self.radius, self.radius)
        self.keypoints[key].set_coords(x, y, self.radius)
        self.broadcast_data()

        # print(x,y)
        self.set_selected(self.index + 1)

    def move_keypoint(self, event):
        pos = event.scenePos()
        x, y = pos.x(), pos.y()
        dists = self.get_distance_to_keypoints(x, y)

        if np.mean(np.isnan(dists)) > 0.99:
            return

        min_ind = np.nanargmin(dists)
        min_dist = dists[min_ind]
        keypoint = self.get_keypoint(min_ind)
        if min_dist < keypoint.radius and keypoint.isVisible():
            self.tmp_selected = min_ind
            # self.set_selected(min_ind)
            # print('inside')
        # else:
        #     print('not close')

    def get_keypoint_coords(self):
        coords = []
        for i, (key, value) in enumerate(self.keypoints.items()):
            x, y = value.cx, value.cy
            if x is None or y is None:
                x, y = np.nan, np.nan
            coords.append([x, y])
        coords = np.array(coords).astype(np.float32)
        return coords

    def get_distance_to_keypoints(self, x, y):
        coords = self.get_keypoint_coords()
        dists = np.sqrt((coords[:, 0] - x)**2 + (coords[:, 1] - y)**2)
        return dists

    @Slot(QtGui.QMouseEvent)
    def receive_click(self, event):
        if event.button() == QtCore.Qt.RightButton:
            if self.click_type_to_add_keypoint == 'left':
                self.move_keypoint(event)
            else:
                self.add_keypoint(event)

        elif event.button() == QtCore.Qt.LeftButton:
            if self.click_type_to_add_keypoint == 'left':
                self.add_keypoint(event)
            else:
                self.move_keypoint(event)

    @Slot(QtGui.QMouseEvent)
    def receive_move(self, event):
        pos = event.scenePos()
        x, y = pos.x(), pos.y()
        # print(x,y)
        # print(event.buttons())
        if ((self.click_type_to_add_keypoint != 'left'
             and event.buttons() == QtCore.Qt.LeftButton)
                or (self.click_type_to_add_keypoint == 'left'
                    and event.buttons() == QtCore.Qt.RightButton)):
            if self.tmp_selected is None:
                return

            self.keypoints[self.keys[self.tmp_selected]].set_coords(
                x, y, self.radius)
            self.broadcast_data()
        if self.text_over_mouse:
            if self.text is not None:
                # print(x, y)
                self.text.setPos(x + 10, y + 10)
            # print(event.scenePos())

    @Slot(QtGui.QMouseEvent)
    def receive_release(self, event):
        self.tmp_selected = None

    def get_keypoint(self, index):
        if index < 0:
            warnings.warn('index below zero, bug somewhere')
            return
        elif index > self.N:
            warnings.warn('index > len(keypoints)')
            return
        elif index == self.N:
            # happens when you click the final keypoint; tries to increment above the total number
            return

        key = self.keys[index]
        return self.keypoints[key]
class CalibrationCrystalEditor(QObject):

    # Emitted when the params get modified
    params_modified = Signal()

    def __init__(self, params=None, parent=None):
        super().__init__(parent)

        loader = UiLoader()
        self.ui = loader.load_file('calibration_crystal_editor.ui', parent)

        # Load slider widget
        self.slider_widget = CalibrationCrystalSliderWidget(parent=self.ui)
        self.ui.slider_widget_parent.layout().addWidget(self.slider_widget.ui)

        self.params = params

        self.update_gui()
        self.update_orientation_suffixes()

        self.setup_connections()

    def setup_connections(self):
        HexrdConfig().euler_angle_convention_changed.connect(
            self.euler_angle_convention_changed)

        self.ui.tab_widget.currentChanged.connect(
            self.update_tab_gui)

        for w in self.all_widgets:
            w.valueChanged.connect(self.value_changed)

        self.slider_widget.changed.connect(self.slider_widget_changed)

    @property
    def params(self):
        return self._params

    @params.setter
    def params(self, v):
        self._params = copy.deepcopy(v)
        self.update_gui()
        self.slider_widget.reset_ranges()

    def value_changed(self):
        sender = self.sender()

        if sender in self.orientation_widgets:
            self.params[:3] = self.orientation
        elif sender in self.position_widgets:
            self.params[3:6] = self.position
        else:
            # If the stretch matrix was modified, we may need to update
            # a duplicate value in the matrix.
            self.update_duplicate(sender)
            try:
                self.params[6:] = self.inverse_stretch
            except LinAlgError as e:
                self.set_matrix_invalid(str(e))
                return

            self.set_matrix_valid()

        self.params_modified.emit()

    def slider_widget_changed(self, mode, index, value):
        prefix = 'orientation' if mode == SliderWidgetMode.ORIENTATION \
            else 'position'
        name = f'{prefix}_{index}'
        w = getattr(self.ui, name)
        w.setValue(value)

    def euler_angle_convention_changed(self):
        self.update_gui()
        self.update_orientation_suffixes()

    def update_orientation_suffixes(self):
        suffix = '' if HexrdConfig().euler_angle_convention is None else '°'
        for w in self.orientation_widgets:
            w.setSuffix(suffix)
        self.slider_widget.set_orientation_suffix(suffix)

    def update_params(self):
        if self.params is None:
            return

        self.params[:3] = self.orientation
        self.params[3:6] = self.position
        self.params[6:] = self.inverse_stretch
        self.params_modified.emit()

    def update_gui(self):
        if self.params is None:
            return

        self.orientation = self.params[:3]
        self.position = self.params[3:6]
        self.inverse_stretch = self.params[6:]

        self.update_tab_gui()

    @property
    def stretch_matrix_duplicates(self):
        return {
            1: 3,
            2: 6,
            5: 7,
            7: 5,
            6: 2,
            3: 1
        }

    def update_duplicate(self, w):
        ind = int(w.objectName().replace('stretch_matrix_', ''))
        dup_ind = self.stretch_matrix_duplicates.get(ind)
        if dup_ind is not None:
            dup = getattr(self.ui, f'stretch_matrix_{dup_ind}')
            blocker = QSignalBlocker(dup)  # noqa: F841
            dup.setValue(w.value())

    def set_matrix_valid(self):
        self.set_matrix_style_sheet('background-color: white')
        self.set_matrix_tooltips('')

    def set_matrix_invalid(self, msg=''):
        self.set_matrix_style_sheet('background-color: red')
        self.set_matrix_tooltips(msg)

    def set_matrix_style_sheet(self, s):
        for w in self.stretch_matrix_widgets:
            w.setStyleSheet(s)

    def set_matrix_tooltips(self, s):
        for w in self.stretch_matrix_widgets:
            w.setToolTip(s)

    @staticmethod
    def convert_angle_convention(values, old_conv, new_conv):
        values[:] = convert_angle_convention(values, old_conv, new_conv)

    @property
    def orientation(self):
        # This automatically converts from Euler angle conventions
        values = [x.value() for x in self.orientation_widgets]
        if HexrdConfig().euler_angle_convention is not None:
            values = np.radians(values)
            convention = HexrdConfig().euler_angle_convention
            self.convert_angle_convention(values, convention, None)

        return values

    @orientation.setter
    def orientation(self, v):
        # This automatically converts to Euler angle conventions
        if HexrdConfig().euler_angle_convention is not None:
            v = copy.deepcopy(v)
            convention = HexrdConfig().euler_angle_convention
            self.convert_angle_convention(v, None, convention)
            v = np.degrees(v)

        for i, w in enumerate(self.orientation_widgets):
            blocker = QSignalBlocker(w)  # noqa: F841
            w.setValue(v[i])

    @property
    def position(self):
        return [x.value() for x in self.position_widgets]

    @position.setter
    def position(self, v):
        for i, w in enumerate(self.position_widgets):
            blocker = QSignalBlocker(w)  # noqa: F841
            w.setValue(v[i])

    @property
    def inverse_stretch(self):
        m = np.array(self.stretch_matrix).reshape(3, 3)
        return matrixutil.symmToVecMV(np.linalg.inv(m), scale=True)

    @inverse_stretch.setter
    def inverse_stretch(self, v):
        m = matrixutil.vecMVToSymm(v, scale=True)
        self.stretch_matrix = np.linalg.inv(m).flatten()

    @property
    def stretch_matrix(self):
        return [x.value() for x in self.stretch_matrix_widgets]

    @stretch_matrix.setter
    def stretch_matrix(self, v):
        for i, w in enumerate(self.stretch_matrix_widgets):
            blocker = QSignalBlocker(w)  # noqa: F841
            w.setValue(v[i])

    @property
    def orientation_widgets(self):
        # Take advantage of the naming scheme
        return [getattr(self.ui, f'orientation_{i}') for i in range(3)]

    @property
    def position_widgets(self):
        # Take advantage of the naming scheme
        return [getattr(self.ui, f'position_{i}') for i in range(3)]

    @property
    def stretch_matrix_widgets(self):
        # Take advantage of the naming scheme
        return [getattr(self.ui, f'stretch_matrix_{i}') for i in range(9)]

    @property
    def all_widgets(self):
        return (
            self.orientation_widgets +
            self.position_widgets +
            self.stretch_matrix_widgets
        )

    def update_tab_gui(self):
        """Updates slider tab contents when it becomes current tab."""
        current_widget = self.ui.tab_widget.currentWidget()
        if current_widget is self.ui.slider_tab:
            o_values = [x.value() for x in self.orientation_widgets]
            p_values = [x.value() for x in self.position_widgets]
            self.slider_widget.update_gui(o_values, p_values)
Esempio n. 24
0
class VideoFrame(QtWidgets.QGraphicsView):
    frameNum = Signal(int)
    initialized = Signal(int)

    def __init__(self,
                 videoFile: Union[str, os.PathLike] = None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        # self.videoView = QtWidgets.QGraphicsView()
        self.scene = ClickableScene(self)  # QtWidgets.QGraphicsScene(self)
        # self.scene = CroppingOverlay(parent=self)
        self._photo = QtWidgets.QGraphicsPixmapItem()
        self.scene.addItem(self._photo)

        # self.videoView.setScene(self.scene)
        self.setScene(self.scene)

        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding,
                                           QtWidgets.QSizePolicy.Expanding)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(self.sizePolicy().hasHeightForWidth())
        self.setSizePolicy(sizePolicy)
        self.setMinimumSize(QtCore.QSize(640, 480))
        # self.setObjectName("videoView")

        self.vid = None

        if videoFile is not None:
            self.initialize_video(videoFile)
            self.update()
        self.setStyleSheet("background:transparent;")
        self.setMouseTracking(True)

        # for pan/zoom
        self.resize_on_each_frame = True
        self.grabGesture(Qt.PinchGesture)
        self.setResizeAnchor(QtWidgets.QGraphicsView.AnchorUnderMouse)

    def event(self, event):
        out = super(VideoFrame, self).event(event)
        if type(event) == QtWidgets.QGestureEvent:
            gesture = event.gesture(Qt.PinchGesture)
            scale = gesture.scaleFactor()
            last_scale = gesture.lastScaleFactor()
            self.scale(scale, last_scale)
        return out

    def initialize_image(self, imagefile: Union[str, os.PathLike]):
        if self.vid is not None:
            self.vid.close()
            self.vid = None

        self.videofile = imagefile
        assert os.path.isfile(imagefile)

        im = cv2.imread(imagefile, 1)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        self.frame = im

        self.initialized.emit(1)

        # the frame in the videoreader is the position of the reader. If you've read frame 0, the current reader
        # position is 1. This makes cv2.CAP_PROP_POS_FRAMES match vid.fnum. However, we want to keep track of our
        # currently displayed image, which is fnum - 1
        self.current_fnum = 0
        # print('new fnum: {}'.format(self.current_fnum))
        self.show_image(self.frame)
        self.frameNum.emit(self.current_fnum)
        self.fitInView()

        # print(self.palette())

    def initialize_video(self, videofile: Union[str, os.PathLike]):
        if self.vid is not None:
            self.vid.close()
            self.vid = None

        self.videofile = videofile
        self.vid = VideoReader(videofile)
        # self.frame = next(self.vid)
        self.initialized.emit(len(self.vid))
        # there was a bug where sometimes subsequent videos with the same frame would not update the image
        self.update_frame(0, force_update=True)
        self.fitInView()

    def get_image_names(self):
        if self.vid is None:  # single image
            return [self.videofile]
        elif isinstance(self.vid.file_object, list):  # image directory
            return [os.path.split(n)[1] for n in self.vid.file_object]
        else:  # video
            return [self.videofile] * len(self.vid)

    def update_frame(self, value, force_update: bool = False):
        # print('updating')
        # print('update to: {}'.format(value))
        # print(self.current_fnum)
        # previous_frame = self.current_fnum
        if not hasattr(self, 'vid'):
            return
        value = int(value)
        if hasattr(self, 'current_fnum'):
            if self.current_fnum == value and not force_update:
                # print('already there')
                return
        if value < 0:
            # warnings.warn('Desired frame less than 0: {}'.format(value))
            value = 0
        if value >= self.vid.nframes:
            # warnings.warn('Desired frame beyond maximum: {}'.format(self.vid.nframes))
            value = self.vid.nframes - 1

        self.frame = self.vid[value]

        # the frame in the videoreader is the position of the reader. If you've read frame 0, the current reader
        # position is 1. This makes cv2.CAP_PROP_POS_FRAMES match vid.fnum. However, we want to keep track of our
        # currently displayed image, which is fnum - 1
        self.current_fnum = self.vid.fnum - 1
        # print('new fnum: {}'.format(self.current_fnum))
        self.show_image(self.frame)
        self.frameNum.emit(self.current_fnum)
        if self.resize_on_each_frame: self.fitInView()

    def next_frame(self):
        self.update_frame(self.current_fnum + 1)

    def previous_frame(self):
        self.update_frame(self.current_fnum - 1)

    def fitInView(self, scale=True):
        rect = QtCore.QRectF(self._photo.pixmap().rect())
        if not rect.isNull():
            self.scene.setSceneRect(rect)
            # if self.hasPhoto():
            unity = self.transform().mapRect(QtCore.QRectF(0, 0, 1, 1))
            self.scale(1 / unity.width(), 1 / unity.height())
            viewrect = self.viewport().rect()
            scenerect = self.transform().mapRect(rect)
            factor = min(viewrect.width() / scenerect.width(),
                         viewrect.height() / scenerect.height())
            # print(factor, viewrect, scenerect)
            self.scale(factor, factor)
            self._zoom = 0

    def adjust_aspect_ratio(self):
        if not hasattr(self, 'vid'):
            raise ValueError(
                'Trying to set GraphicsView aspect ratio before video loaded.')
        if not hasattr(self.vid, 'width'):
            self.vid.width, self.vid.height = self.frame.shape[
                1], self.frame.shape[0]
        video_aspect = self.vid.width / self.vid.height
        H, W = self.height(), self.width()
        new_width = video_aspect * H
        if new_width < W:
            self.setFixedWidth(new_width)
        new_height = W / self.vid.width * self.vid.height
        if new_height < H:
            self.setFixedHeight(new_height)

    def show_image(self, array):
        qpixmap = numpy_to_qpixmap(array)
        # THIS LINE CHANGES THE SCENE WIDTH AND HEIGHT
        self._photo.setPixmap(qpixmap)

        if self.resize_on_each_frame: self.fitInView()
        self.update()
Esempio n. 25
0
class MainThread(QObject):
    T = TypeVar('T')

    # Instance (singleton)
    _instance: Optional['MainThread'] = None

    # Variables for inter-thread communication
    _updater_mutex = Lock()
    _in_widget: Optional[QWidget] = None
    _in_widget_setter: Optional[Callable[..., None]] = None
    _in_model_getter: Optional[Callable[[], Any]] = None
    _updater_out: Optional[QWidgetUpdater] = None

    # Variables for inter-thread communication of general executor
    _executor_mutex = Lock()
    _in_function: Optional[Callable[[], None]] = None

    # Signal for invoking method on main thread
    _sig_make_object = Signal()
    _sig_execute_function = Signal()

    @classmethod
    def initialise(cls) -> None:
        """
        Create the singleton. This method must be executed first, and on the main thread. This method can be enqueued
        to run upon QTApplication start by using `QTimer.singleShot(0, QWidgetUpdaterHost.initialise)`.
        """
        cls._instance = cls()

    def __init__(self):
        """
        Intended to be called by `initialise` only. Instantiates the singleton, and checks if we're on the main thread.
        If not, raises a RuntimeError.
        """
        super().__init__()
        app = QtWidgets.QApplication.instance()
        if app is None:
            print(
                'Running without QT event queue. Callbacks will occur on the caller\'s thread, and UI bindings will '
                'be ignored.')
        elif app.thread() != QThread.currentThread():
            raise RuntimeError(
                'QWidgetUpdaterFactory must be created on the main thread')
        else:
            self._sig_make_object.connect(self._create_QWidgetUpdater_slot)
            self._sig_execute_function.connect(self._execute_slot)

    @classmethod
    def _execute_slot(cls):
        # Intended to be run as a slot only!
        cls._in_function()
        cls._in_function = None
        cls._executor_mutex.release()

    @classmethod
    def _create_QWidgetUpdater_slot(cls):
        # Intended to be run as a slot only!
        cls._updater_out = QWidgetUpdater(cls._in_widget,
                                          cls._in_widget_setter,
                                          cls._in_model_getter)
        cls._in_widget = None
        cls._in_widget_setter = None
        cls._in_model_getter = None
        cls._updater_mutex.release()

    @classmethod
    def execute(cls, fn: Callable[[], None], blocking: bool = False) -> None:
        app = QtWidgets.QApplication.instance()
        if app is None or app.thread() == QThread.currentThread():
            # If already on main thread then just execute directly. This check works even if standard python threads
            # (not QThreads) are used. Function called explicitly (instead of relying on direct QT signal) as the QT
            # event queue may not be available.
            fn()
            return

        # On non-main thread
        cls._executor_mutex.acquire(
        )  # Will be released when object is created (on main thread)
        cls._in_function = fn
        cls._instance._sig_execute_function.emit()  # releases lock

        if blocking:
            cls._executor_mutex.acquire()
            cls._executor_mutex.release()

    @classmethod
    def create_QWidgetUpdater(
            cls, widget: QWidget, widget_setter: Callable[[T], None],
            model_getter: Callable[[], T]) -> Callable[[], None]:
        """
        Creates a new QWidgetUpdater. If this is called from a secondary thread, the QWidgetUpdater is created on the
         main thread via signal.
        Parameters
        ----------
        widget The widget to be updated, required so signals can be blocked during update
        widget_setter Function to be called on the widget with input from model_getter()
        model_getter Retrieves the value to be provided to widget_setter

        Returns QWidgetUpdater constructed on the main thread
        -------

        """
        app = QtWidgets.QApplication.instance()
        if app is None:
            # Headless mode; updater doesn't need to do anything.
            return lambda: None
        if app.thread() == QThread.currentThread():
            # If already on main thread then just make function directly. This check works even if standard python
            # threads (not QThreads) are used. Function called explicitly (instead of relying on direct QT signal) as
            # the QT event queue may not be available.
            return QWidgetUpdater(widget, widget_setter, model_getter)
        # On non-main thread
        # Acquire lock and set input variables. Lock will be released & inputs will be cleared on main thread.
        cls._updater_mutex.acquire()
        cls._in_widget = widget
        cls._in_widget_setter = widget_setter
        cls._in_model_getter = model_getter
        cls._instance._sig_make_object.emit()

        # Block until main thread does it's job, and get output & rest output variable.
        cls._updater_mutex.acquire()
        out = cls._updater_out
        cls._updater_out = None
        cls._updater_mutex.release()
        return out
Esempio n. 26
0
    class QtImageViewer(QGraphicsView):
        """ PyQt image viewer widget for a QPixmap in a QGraphicsView scene with mouse zooming and panning.
        Displays a QImage or QPixmap (QImage is internally converted to a QPixmap).
        To display any other image format, you must first convert it to a QImage or QPixmap.
        Some useful image format conversion utilities:
            qimage2ndarray: NumPy ndarray <==> QImage    (https://github.com/hmeine/qimage2ndarray)
            ImageQt: PIL Image <==> QImage  (https://github.com/python-pillow/Pillow/blob/master/PIL/ImageQt.py)
        Mouse interaction:
            Left mouse button drag: Pan image.
            Right mouse button drag: Zoom box.
            Right mouse button doubleclick: Zoom to show entire image.
        """

        # Mouse button signals emit image scene (x, y) coordinates.
        # !!! For image (row, column) matrix indexing, row = y and column = x.
        leftMouseButtonPressed = Signal(float, float)
        rightMouseButtonPressed = Signal(float, float)
        leftMouseButtonReleased = Signal(float, float)
        rightMouseButtonReleased = Signal(float, float)
        leftMouseButtonDoubleClicked = Signal(float, float)
        rightMouseButtonDoubleClicked = Signal(float, float)

        def __init__(self):
            QGraphicsView.__init__(self)

            # Image is displayed as a QPixmap in a QGraphicsScene attached to this QGraphicsView.
            self.scene = QGraphicsScene()
            self.setScene(self.scene)

            # Store a local handle to the scene's current image pixmap.
            self._pixmapHandle = None

            # Image aspect ratio mode.
            # !!! ONLY applies to full image. Aspect ratio is always ignored when zooming.
            #   Qt.IgnoreAspectRatio: Scale image to fit viewport.
            #   Qt.KeepAspectRatio: Scale image to fit inside viewport, preserving aspect ratio.
            #   Qt.KeepAspectRatioByExpanding: Scale image to fill the viewport, preserving aspect ratio.
            self.aspectRatioMode = Qt.KeepAspectRatio

            # Scroll bar behaviour.
            #   Qt.ScrollBarAlwaysOff: Never shows a scroll bar.
            #   Qt.ScrollBarAlwaysOn: Always shows a scroll bar.
            #   Qt.ScrollBarAsNeeded: Shows a scroll bar only when zoomed.
            self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
            self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)

            self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)

            # Stack of QRectF zoom boxes in scene coordinates.
            self.zoomStack = []

            # Flags for enabling/disabling mouse interaction.
            self.canZoom = True
            self.canPan = True

        def hasImage(self):
            """ Returns whether or not the scene contains an image pixmap.
            """
            return self._pixmapHandle is not None

        def clearImage(self):
            """ Removes the current image pixmap from the scene if it exists.
            """
            if self.hasImage():
                self.scene.removeItem(self._pixmapHandle)
                self._pixmapHandle = None

        def pixmap(self):
            """ Returns the scene's current image pixmap as a QPixmap, or else None if no image exists.
            :rtype: QPixmap | None
            """
            if self.hasImage():
                return self._pixmapHandle.pixmap()
            return None

        def image(self):
            """ Returns the scene's current image pixmap as a QImage, or else None if no image exists.
            :rtype: QImage | None
            """
            if self.hasImage():
                return self._pixmapHandle.pixmap().toImage()
            return None

        def setImage(self, image):
            """ Set the scene's current image pixmap to the input QImage or QPixmap.
            Raises a RuntimeError if the input image has type other than QImage or QPixmap.
            :type image: QImage | QPixmap
            """
            if type(image) is QPixmap:
                pixmap = image
            elif type(image) is QImage:
                pixmap = QPixmap.fromImage(image)
            else:
                raise RuntimeError(
                    "ImageViewer.setImage: Argument must be a QImage or QPixmap."
                )
            if self.hasImage():
                self._pixmapHandle.setPixmap(pixmap)
            else:
                self._pixmapHandle = self.scene.addPixmap(pixmap)
            self.setSceneRect(QRectF(
                pixmap.rect()))  # Set scene size to image size.
            self.updateViewer()

        def loadImageFromFile(self, fileName):
            """ Load an image from file.
            Without any arguments, loadImageFromFile() will popup a file dialog to choose the image file.
            With a fileName argument, loadImageFromFile(fileName) will attempt to load the specified image file directly.
            """
            if len(fileName) and os.path.isfile(fileName):
                image = QImage(fileName)
                self.setImage(image)

        def updateViewer(self):
            """ Show current zoom (if showing entire image, apply current aspect ratio mode).
            """
            if not self.hasImage():
                return
            if len(self.zoomStack) and self.sceneRect().contains(
                    self.zoomStack[-1]):
                self.fitInView(self.zoomStack[-1], Qt.IgnoreAspectRatio
                               )  # Show zoomed rect (ignore aspect ratio).
            else:
                self.zoomStack = [
                ]  # Clear the zoom stack (in case we got here because of an invalid zoom).
                self.fitInView(
                    self.sceneRect(), self.aspectRatioMode
                )  # Show entire image (use current aspect ratio mode).

        def resizeEvent(self, event):
            """ Maintain current zoom on resize.
            """
            self.updateViewer()

        def wheelEvent(self, event):
            scale_factor = 1.1

            if event.delta() > 0:
                self.scale(scale_factor, scale_factor)
            else:
                self.scale(1 / scale_factor, 1 / scale_factor)

        def mousePressEvent(self, event):
            """ Start mouse pan or zoom mode.
            """
            scenePos = self.mapToScene(event.pos())
            if event.button() == Qt.LeftButton:
                if self.canPan:
                    self.setDragMode(QGraphicsView.ScrollHandDrag)
                self.leftMouseButtonPressed.emit(scenePos.x(), scenePos.y())
            QGraphicsView.mousePressEvent(self, event)

        def mouseReleaseEvent(self, event):
            """ Stop mouse pan or zoom mode (apply zoom if valid).
            """
            QGraphicsView.mouseReleaseEvent(self, event)
            scenePos = self.mapToScene(event.pos())
            if event.button() == Qt.LeftButton:
                self.setDragMode(QGraphicsView.NoDrag)
                self.leftMouseButtonReleased.emit(scenePos.x(), scenePos.y())

        def mouseDoubleClickEvent(self, event):
            """ Show entire image.
            """
            scenePos = self.mapToScene(event.pos())
            if event.button() == Qt.LeftButton:
                self.leftMouseButtonDoubleClicked.emit(scenePos.x(),
                                                       scenePos.y())
            elif event.button() == Qt.RightButton:
                if self.canZoom:
                    self.zoomStack = []  # Clear zoom stack.
                    self.updateViewer()
                self.rightMouseButtonDoubleClicked.emit(
                    scenePos.x(), scenePos.y())

            QGraphicsView.mouseDoubleClickEvent(self, event)
Esempio n. 27
0
class CreateLayeredPsdSignals(QObject):
    started = Signal()
    progress_step = Signal()
    finished = Signal()
    file_created = Signal(Path)
Esempio n. 28
0
class DRTSig(QObject):
    send_device_msg_sig = Signal(str)
Esempio n. 29
0
class ImportMappings(QWidget):
    """
    A widget for managing Mappings (add, remove, edit, visualize, and so on).
    Intended to be embedded in an ImportEditor.
    """

    mappingChanged = Signal("QVariant")
    """Emitted when a new mapping MappingSpecModel is selected from the Mappings list."""
    mappingDataChanged = Signal("QVariant")
    """Emits the new MappingListModel."""
    def __init__(self, parent=None):
        """
        Args:
            parent (QWidget, optional): a parent widget
        """
        from ..ui.import_mappings import Ui_ImportMappings  # pylint: disable=import-outside-toplevel

        super().__init__(parent)

        # state
        self._model = None

        # initialize interface
        self._ui = Ui_ImportMappings()
        self._ui.setupUi(self)
        self._ui.table_view.setItemDelegateForColumn(
            1, ComboBoxDelegate(self, MAPPING_CHOICES))
        for i in range(self._ui.mapping_splitter.count()):
            self._ui.mapping_splitter.setCollapsible(i, False)

        # connect signals
        self._select_handle = None
        self._ui.new_button.clicked.connect(self.new_mapping)
        self._ui.remove_button.clicked.connect(self.delete_selected_mapping)
        self.mappingChanged.connect(self._ui.table_view.setModel)
        self.mappingChanged.connect(self._ui.options.set_model)

    def set_data_source_column_num(self, num):
        """Sets the number of available columns in the options widget."""
        self._ui.options.set_num_available_columns(num)

    def set_model(self, model):
        """
        Sets new model
        """
        if self._select_handle and self._ui.list_view.selectionModel():
            self._ui.list_view.selectionModel().selectionChanged.disconnect(
                self.select_mapping)
            self._select_handle = None
        if self._model:
            self._model.dataChanged.disconnect(self.data_changed)
        self._model = model
        self._ui.list_view.setModel(model)
        self._select_handle = self._ui.list_view.selectionModel(
        ).selectionChanged.connect(self.select_mapping)
        self._model.dataChanged.connect(self.data_changed)
        if self._model.rowCount() > 0:
            self._ui.list_view.setCurrentIndex(self._model.index(0, 0))
        else:
            self._ui.list_view.clearSelection()

    @Slot()
    def data_changed(self):
        """Emits the mappingDataChanged signal with the currently selected data mappings."""
        m = None
        indexes = self._ui.list_view.selectedIndexes()
        if self._model and indexes:
            m = self._model.data_mapping(indexes()[0])
        self.mappingDataChanged.emit(m)

    @Slot()
    def new_mapping(self):
        """
        Adds new empty mapping
        """
        if self._model:
            self._model.add_mapping()
            if not self._ui.list_view.selectedIndexes():
                # if no item is selected, select the first item
                self._ui.list_view.setCurrentIndex(self._model.index(0, 0))

    @Slot()
    def delete_selected_mapping(self):
        """
        deletes selected mapping
        """
        if self._model is not None:
            # get selected mapping in list
            indexes = self._ui.list_view.selectedIndexes()
            if indexes:
                self._model.remove_mapping(indexes[0].row())
                if self._model.rowCount() > 0:
                    # select the first item
                    self._ui.list_view.setCurrentIndex(self._model.index(0, 0))
                    self.select_mapping(
                        self._ui.list_view.selectionModel().selection())
                else:
                    # no items clear selection so select_mapping is called
                    self._ui.list_view.clearSelection()

    @Slot("QItemSelection")
    def select_mapping(self, selection):
        """Emits mappingChanged with the selected mapping."""
        if selection.indexes():
            m = self._model.data_mapping(selection.indexes()[0])
        else:
            m = None
        self.mappingChanged.emit(m)

    def selected_mapping_name(self):
        """Returns the name of the selected mapping."""
        if not self._ui.list_view.selectionModel().hasSelection():
            return None
        return self._ui.list_view.selectionModel().selection().indexes(
        )[0].data()
class CalibrationSliderWidget(QObject):

    update_if_mode_matches = Signal(str)

    # Conversions from configuration value to slider value and back
    CONF_VAL_TO_SLIDER_VAL = 10
    SLIDER_VAL_TO_CONF_VAL = 0.1

    def __init__(self, parent=None):
        super(CalibrationSliderWidget, self).__init__(parent)

        loader = UiLoader()
        self.ui = loader.load_file('calibration_slider_widget.ui', parent)

        self.update_gui_from_config()

        self.timer = None

        self.setup_connections()

    def setup_connections(self):
        self.ui.detector.currentIndexChanged.connect(
            self.update_gui_from_config)
        for widget in self.config_widgets():
            widget.valueChanged.connect(self.update_widget_counterpart)
            widget.valueChanged.connect(self.update_config_from_gui)

        self.ui.sb_translation_range.valueChanged.connect(self.update_ranges)
        self.ui.sb_tilt_range.valueChanged.connect(self.update_ranges)
        self.ui.sb_beam_range.valueChanged.connect(self.update_ranges)

        self.ui.push_reset_config.pressed.connect(self.reset_config)

    def update_ranges(self):
        r = self.ui.sb_translation_range.value()
        slider_r = r * self.CONF_VAL_TO_SLIDER_VAL
        for w in self.translation_widgets():
            v = w.value()
            r_val = slider_r if w.objectName().startswith('slider') else r
            w.setRange(v - r_val / 2.0, v + r_val / 2.0)

        r = self.ui.sb_tilt_range.value()
        slider_r = r * self.CONF_VAL_TO_SLIDER_VAL
        for w in self.tilt_widgets():
            v = w.value()
            r_val = slider_r if w.objectName().startswith('slider') else r
            w.setRange(v - r_val / 2.0, v + r_val / 2.0)

        r = self.ui.sb_beam_range.value()
        slider_r = r * self.CONF_VAL_TO_SLIDER_VAL
        for w in self.beam_widgets():
            v = w.value()
            r_val = slider_r if w.objectName().startswith('slider') else r
            w.setRange(v - r_val / 2.0, v + r_val / 2.0)

    def current_detector(self):
        return self.ui.detector.currentText()

    def current_detector_dict(self):
        return HexrdConfig().get_detector(self.current_detector())

    def translation_widgets(self):
        # Let's take advantage of the naming scheme
        prefixes = ['sb', 'slider']
        root = 'translation'
        suffixes = ['0', '1', '2']
        widget_names = [
            '_'.join([p, root, s])
            for p in prefixes
            for s in suffixes
        ]

        return [getattr(self.ui, x) for x in widget_names]

    def tilt_widgets(self):
        # Let's take advantage of the naming scheme
        prefixes = ['sb', 'slider']
        root = 'tilt'
        suffixes = ['0', '1', '2']
        widget_names = [
            '_'.join([p, root, s])
            for p in prefixes
            for s in suffixes
        ]

        return [getattr(self.ui, x) for x in widget_names]

    def transform_widgets(self):
        return self.translation_widgets() + self.tilt_widgets()

    def beam_widgets(self):
        # Let's take advantage of the naming scheme
        prefixes = ['sb', 'slider']
        roots = ['energy', 'azimuth', 'polar']
        suffixes = ['0']
        widget_names = [
            '_'.join([p, r, s])
            for p in prefixes
            for r in roots
            for s in suffixes
        ]

        return [getattr(self.ui, x) for x in widget_names]

    def config_widgets(self):
        return self.transform_widgets() + self.beam_widgets()

    def all_widgets(self):
        return self.config_widgets() + [self.ui.detector]

    def block_all_signals(self):
        previously_blocked = []
        all_widgets = self.all_widgets()

        for widget in all_widgets:
            previously_blocked.append(widget.blockSignals(True))

        return previously_blocked

    def unblock_all_signals(self, previously_blocked):
        all_widgets = self.all_widgets()

        for block, widget in zip(previously_blocked, all_widgets):
            widget.blockSignals(block)

    def on_detector_changed(self):
        self.update_gui_from_config()

    def update_widget_counterpart(self):
        sender = self.sender()
        name = sender.objectName()
        value = sender.value()

        prefix, root, ind = name.split('_')

        if prefix == 'slider':
            value *= self.SLIDER_VAL_TO_CONF_VAL
        else:
            value *= self.CONF_VAL_TO_SLIDER_VAL

        counter = 'slider' if prefix == 'sb' else 'sb'

        counter_widget_name = '_'.join([counter, root, ind])
        counter_widget = getattr(self.ui, counter_widget_name)

        blocked = counter_widget.blockSignals(True)
        try:
            counter_widget.setValue(value)
        finally:
            counter_widget.blockSignals(blocked)

    def update_gui_from_config(self):
        self.update_detectors_from_config()

        previously_blocked = self.block_all_signals()
        try:
            for widget in self.config_widgets():
                self.update_widget_value(widget)

        finally:
            self.unblock_all_signals(previously_blocked)

        self.update_ranges()

    def update_detectors_from_config(self):
        widget = self.ui.detector

        old_detector = self.current_detector()
        old_detectors = [widget.itemText(x) for x in range(widget.count())]
        detectors = HexrdConfig().get_detector_names()

        if old_detectors == detectors:
            # The detectors didn't change. Nothing to update
            return

        blocked = widget.blockSignals(True)
        try:
            widget.clear()
            widget.addItems(detectors)
            if old_detector in detectors:
                # Switch to the old detector if possible
                widget.setCurrentText(old_detector)

        finally:
            widget.blockSignals(blocked)

    def update_config_from_gui(self, val):
        """This function only updates the sender value"""
        sender = self.sender()
        name = sender.objectName()

        # Take advantage of the widget naming scheme
        prefix, key, ind = name.split('_')
        ind = int(ind)

        if prefix == 'slider':
            val *= self.SLIDER_VAL_TO_CONF_VAL

        if key in ['tilt', 'translation']:
            det = self.current_detector_dict()
            rme = HexrdConfig().rotation_matrix_euler()
            if key == 'tilt' and rme is not None:
                # Convert to radians, and to the native python type before saving
                val = np.radians(val).item()

            det['transform'][key]['value'][ind] = val

            # Since we modify the value directly instead of letting the
            # HexrdConfig() do it, let's also emit the signal it would
            # have emitted.
            HexrdConfig().detector_transform_modified.emit(
                self.current_detector()
            )
        else:
            iconfig = HexrdConfig().config['instrument']
            if key == 'energy':
                iconfig['beam'][key]['value'] = val
                HexrdConfig().update_active_material_energy()
            elif key == 'polar':
                iconfig['beam']['vector']['polar_angle']['value'] = val
                self.emit_update_if_polar()
            else:
                iconfig['beam']['vector'][key]['value'] = val
                self.emit_update_if_polar()

    def update_widget_value(self, widget):
        name = widget.objectName()

        # Take advantage of the widget naming scheme
        prefix, key, ind = name.split('_')
        ind = int(ind)

        if key in ['tilt', 'translation']:
            det = self.current_detector_dict()
            val = det['transform'][key]['value'][ind]
        else:
            iconfig = HexrdConfig().config['instrument']
            if key == 'energy':
                val = iconfig['beam'][key]['value']
            elif key == 'polar':
                val = iconfig['beam']['vector']['polar_angle']['value']
            else:
                val = iconfig['beam']['vector'][key]['value']

        if key == 'tilt':
            if HexrdConfig().rotation_matrix_euler() is None:
                suffix = ''
            else:
                # Convert to degrees, and to the native python type
                val = np.degrees(val).item()
                suffix = '°'

            if prefix == 'sb':
                widget.setSuffix(suffix)

        if prefix == 'slider':
            val *= self.CONF_VAL_TO_SLIDER_VAL

        # Make sure the widget's range will accept the value
        if val < widget.minimum():
            widget.setMinimum(val)
        elif val > widget.maximum():
            widget.setMaximum(val)

        widget.setValue(val)

    def emit_update_if_polar(self):
        # Only emit this once every 500 milliseconds or so
        if not hasattr(self, '_update_if_polar_timer'):
            self._update_if_polar_timer = QTimer()
            self._update_if_polar_timer.setSingleShot(True)
            self._update_if_polar_timer.timeout.connect(
                lambda: self.update_if_mode_matches.emit('polar'))

        self._update_if_polar_timer.start(500)

    def reset_config(self):
        HexrdConfig().restore_instrument_config_backup()
        self.update_gui_from_config()