class NodeGraph(QtCore.QObject): """ base node graph controller. Args: tab_search_key(str): hotkey for the tab search widget (default: "tab"). """ #: signal for when a node has been created in the node graph. node_created = QtCore.Signal(NodeObject) #: signal for when a node is selected. node_selected = QtCore.Signal(NodeObject) #: signal for when a node has been connected. port_connected = QtCore.Signal(Port, Port) #: signal for when drop data has been added to the graph. data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint) def __init__(self, parent=None, tab_search_key='tab'): super(NodeGraph, self).__init__(parent) self.setObjectName('NodeGraphQt') self._model = NodeGraphModel() self._viewer = NodeViewer() self._vendor = NodeVendor() self._undo_stack = QUndoStack(self) tab = QAction('Search Nodes', self) tab.setShortcut(tab_search_key) tab.triggered.connect(self._toggle_tab_search) self._viewer.addAction(tab) self._wire_signals() def _wire_signals(self): # internal signals. self._viewer.search_triggered.connect(self._on_search_triggered) self._viewer.connection_changed.connect(self._on_connection_changed) self._viewer.moved_nodes.connect(self._on_nodes_moved) # pass through signals. self._viewer.node_selected.connect(self._on_node_selected) self._viewer.data_dropped.connect(self._on_node_data_dropped) def _toggle_tab_search(self): """ toggle the tab search widget. """ self._viewer.tab_search_set_nodes(self._vendor.names) self._viewer.tab_search_toggle() def _on_node_selected(self, node_id): """ called when a node in the viewer is selected on left click. (emits the node object when the node is clicked) Args: node_id (str): node id emitted by the viewer. """ node = self.get_node_by_id(node_id) self.node_selected.emit(node) def _on_node_data_dropped(self, data, pos): """ called when data has been dropped on the viewer. Args: data (QtCore.QMimeData): mime data. pos (QtCore.QPoint): scene position relative to the drop. """ self.data_dropped.emit(data, pos) def _on_nodes_moved(self, node_data): """ called when selected nodes in the viewer has changed position. Args: node_data (dict): {<node_view>: <previous_pos>} """ self._undo_stack.beginMacro('moved nodes') for node_view, prev_pos in node_data.items(): node = self._model.nodes[node_view.id] self._undo_stack.push(NodeMovedCmd(node, node.pos(), prev_pos)) self._undo_stack.endMacro() def _on_search_triggered(self, node_type, pos): """ called when the tab search widget is triggered in the viewer. Args: node_type (str): node identifier. pos (tuple): x,y position for the node. """ self.create_node(node_type, pos=pos) def _on_connection_changed(self, disconnected, connected): """ called when a pipe connection has been changed in the viewer. Args: disconnected (list[list[widgets.port.PortItem]): pair list of port view items. connected (list[list[widgets.port.PortItem]]): pair list of port view items. """ if not (disconnected or connected): return label = 'connected node(s)' if connected else 'disconnected node(s)' ptypes = {'in': 'inputs', 'out': 'outputs'} self._undo_stack.beginMacro(label) for p1_view, p2_view in disconnected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.disconnect_from(port2) for p1_view, p2_view in connected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.connect_to(port2) self._undo_stack.endMacro() @property def model(self): """ Returns the model used to store the node graph data. Returns: NodeGraphQt.base.model.NodeGraphModel: node graph model. """ return self._model def show(self): """ Show node graph viewer widget this is just a convenience function to :meth:`NodeGraph.viewer().show()`. """ self._viewer.show() def close(self): """ Close node graph NodeViewer widget this is just a convenience function to :meth:`NodeGraph.viewer().close()`. """ self._viewer.close() def viewer(self): """ Return the node graph viewer widget. Returns: NodeGraphQt.widgets.viewer.NodeViewer: viewer widget. """ return self._viewer def scene(self): """ Return the scene object. Returns: NodeGraphQt.widgets.scene.NodeScene: node scene. """ return self._viewer.scene() def undo_stack(self): """ Returns the undo stack used in the node graph Returns: QtWidgets.QUndoStack: undo stack. """ return self._undo_stack def clear_undo_stack(self): """ Clears the undo stack. (convenience function to :meth:`NodeGraph.undo_stack().clear`) """ self._undo_stack.clear() def begin_undo(self, name='undo'): """ Start of an undo block followed by a :meth:`NodeGraph.end_undo()`. Args: name (str): name for the undo block. """ self._undo_stack.beginMacro(name) def end_undo(self): """ End of an undo block started by :meth:`NodeGraph.begin_undo()`. """ self._undo_stack.endMacro() def context_menu(self): """ Returns the node graph root context menu object. Returns: Menu: context menu object. """ return Menu(self._viewer, self._viewer.context_menu()) def acyclic(self): """ Returns true if the current node graph is acyclic. Returns: bool: true if acyclic (default: True). """ return self._model.acyclic def set_acyclic(self, mode=True): """ Set the node graph to be acyclic or not. (default=True) Args: mode (bool): false to disable acyclic. """ self._model.acyclic = mode self._viewer.acyclic = mode def set_pipe_layout(self, layout='curved'): """ Set node graph pipes to be drawn straight or curved by default all pipes are set curved. (default='curved') Args: layout (str): 'straight' or 'curved' """ self._viewer.set_pipe_layout(layout) def fit_to_selection(self): """ Sets the zoom level to fit selected nodes. If no nodes are selected then all nodes in the graph will be framed. """ nodes = self.selected_nodes() or self.all_nodes() if not nodes: return self._viewer.zoom_to_nodes([n.view for n in nodes]) def reset_zoom(self): """ Reset the zoom level """ self._viewer.reset_zoom() def set_zoom(self, zoom=0): """ Set the zoom factor of the Node Graph the default is 0.0 Args: zoom (float): zoom factor (max zoom out -0.9 / max zoom in 2.0) """ self._viewer.set_zoom(zoom) def get_zoom(self): """ Get the current zoom level of the node graph. Returns: float: the current zoom level. """ return self._viewer.get_zoom() def center_on(self, nodes=None): """ Center the node graph on the given nodes or all nodes by default. Args: nodes (list[NodeGraphQt.Node]): a list of nodes. """ self._viewer.center_selection(nodes) def center_selection(self): """ Centers on the current selected nodes. """ nodes = self._viewer.selected_nodes() self._viewer.center_selection(nodes) def registered_nodes(self): """ Return a list of all node types that have been registered. To register a node see :meth:`NodeGraph.register_node` Returns: list[str]: list of node type identifiers. """ return sorted(self._vendor.nodes.keys()) def register_node(self, node, alias=None): """ Register the node to the node graph vendor. Args: node (NodeGraphQt.NodeObject): node. alias (str): custom alias name for the node type. """ self._vendor.register_node(node, alias) def create_node(self, node_type, name=None, selected=True, color=None, pos=None): """ Create a new node in the node graph. (To list all node types see :meth:`NodeGraph.registered_nodes`) Args: node_type (str): node instance type. name (str): set name of the node. selected (bool): set created node to be selected. color (tuple or str): node color (255, 255, 255) or '#FFFFFF'. pos (list[int, int]): initial x, y position for the node (default: (0, 0)). Returns: NodeGraphQt.Node: the created instance of the node. """ NodeCls = self._vendor.create_node_instance(node_type) if NodeCls: node = NodeCls() node._graph = self node.model._graph_model = self.model wid_types = node.model.__dict__.pop('_TEMP_property_widget_types') prop_attrs = node.model.__dict__.pop('_TEMP_property_attrs') graph_attrs = self.model.node_property_attrs if node.type not in graph_attrs.keys(): graph_attrs[node.type] = { n: { 'widget_type': wt } for n, wt in wid_types.items() } for pname, pattrs in prop_attrs.items(): graph_attrs[node.type][pname].update(pattrs) node.NODE_NAME = self.get_unique_name(name or node.NODE_NAME) node.model.name = node.NODE_NAME node.model.selected = selected if color: if isinstance(color, str): color = color[1:] if color[0] is '#' else color color = tuple(int(color[i:i + 2], 16) for i in (0, 2, 4)) node.model.color = color if pos: node.model.pos = [float(pos[0]), float(pos[1])] node.update() undo_cmd = NodeAddedCmd(self, node, node.model.pos) undo_cmd.setText('created node') self._undo_stack.push(undo_cmd) self.node_created.emit(node) return node raise Exception('\n\n>> Cannot find node:\t"{}"\n'.format(node_type)) def add_node(self, node): """ Add a node into the node graph. Args: node (NodeGraphQt.Node): node object. """ assert isinstance(node, NodeObject), 'node must be a Node instance.' wid_types = node.model.__dict__.pop('_TEMP_property_widget_types') prop_attrs = node.model.__dict__.pop('_TEMP_property_attrs') graph_attrs = self.model.node_property_attrs if node.type not in graph_attrs.keys(): graph_attrs[node.type] = { n: { 'widget_type': wt } for n, wt in wid_types.items() } for pname, pattrs in prop_attrs.items(): graph_attrs[node.type][pname].update(pattrs) node._graph = self node.NODE_NAME = self.get_unique_name(node.NODE_NAME) node.model._graph_model = self.model node.model.name = node.NODE_NAME node.update() self._undo_stack.push(NodeAddedCmd(self, node)) def delete_node(self, node): """ Remove the node from the node graph. Args: node (NodeGraphQt.Node): node object. """ assert isinstance(node, NodeObject), \ 'node must be a instance of a NodeObject.' self._undo_stack.push(NodeRemovedCmd(self, node)) def delete_nodes(self, nodes): """ Remove a list of specified nodes from the node graph. Args: nodes (list[NodeGraphQt.Node]): list of node instances. """ self._undo_stack.beginMacro('deleted nodes') [self.delete_node(n) for n in nodes] self._undo_stack.endMacro() def all_nodes(self): """ Return all nodes in the node graph. Returns: list[NodeGraphQt.Node]: list of nodes. """ return list(self._model.nodes.values()) def selected_nodes(self): """ Return all selected nodes that are in the node graph. Returns: list[NodeGraphQt.Node]: list of nodes. """ nodes = [] for item in self._viewer.selected_nodes(): node = self._model.nodes[item.id] nodes.append(node) return nodes def select_all(self): """ Select all nodes in the node graph. """ self._undo_stack.beginMacro('select all') for node in self.all_nodes(): node.set_selected(True) self._undo_stack.endMacro() def clear_selection(self): """ Clears the selection in the node graph. """ self._undo_stack.beginMacro('deselected nodes') for node in self.all_nodes(): node.set_selected(False) self._undo_stack.endMacro() def get_node_by_id(self, node_id=None): """ Returns the node from the node id string. Args: node_id (str): node id (:meth:`NodeObject.id`) Returns: NodeGraphQt.NodeObject: node object. """ return self._model.nodes.get(node_id) def get_node_by_name(self, name): """ Returns node that matches the name. Args: name (str): name of the node. Returns: NodeGraphQt.NodeObject: node object. """ for node_id, node in self._model.nodes.items(): if node.name() == name: return node def get_unique_name(self, name): """ Creates a unique node name to avoid having nodes with the same name. Args: name (str): node name. Returns: str: unique node name. """ name = ' '.join(name.split()) node_names = [n.name() for n in self.all_nodes()] if name not in node_names: return name regex = re.compile('[\w ]+(?: )*(\d+)') search = regex.search(name) if not search: for x in range(1, len(node_names) + 1): new_name = '{} {}'.format(name, x) if new_name not in node_names: return new_name version = search.group(1) name = name[:len(version) * -1].strip() for x in range(1, len(node_names) + 1): new_name = '{} {}'.format(name, x) if new_name not in node_names: return new_name def current_session(self): """ Returns the file path to the currently loaded session. Returns: str: path to the currently loaded session """ return self._model.session def clear_session(self): """ Clears the current node graph session. """ for n in self.all_nodes(): self.delete_node(n) self._undo_stack.clear() self._model.session = None def _serialize(self, nodes): """ serialize nodes to a dict. (used internally by the node graph) Args: nodes (list[NodeGraphQt.Nodes]): list of node instances. Returns: dict: serialized data. """ serial_data = {'nodes': {}, 'connections': []} nodes_data = {} for n in nodes: # update the node model. n.update_model() nodes_data.update(n.model.to_dict) for n_id, n_data in nodes_data.items(): serial_data['nodes'][n_id] = n_data inputs = n_data.pop('inputs') if n_data.get('inputs') else {} outputs = n_data.pop('outputs') if n_data.get('outputs') else {} for pname, conn_data in inputs.items(): for conn_id, prt_names in conn_data.items(): for conn_prt in prt_names: pipe = { 'in': [n_id, pname], 'out': [conn_id, conn_prt] } if pipe not in serial_data['connections']: serial_data['connections'].append(pipe) for pname, conn_data in outputs.items(): for conn_id, prt_names in conn_data.items(): for conn_prt in prt_names: pipe = { 'out': [n_id, pname], 'in': [conn_id, conn_prt] } if pipe not in serial_data['connections']: serial_data['connections'].append(pipe) if not serial_data['connections']: serial_data.pop('connections') return serial_data def _deserialize(self, data, relative_pos=False, pos=None): """ deserialize node data. (used internally by the node graph) Args: data (dict): node data. relative_pos (bool): position node relative to the cursor. Returns: list[NodeGraphQt.Nodes]: list of node instances. """ nodes = {} # build the nodes. for n_id, n_data in data.get('nodes', {}).items(): identifier = n_data['type'] NodeCls = self._vendor.create_node_instance(identifier) if NodeCls: node = NodeCls() node._graph = self name = self.get_unique_name(n_data.get('name', node.NODE_NAME)) n_data['name'] = name # set properties. for prop, val in node.model.properties.items(): if prop in n_data.keys(): setattr(node.model, prop, n_data[prop]) # set custom properties. for prop, val in n_data.get('custom', {}).items(): if prop in node.model.custom_properties.keys(): node.model.custom_properties[prop] = val node.update() self._undo_stack.push( NodeAddedCmd(self, node, n_data.get('pos'))) nodes[n_id] = node # build the connections. for connection in data.get('connections', []): nid, pname = connection.get('in', ('', '')) in_node = nodes.get(nid) if not in_node: continue in_port = in_node.inputs().get(pname) if in_node else None nid, pname = connection.get('out', ('', '')) out_node = nodes.get(nid) if not out_node: continue out_port = out_node.outputs().get(pname) if out_node else None if in_port and out_port: self._undo_stack.push(PortConnectedCmd(in_port, out_port)) node_objs = list(nodes.values()) if relative_pos: self._viewer.move_nodes([n.view for n in node_objs]) [setattr(n.model, 'pos', n.view.pos) for n in node_objs] elif pos: self._viewer.move_nodes([n.view for n in node_objs], pos=pos) return node_objs def serialize_session(self): """ Serializes the current node graph layout to a dictionary. Returns: dict: serialized session of the current node layout. """ return self._serialize(self.all_nodes()) def save_session(self, file_path): """ Saves the current node graph session layout to a `JSON` formatted file. Args: file_path (str): path to the saved node layout. """ serliazed_data = self._serialize(self.all_nodes()) file_path = file_path.strip() with open(file_path, 'w') as file_out: json.dump(serliazed_data, file_out, indent=2, separators=(',', ':')) def load_session(self, file_path): """ Load node graph session layout file. Args: file_path (str): path to the serialized layout file. """ file_path = file_path.strip() if not os.path.isfile(file_path): raise IOError('file does not exist.') self.clear_session() try: with open(file_path) as data_file: layout_data = json.load(data_file) except Exception as e: layout_data = None print('Cannot read data from file.\n{}'.format(e)) if not layout_data: return self._deserialize(layout_data) self._undo_stack.clear() self._model.session = file_path def copy_nodes(self, nodes=None): """ Copy nodes to the clipboard. Args: nodes (list[NodeGraphQt.Node]): list of nodes (default: selected nodes). """ nodes = nodes or self.selected_nodes() if not nodes: return False clipboard = QApplication.clipboard() serial_data = self._serialize(nodes) serial_str = json.dumps(serial_data) if serial_str: clipboard.setText(serial_str) return True return False def paste_nodes(self): """ Pastes nodes copied from the clipboard. """ clipboard = QApplication.clipboard() cb_string = clipboard.text() if not cb_string: return self._undo_stack.beginMacro('pasted nodes') serial_data = json.loads(cb_string) self.clear_selection() nodes = self._deserialize(serial_data, True) [n.set_selected(True) for n in nodes] self._undo_stack.endMacro() def duplicate_nodes(self, nodes): """ Create duplicate copy from the list of nodes. Args: nodes (list[NodeGraphQt.Node]): list of nodes. Returns: list[NodeGraphQt.Node]: list of duplicated node instances. """ if not nodes: return self._undo_stack.beginMacro('duplicated nodes') self.clear_selection() serial = self._serialize(nodes) new_nodes = self._deserialize(serial) offset = 50 for n in new_nodes: x, y = n.pos() n.set_pos(x + offset, y + offset) n.set_property('selected', True) self._undo_stack.endMacro() return new_nodes def disable_nodes(self, nodes, mode=None): """ Set weather to Disable or Enable specified nodes. see: :meth:`NodeObject.set_disabled` Args: nodes (list[NodeGraphQt.Node]): list of node instances. mode (bool): (optional) disable state of the nodes. """ if not nodes: return if mode is None: mode = not nodes[0].disabled() if len(nodes) > 1: text = {False: 'enabled', True: 'disabled'}[mode] text = '{} ({}) nodes'.format(text, len(nodes)) self._undo_stack.beginMacro(text) [n.set_disabled(mode) for n in nodes] self._undo_stack.endMacro() return nodes[0].set_disabled(mode) def question_dialog(self, text, title='Node Graph'): """ Prompts a question open dialog with "Yes" and "No" buttons in the node graph. (convenience function to :meth:`NodeGraph.viewer().question_dialog`) Args: text (str): question text. title (str): dialog window title. Returns: bool: true if user clicked yes. """ self._viewer.question_dialog(text, title) def message_dialog(self, text, title='Node Graph'): """ Prompts a file open dialog in the node graph. (convenience function to :meth:`NodeGraph.viewer().message_dialog`) Args: text (str): message text. title (str): dialog window title. """ self._viewer.message_dialog(text, title) def load_dialog(self, current_dir=None, ext=None): """ Prompts a file open dialog in the node graph. (convenience function to :meth:`NodeGraph.viewer().load_dialog`) Args: current_dir (str): path to a directory. ext (str): custom file type extension (default: json) Returns: str: selected file path. """ return self._viewer.load_dialog(current_dir, ext) def save_dialog(self, current_dir=None, ext=None): """ Prompts a file save dialog in the node graph. (convenience function to :meth:`NodeGraph.viewer().save_dialog`) Args: current_dir (str): path to a directory. ext (str): custom file type extension (default: json) Returns: str: selected file path. """ return self._viewer.save_dialog(current_dir, ext)
class NodeGraph(QtCore.QObject): def __init__(self, parent=None): super(NodeGraph, self).__init__(parent) self._model = NodeGraphModel() self._viewer = NodeViewer() self._undo_stack = QUndoStack(self) self._init_actions() self._wire_signals() def _wire_signals(self): self._viewer.moved_nodes.connect(self._on_nodes_moved) self._viewer.search_triggered.connect(self._on_search_triggered) self._viewer.connection_changed.connect(self._on_connection_changed) def _init_actions(self): # setup tab search shortcut. tab = QAction('Search Nodes', self) tab.setShortcut('tab') tab.triggered.connect(self._toggle_tab_search) self._viewer.addAction(tab) setup_actions(self) def _toggle_tab_search(self): """ toggle the tab search widget. """ self._viewer.tab_search_set_nodes(NodeVendor.names) self._viewer.tab_search_toggle() def _on_nodes_moved(self, node_data): """ called when selected nodes in the viewer has changed position. Args: node_data (dict): {<node_view>: <previous_pos>} """ self._undo_stack.beginMacro('moved nodes') for node_view, prev_pos in node_data.items(): node = self._model.nodes[node_view.id] self._undo_stack.push(NodeMovedCmd(node, node.pos(), prev_pos)) self._undo_stack.endMacro() def _on_search_triggered(self, node_type, pos): """ called when the tab search widget is triggered in the viewer. Args: node_type (str): node identifier. pos (tuple): x,y position for the node. """ self.create_node(node_type, pos=pos) def _on_connection_changed(self, disconnected, connected): """ called when a pipe connection has been changed in the viewer. Args: disconnected (list[list[widgets.port.PortItem]): pair list of port view items. connected (list[list[widgets.port.PortItem]]): pair list of port view items. """ if not (disconnected or connected): return label = 'connected node(s)' if connected else 'disconnected node(s)' ptypes = {'in': 'inputs', 'out': 'outputs'} self._undo_stack.beginMacro(label) for p1_view, p2_view in disconnected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.disconnect_from(port2) for p1_view, p2_view in connected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.connect_to(port2) self._undo_stack.endMacro() @property def model(self): """ Return the node graph model. Returns: NodeGraphModel: model object. """ return self._model def show(self): """ Show node graph viewer widget. """ self._viewer.show() def hide(self): """ Hide node graph viewer widget. """ self._viewer.hide() def close(self): """ Close node graph viewer widget. """ self._viewer.close() def viewer(self): """ Return the node graph viewer widget object. Returns: NodeGraphQt.widgets.viewer.NodeViewer: viewer widget. """ return self._viewer def scene(self): """ Return the scene object. Returns: NodeGraphQt.widgets.scene.NodeScene: node scene. """ return self._viewer.scene() def undo_stack(self): """ Returns the undo stack used in the node graph Returns: QUndoStack: undo stack. """ return self._undo_stack def begin_undo(self, name='undo'): """ Start of an undo block followed by a end_undo(). Args: name (str): name for the undo block. """ self._undo_stack.beginMacro(name) def end_undo(self): """ End of an undo block started by begin_undo(). """ self._undo_stack.endMacro() def context_menu(self): """ Returns a node graph context menu object. Returns: ContextMenu: node graph context menu object instance. """ return self._viewer.context_menu() def acyclic(self): """ Returns true if the current node graph is acyclic. Returns: bool: true if acyclic. """ return self._model.acyclic def set_acyclic(self, mode=True): """ Set the node graph to be acyclic or not. (default=True) Args: mode (bool): false to disable acyclic. """ self._model.acyclic = mode self._viewer.acyclic = mode def set_pipe_layout(self, layout='curved'): """ Set node graph pipes to be drawn straight or curved by default all pipes are set curved. (default='curved') Args: layout (str): 'straight' or 'curved' """ self._viewer.set_pipe_layout(layout) def fit_to_selection(self): """ Sets the zoom level to fit selected nodes. If no nodes are selected then all nodes in the graph will be framed. """ nodes = self.selected_nodes() or self.all_nodes() if not nodes: return self._viewer.zoom_to_nodes([n.view for n in nodes]) def reset_zoom(self): """ Reset the zoom level """ self._viewer.reset_zoom() def set_zoom(self, zoom=0): """ Set the zoom factor of the Node Graph the default is 0.0 Args: zoom (float): zoom factor max zoom out -0.9 max zoom in 2.0 """ self._viewer.set_zoom(zoom) def get_zoom(self): """ Get the current zoom level of the node graph. Returns: float: the current zoom level. """ return self._viewer.get_zoom() def center_on(self, nodes=None): """ Center the node graph on the given nodes or all nodes by default. Args: nodes (list[NodeGraphQt.Node]): a list of nodes. """ self._viewer.center_selection(nodes) def center_selection(self): """ Center the node graph on the current selected nodes. """ nodes = self._viewer.selected_nodes() self._viewer.center_selection(nodes) def registered_nodes(self): """ Return a list of all node types that have been registered. To register a node see "NodeGraphWidget.register_node()" Returns: list[str]: node types. """ return sorted(NodeVendor.nodes.keys()) def register_node(self, node, alias=None): """ Register a node. Args: node (NodeGraphQt.Node): node object. alias (str): custom alias name for the node type. """ NodeVendor.register_node(node, alias) def create_node(self, node_type, name=None, selected=True, color=None, pos=None): """ Create a new node in the node graph. To list all node types see "NodeGraphWidget.registered_nodes()" Args: node_type (str): node instance type. name (str): set name of the node. selected (bool): set created node to be selected. color (tuple or str): node color (255, 255, 255) or '#FFFFFF'. pos (tuple): set position of the node (x, y). Returns: NodeGraphQt.Node: created instance of a node. """ NodeInstance = NodeVendor.create_node_instance(node_type) if NodeInstance: node = NodeInstance() node._graph = self node.update() self._undo_stack.beginMacro('created node') self._undo_stack.push(NodeAddedCmd(self, node, pos)) if name: node.set_name(name) else: node.set_name(node.NODE_NAME) if color: if isinstance(color, str): color = color[1:] if color[0] is '#' else color color = tuple(int(color[i:i + 2], 16) for i in (0, 2, 4)) node.set_color(*color) node.set_selected(selected) self._undo_stack.endMacro() return node raise Exception('\n\n>> Cannot find node:\t"{}"\n'.format(node_type)) def add_node(self, node): """ Add a node into the node graph. Args: node (NodeGraphQt.Node): node object. """ assert isinstance(node, NodeObject), 'node must be a Node instance.' node._graph = self node.NODE_NAME = self.get_unique_name(node.NODE_NAME) node.model.name = node.NODE_NAME node.update() self._undo_stack.push(NodeAddedCmd(self, node)) def delete_node(self, node): """ Remove the node from the node graph. Args: node (NodeGraphQt.Node): node object. """ assert isinstance(node, NodeObject), \ 'node must be a instance of a NodeObject.' self._undo_stack.push(NodeRemovedCmd(self, node)) def delete_nodes(self, nodes): """ Remove a list of nodes from the node graph. Args: nodes (list[NodeGraphQt.Node]): list of node instances. """ self._undo_stack.beginMacro('deleted nodes') [self.delete_node(n) for n in nodes] self._undo_stack.endMacro() def all_nodes(self): """ Return all nodes in the node graph. Returns: list[NodeGraphQt.Node]: list of nodes. """ return list(self._model.nodes.values()) def selected_nodes(self): """ Return all selected nodes that are in the node graph. Returns: list[NodeGraphQt.Node]: list of nodes. """ nodes = [] for item in self._viewer.selected_nodes(): node = self._model.nodes[item.id] nodes.append(node) return nodes def select_all(self): """ Select all nodes in the current node graph. """ self._undo_stack.beginMacro('select all') for node in self.all_nodes(): node.set_selected(True) self._undo_stack.endMacro() def clear_selection(self): """ Clears the selection in the node graph. """ self._undo_stack.beginMacro('deselected nodes') for node in self.all_nodes(): node.set_selected(False) self._undo_stack.endMacro() def get_node_by_id(self, node_id=None): """ Get the node object by it's id. Args: node_id (str): node id Returns: NodeGraphQt.NodeObject: node object. """ return self._model.nodes.get(node_id) def get_node_by_name(self, name): """ Returns node object that matches the name. Args: name (str): name of the node. Returns: NodeGraphQt.Node: node object. """ for node_id, node in self._model.nodes.items(): if node.name() == name: return node def get_unique_name(self, name): """ return a unique node name for the node. Args: name (str): node name. Returns: str: unique node name. """ name = ' '.join(name.split()) node_names = [n.name() for n in self.all_nodes()] if name not in node_names: return name regex = re.compile('[\w ]+(?: )*(\d+)') search = regex.search(name) if not search: for x in range(1, len(node_names) + 1): new_name = '{} {}'.format(name, x) if new_name not in node_names: return new_name version = search.group(1) name = name[:len(version) * -1].strip() for x in range(1, len(node_names) + 1): new_name = '{} {}'.format(name, x) if new_name not in node_names: return new_name def current_session(self): """ returns the file path to the currently loaded session. Returns: str: path to the currently loaded session """ return self._model.session def clear_session(self): """ clear the loaded node layout session. """ for n in self.all_nodes(): self.delete_node(n) self._undo_stack.clear() self._model.session = None def _serialize(self, nodes): """ serialize nodes to a dict. Args: nodes (list[NodeGraphQt.Nodes]): list of node instances. Returns: dict: serialized data. """ serial_data = {'nodes': {}, 'connections': []} nodes_data = {} for n in nodes: # update the node model. n.update_model() nodes_data.update(n.model.to_dict) for n_id, n_data in nodes_data.items(): serial_data['nodes'][n_id] = n_data inputs = n_data.pop('inputs') if n_data.get('inputs') else {} outputs = n_data.pop('outputs') if n_data.get('outputs') else {} for pname, conn_data in inputs.items(): for conn_id, prt_names in conn_data.items(): for conn_prt in prt_names: pipe = { 'in': [n_id, pname], 'out': [conn_id, conn_prt] } if pipe not in serial_data['connections']: serial_data['connections'].append(pipe) for pname, conn_data in outputs.items(): for conn_id, prt_names in conn_data.items(): for conn_prt in prt_names: pipe = { 'out': [n_id, pname], 'in': [conn_id, conn_prt] } if pipe not in serial_data['connections']: serial_data['connections'].append(pipe) if not serial_data['connections']: serial_data.pop('connections') return serial_data def _deserialize(self, data, relative_pos=False, pos=None): """ deserialize node data. Args: data (dict): node data. relative_pos (bool): position node relative to the cursor. Returns: list[NodeGraphQt.Nodes]: list of node instances. """ nodes = {} # build the nodes. for n_id, n_data in data.get('nodes', {}).items(): identifier = n_data['type'] NodeInstance = NodeVendor.create_node_instance(identifier) if NodeInstance: node = NodeInstance() node._graph = self name = self.get_unique_name(n_data.get('name', node.NODE_NAME)) n_data['name'] = name # set properties. for prop, val in node.model.properties.items(): if prop in n_data.keys(): setattr(node.model, prop, n_data[prop]) # set custom properties. for prop, val in n_data.get('custom', {}).items(): if prop in node.model.custom_properties.keys(): node.model.custom_properties[prop] = val node.update() self._undo_stack.push( NodeAddedCmd(self, node, n_data.get('pos'))) nodes[n_id] = node # build the connections. for connection in data.get('connections', []): nid, pname = connection.get('in', ('', '')) in_node = nodes.get(nid) if not in_node: continue in_port = in_node.inputs().get(pname) if in_node else None nid, pname = connection.get('out', ('', '')) out_node = nodes.get(nid) if not out_node: continue out_port = out_node.outputs().get(pname) if out_node else None if in_port and out_port: self._undo_stack.push(PortConnectedCmd(in_port, out_port)) node_objs = list(nodes.values()) if relative_pos: self._viewer.move_nodes([n.view for n in node_objs]) [setattr(n.model, 'pos', n.view.pos) for n in node_objs] elif pos: self._viewer.move_nodes([n.view for n in node_objs], pos=pos) return node_objs def save_session(self, file_path): """ Saves the current node graph session layout to a JSON formatted file. Args: file_path (str): path to the saved node layout. """ serliazed_data = self._serialize(self.selected_nodes()) file_path = file_path.strip() with open(file_path, 'w') as file_out: json.dump(serliazed_data, file_out, indent=2, separators=(',', ':')) def load_session(self, file_path): """ Load node graph session layout file. Args: file_path (str): path to the serialized layout file. """ file_path = file_path.strip() if not os.path.isfile(file_path): raise IOError('file does not exist.') self.clear_session() try: with open(file_path) as data_file: layout_data = json.load(data_file) except Exception as e: layout_data = None print('Cannot read data from file.\n{}'.format(e)) if not layout_data: return self._deserialize(layout_data) self._undo_stack.clear() self._model.session = file_path def copy_nodes(self, nodes=None): """ copy nodes to the clipboard by default this method copies the selected nodes from the node graph. Args: nodes (list[NodeGraphQt.Node]): list of node instances. """ nodes = nodes or self.selected_nodes() if not nodes: return False clipboard = QClipboard() serial_data = self._serialize(nodes) serial_str = json.dumps(serial_data) if serial_str: clipboard.setText(serial_str) return True return False def paste_nodes(self): """ Pastes nodes from the clipboard. """ clipboard = QClipboard() cb_string = clipboard.text() if not cb_string: return self._undo_stack.beginMacro('pasted nodes') serial_data = json.loads(cb_string) self.clear_selection() nodes = self._deserialize(serial_data, True) [n.set_selected(True) for n in nodes] self._undo_stack.endMacro() def duplicate_nodes(self, nodes): """ Create duplicates nodes. Args: nodes (list[NodeGraphQt.Node]): list of node objects. Returns: list[NodeGraphQt.Node]: list of duplicated node instances. """ if not nodes: return self._undo_stack.beginMacro('duplicated nodes') self.clear_selection() serial = self._serialize(nodes) new_nodes = self._deserialize(serial) offset = 50 for n in new_nodes: x, y = n.pos() n.set_pos(x + offset, y + offset) n.set_property('selected', True) self._undo_stack.endMacro() return new_nodes def disable_nodes(self, nodes, mode=None): """ Disable/Enable specified nodes. Args: nodes (list[NodeGraphQt.Node]): list of node instances. mode (bool): (optional) disable state of the nodes. """ if not nodes: return if mode is None: mode = not nodes[0].disabled() if len(nodes) > 1: text = {False: 'enabled', True: 'disabled'}[mode] text = '{} ({}) nodes'.format(text, len(nodes)) self._undo_stack.beginMacro(text) [n.set_disabled(mode) for n in nodes] self._undo_stack.endMacro() return nodes[0].set_disabled(mode)
class URList(UserList): class __SingleStackCommand__(QUndoCommand): def __init__(self, myList: 'URList', key: Union[int, slice], value: Any): QUndoCommand.__init__(self) self._list = myList self._key = key try: self._old_value = self._list[key] except IndexError: self._old_value = None self._new_value = value def undo(self) -> NoReturn: if self._old_value is None: del self._list[self._key] else: self._list.__realsetitem__(self._key, self._old_value) def redo(self) -> NoReturn: self._list.__realsetitem__(self._key, self._new_value) class __MultiStackCommand__(QUndoCommand): def __init__(self, myList: 'URList', key: Union[int, slice]): QUndoCommand.__init__(self) self._key = key self._list = myList def undo(self) -> NoReturn: self._list[self._key].undo() def redo(self) -> NoReturn: self._list[self._key].redo() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__stack__ = QUndoStack() self._macroRunning = False def __setitem__(self, key: Union[int, slice], value: Any) -> NoReturn: self.__stack__.push(self.__SingleStackCommand__(self, key, value)) def __getitem__(self, key): if isinstance(key, slice): myList = [] if key.step is None: myRange = range(key.start, key.stop + 1) else: myRange = range(key.start, key.stop + 1, key.step) for cKey in myRange: myList.append(super().__getitem__(cKey)) return myList return super().__getitem__(key) def __realsetitem__(self, key: Union[int, slice], value: Any) -> NoReturn: def keyUpdate(base, key, value): if key > (len(base.data) - 1): super().append(value) super().__setitem__(key, value) if isinstance(key, slice): if key.step is None: myRange = range(key.start, key.stop + 1) else: myRange = range(key.start, key.stop + 1, key.step) for cKey in myRange: keyUpdate(self, cKey, value[cKey]) super().__setitem__(cKey, value[cKey]) return keyUpdate(self, key, value) def undo(self) -> NoReturn: self.__stack__.undo() def redo(self) -> NoReturn: self.__stack__.redo() def startBulkUpdate(self) -> NoReturn: self.__stack__.beginMacro('Bulk update') self._macroRunning = True def endBulkUpdate(self) -> NoReturn: self.__stack__.endMacro() self._macroRunning = False def append(self, item) -> None: self.__setitem__(len(self.data), item)
class MainWindow(QMainWindow): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.undoStack = QUndoStack(self) self.undoStack.cleanChanged.connect(self.cleanChanged) self.undoAction = self.undoStack.createUndoAction(self, "&Undo") self.undoAction.setShortcut(QKeySequence.Undo) self.redoAction = self.undoStack.createRedoAction(self, "&Redo") self.redoAction.setShortcut(QKeySequence.Redo) self.messageLabel = QLabel() self.coordLabel = QLabel() self.stopwatchLabel = QLabel() self.time = QTime(0, 0) self.stopwatch = QTimer() self.undoView = QUndoView(self.undoStack) self.graphicsScene = GraphicsScene(self) self.copyList = [] def setupUi(self): if QIcon.themeName() == "": QIcon.setThemeName('breeze') self.openIcon = QIcon().fromTheme("document-open") self.actionOpen_Datasets.setIcon(self.openIcon) self.saveIcon = QIcon().fromTheme("document-save") self.actionSave.setIcon(self.saveIcon) self.closeIcon = QIcon().fromTheme("document-close") self.actionClose_Dataset.setIcon(self.closeIcon) self.statusbar.addWidget(self.messageLabel) self.statusbar.addWidget(self.coordLabel) self.statusbar.addWidget(self.stopwatchLabel) self.stopwatch.setInterval(1000) self.stopwatch.timeout.connect(self.updateStopWatchLabel) self.stopwatchStartIcon = QIcon().fromTheme("chronometer-start") self.actionTimer_Start.setIcon(self.stopwatchStartIcon) self.actionTimer_Start.triggered.connect(self.startStopWatch) self.stopwatchStopIcon = QIcon().fromTheme("chronometer-pause") self.actionTimer_Stop.setIcon(self.stopwatchStopIcon) self.actionTimer_Stop.triggered.connect(self.stopStopWatch) self.stopwatchResetIcon = QIcon().fromTheme("chronometer-reset") self.actionTimer_Reset.setIcon(self.stopwatchResetIcon) self.actionTimer_Reset.triggered.connect(self.resetStopWatch) self.leftIcon = QIcon().fromTheme("go-previous") self.actionSend_To_Left.setIcon(self.leftIcon) self.rightIcon = QIcon().fromTheme("go-next") self.actionSend_To_Right.setIcon(self.rightIcon) self.upIcon = QIcon().fromTheme("go-up") self.actionPrevious_Item.setIcon(self.upIcon) self.downIcon = QIcon().fromTheme("go-down") self.actionNext_Item.setIcon(self.downIcon) self.undoView.setWindowTitle("Command List") self.undoView.show() self.undoView.setAttribute(Qt.WA_QuitOnClose, False) self.menuEdit.addAction(self.undoAction) self.menuEdit.addAction(self.redoAction) self.graphicsView.setScene(self.graphicsScene) self.graphicsView.mouseMoved.connect(self.coordLabel.setText) self.graphicsScene.tabWidget = self.tabWidget self.graphicsScene.comboBox = self.comboBox self.graphicsScene.signalHandler.boxPressed.connect(self.selectBox) self.graphicsScene.signalHandler.boxChanged.connect(self.changeBox) self.graphicsScene.signalHandler.boxCreated.connect(self.createItem) @Slot() def sendToLeft(self): originIndex = self.tabWidget.currentIndex() numTabs = self.tabWidget.count() targetIndex = (numTabs + originIndex - 1) % numTabs modelIndex = self.tabWidget.getCurrentTableView().currentIndex() if modelIndex.isValid(): self.undoStack.beginMacro(f"Send item to {targetIndex}") sendToCommand = SendToCommand(originIndex, targetIndex, modelIndex.row(), self.tabWidget, self.graphicsScene) self.undoStack.push(sendToCommand) modelIndex = modelIndex.model().index(modelIndex.row(), modelIndex.column()) if modelIndex.isValid(): self.cellClicked(originIndex, modelIndex, originIndex, modelIndex) self.undoStack.endMacro() @Slot() def sendToRight(self): originIndex = self.tabWidget.currentIndex() numTabs = self.tabWidget.count() targetIndex = (originIndex + 1) % numTabs modelIndex = self.tabWidget.getCurrentTableView().currentIndex() if modelIndex.isValid(): self.undoStack.beginMacro(f"Send item to {targetIndex}") sendToCommand = SendToCommand(originIndex, targetIndex, modelIndex.row(), self.tabWidget, self.graphicsScene) self.undoStack.push(sendToCommand) modelIndex = modelIndex.model().index(modelIndex.row(), modelIndex.column()) if modelIndex.isValid(): self.cellClicked(originIndex, modelIndex, originIndex, modelIndex) self.undoStack.endMacro() @Slot() @Slot(int) def closeDataset(self, i=-1): # put a dialog if there is a pending modification # say that modifications are not lost and retrievable with CTRL-Z if self.tabWidget.count() > 0: if i == -1: i = self.tabWidget.currentIndex() deleteDatasetCommand = DeleteDatasetCommand([i], self.tabWidget, self.comboBox, self.graphicsView, self.graphicsScene) self.undoStack.push(deleteDatasetCommand) @Slot() def openDatasets(self): """Open dataset directory""" prevTabIndex = self.tabWidget.currentIndex() prevModelIndex = self.tabWidget.getCurrentSelectedCell() numTabs = self.tabWidget.count() (filenames, _ext) = QFileDialog.getOpenFileNames( self, QApplication.translate("MainWindow", "Open datasets", None, -1), "/home/kwon-young/Documents/PartageVirtualBox/data/omr_dataset/choi_dataset", "*.csv") if filenames: self.undoStack.beginMacro(f"open Datasets {filenames}") filenames.sort() openDatasetCommand = OpenDatasetCommand(filenames, self.tabWidget, self.comboBox, self.graphicsScene, self.messageLabel) self.undoStack.push(openDatasetCommand) tabIndex = numTabs if self.tabWidget.count() > 0: modelIndex = self.tabWidget.getTableModel(tabIndex).index(0, 0) if modelIndex.isValid(): self.cellClicked(tabIndex, modelIndex, prevTabIndex, prevModelIndex) self.undoStack.endMacro() @Slot(int) def currentTabChanged(self, index): self.tabWidget.setCurrentIndex(index) for tabIndex in range(self.tabWidget.count()): self.graphicsScene.changeTabColor( tabIndex, self.tabWidget.color_map(tabIndex)) @Slot(int, QModelIndex, int, QModelIndex) def cellClicked(self, tabIndex, cellIndex, prevTabIndex, prevCellIndex): cellClickedCommand = CellClickedCommand( tabIndex, cellIndex, prevTabIndex, prevCellIndex, self.tabWidget, self.graphicsScene, self.graphicsView, self.comboBox, self.messageLabel) self.undoStack.push(cellClickedCommand) @Slot() def SelectNextItem(self): if self.tabWidget.count() > 0: tabIndex = self.tabWidget.currentIndex() prevCellIndex = self.tabWidget.getCurrentSelectedCell() model = self.tabWidget.getCurrentTableModel() rowCount = model.rowCount(QModelIndex()) nextRow = (prevCellIndex.row() + 1) % rowCount cellIndex = model.index(nextRow, prevCellIndex.column()) cellClickedCommand = CellClickedCommand( tabIndex, cellIndex, tabIndex, prevCellIndex, self.tabWidget, self.graphicsScene, self.graphicsView, self.comboBox, self.messageLabel) self.undoStack.push(cellClickedCommand) @Slot() def SelectPreviousItem(self): if self.tabWidget.count() > 0: tabIndex = self.tabWidget.currentIndex() prevCellIndex = self.tabWidget.getCurrentSelectedCell() model = self.tabWidget.getCurrentTableModel() rowCount = model.rowCount(QModelIndex()) nextRow = (rowCount + prevCellIndex.row() - 1) % rowCount cellIndex = model.index(nextRow, prevCellIndex.column()) cellClickedCommand = CellClickedCommand( tabIndex, cellIndex, tabIndex, prevCellIndex, self.tabWidget, self.graphicsScene, self.graphicsView, self.comboBox, self.messageLabel) self.undoStack.push(cellClickedCommand) @Slot() def SelectNextPage(self): if self.tabWidget.count() > 0: tabIndex = self.tabWidget.currentIndex() prevCellIndex = self.tabWidget.getCurrentSelectedCell() model = self.tabWidget.getCurrentTableModel() prevPage = model.pageAtIndex(prevCellIndex) rowCount = model.rowCount(QModelIndex()) for i in range(0, rowCount): row = (prevCellIndex.row() + i) % rowCount cellIndex = model.index(row, prevCellIndex.column()) page = model.pageAtIndex(cellIndex) if prevPage.split("-")[0] != page.split("-")[0]: break cellClickedCommand = CellClickedCommand( tabIndex, cellIndex, tabIndex, prevCellIndex, self.tabWidget, self.graphicsScene, self.graphicsView, self.comboBox, self.messageLabel) self.undoStack.push(cellClickedCommand) @Slot() def SelectPreviousPage(self): if self.tabWidget.count() > 0: tabIndex = self.tabWidget.currentIndex() prevCellIndex = self.tabWidget.getCurrentSelectedCell() model = self.tabWidget.getCurrentTableModel() prevPage = model.pageAtIndex(prevCellIndex) rowCount = model.rowCount(QModelIndex()) for i in range(0, rowCount): row = (prevCellIndex.row() - i) % rowCount cellIndex = model.index(row, prevCellIndex.column()) page = model.pageAtIndex(cellIndex) if prevPage.split("-")[0] != page.split("-")[0]: break cellClickedCommand = CellClickedCommand( tabIndex, cellIndex, tabIndex, prevCellIndex, self.tabWidget, self.graphicsScene, self.graphicsView, self.comboBox, self.messageLabel) self.undoStack.push(cellClickedCommand) @Slot() def selectNextLabel(self): if self.comboBox.count() > 0: index = self.comboBox.currentIndex() newIndex = (index + 1) % self.comboBox.count() label = self.comboBox.itemText(newIndex) tabIndex = self.tabWidget.currentIndex() cellIndex = self.tabWidget.getCurrentSelectedCell() labelChangedCommand = LabelChangedCommand(label, tabIndex, cellIndex, self.tabWidget, self.graphicsScene, self.comboBox) self.undoStack.push(labelChangedCommand) @Slot() def selectPreviousLabel(self): if self.comboBox.count() > 0: index = self.comboBox.currentIndex() newIndex = ((self.comboBox.count() + index - 1) % self.comboBox.count()) label = self.comboBox.itemText(newIndex) tabIndex = self.tabWidget.currentIndex() cellIndex = self.tabWidget.getCurrentSelectedCell() labelChangedCommand = LabelChangedCommand(label, tabIndex, cellIndex, self.tabWidget, self.graphicsScene, self.comboBox) self.undoStack.push(labelChangedCommand) @Slot(int) def labelChanged(self, index): label = self.comboBox.itemText(index) tabIndex = self.tabWidget.currentIndex() cellIndex = self.tabWidget.getCurrentSelectedCell() labelChangedCommand = LabelChangedCommand(label, tabIndex, cellIndex, self.tabWidget, self.graphicsScene, self.comboBox) self.undoStack.push(labelChangedCommand) @Slot() def saveDataToDisk(self): self.undoStack.setClean() for name, model in zip(self.tabWidget.filenames(), self.tabWidget.models()): model.save(name) @Slot(bool) def cleanChanged(self, clean): self.setWindowModified(not clean) @Slot(int, int) def selectBox(self, tabIndex, rowIndex): if tabIndex != self.tabWidget.currentIndex() or \ rowIndex != self.tabWidget.getCurrentSelectedCell().row(): selectBoxCommand = SelectBoxCommand(tabIndex, rowIndex, self.tabWidget, self.graphicsScene, self.comboBox) self.undoStack.push(selectBoxCommand) @Slot(int, int, QRectF) def changeBox(self, tabIndex, rowIndex, box): moveBoxCommand = MoveBoxCommand(tabIndex, rowIndex, box, self.tabWidget, self.graphicsScene) self.undoStack.push(moveBoxCommand) @Slot(QRectF, QRectF) def viewportMoved(self, rect, prevRect): viewportMovedCommand = ViewportMovedCommand(rect, prevRect, self.graphicsView) self.undoStack.push(viewportMovedCommand) @Slot() def updateStopWatchLabel(self): elapsed = QTime(0, 0).addMSecs(self.time.elapsed()) self.stopwatchLabel.setText(elapsed.toString()) @Slot() def startStopWatch(self): self.time.start() self.stopwatch.start(1000) @Slot() def stopStopWatch(self): self.stopwatch.stop() @Slot() def resetStopWatch(self): self.time.start() self.stopwatchLabel.setText(QTime(0, 0).toString()) @Slot() def deleteItem(self): tabIndex = self.tabWidget.currentIndex() cellIndex = self.tabWidget.getCurrentSelectedCell() self.undoStack.beginMacro(f"Delete item {tabIndex}:{cellIndex.row()}") deleteItemCommand = DeleteItemCommand(tabIndex, cellIndex, self.tabWidget, self.graphicsView, self.graphicsScene, self.comboBox) self.undoStack.push(deleteItemCommand) cellIndex = cellIndex.model().index(cellIndex.row(), cellIndex.column()) if cellIndex.isValid(): self.cellClicked(tabIndex, cellIndex, tabIndex, cellIndex) self.undoStack.endMacro() @Slot(ResizableRect) def createItem(self, rect): self.undoStack.beginMacro( f"Create item {rect.tabIndex}:{rect.rowIndex}") createItemCommand = CreateItemCommand(rect, self.tabWidget, self.graphicsScene, self.comboBox) self.undoStack.push(createItemCommand) self.selectBox(rect.tabIndex, rect.rowIndex) self.undoStack.endMacro() @Slot() def tabItemForward(self): changeTabItemZValueCommand = ChangeTabItemZValueCommand( self.tabWidget.currentIndex(), 1, self.graphicsScene) self.undoStack.push(changeTabItemZValueCommand) @Slot() def tabItemBackward(self): changeTabItemZValueCommand = ChangeTabItemZValueCommand( self.tabWidget.currentIndex(), -1, self.graphicsScene) self.undoStack.push(changeTabItemZValueCommand) @Slot() def copy(self): tabIndex = self.tabWidget.currentIndex() cellIndex = self.tabWidget.getCurrentSelectedCell() box = self.graphicsScene.box(tabIndex, cellIndex.row()) copyCommand = CopyCommand(box, self.copyList) self.undoStack.push(copyCommand) @Slot() def paste(self): pos = self.graphicsView.mapFromGlobal(QCursor.pos()) scenePos = self.graphicsView.mapToScene(pos) prop = self.copyList[-1] self.undoStack.beginMacro(f"paste item {prop.box}") pasteCommand = PasteCommand(scenePos, prop, self.tabWidget, self.graphicsScene) self.undoStack.push(pasteCommand) self.selectBox(prop.tabIndex, prop.rowIndex) self.undoStack.endMacro()
class UndoableDict(PathDict): """ The UndoableDict class implements a PathDict-base_dict class with undo/redo functionality base_dict on QUndoStack. """ def __init__(self, *args, **kwargs): self.__stack = QUndoStack() self._macroRunning = False super().__init__(*args, **kwargs) # Public methods: dictionary-related def __setitem__(self, key: str, val: Any) -> NoReturn: """ Calls the undoable command to override PathDict assignment to self[key] implementation and pushes this command on the stack. """ if key in self: self.__stack.push(_SetItemCommand(self, key, val)) else: self.__stack.push(_AddItemCommand(self, key, val)) def setItemByPath(self, keys: list, value: Any) -> NoReturn: """ Calls the undoable command to set a value in a nested object by key sequence and pushes this command on the stack. """ self.__stack.push(_SetItemCommand(self, keys, value)) # Public methods: undo/redo-related def clearUndoStack(self) -> NoReturn: """ Clears the command stack by deleting all commands on it, and returns the stack to the clean state. """ self.__stack.clear() def canUndo(self) -> bool: """ :return true if there is a command available for undo; otherwise returns false. """ return self.__stack.canUndo() def canRedo(self) -> bool: """ :return true if there is a command available for redo; otherwise returns false. """ return self.__stack.canRedo() def undo(self) -> NoReturn: """ Undoes the current command on stack. """ self.__stack.undo() def redo(self) -> NoReturn: """ Redoes the current command on stack. """ self.__stack.redo() def undoText(self) -> str: """ :return the current command on stack. """ return self.__stack.undoText() def redoText(self) -> str: """ :return the current command on stack. """ return self.__stack.redoText() def startBulkUpdate(self, text='Bulk update') -> NoReturn: """ Begins composition of a macro command with the given text description. """ if self._macroRunning: print('Macro already running') return self.__stack.beginMacro(text) self._macroRunning = True def endBulkUpdate(self) -> NoReturn: """ Ends composition of a macro command. """ if not self._macroRunning: print('Macro not running') return self.__stack.endMacro() self._macroRunning = False def bulkUpdate(self, key_list: list, item_list: list, text='Bulk update') -> NoReturn: """ Performs a bulk update base_dict on a list of keys and a list of values :param key_list: list of keys or path keys to be updated :param item_list: the value to be updated :return: None """ self.startBulkUpdate(text) for key, value in zip(key_list, item_list): self.setItemByPath(key, value) self.endBulkUpdate()
class URDict(UserDict): """ The URDict class implements a dictionary-based class with undo/redo functionality based on QUndoStack. """ def __init__(self, *args, **kwargs): self._stack = QUndoStack() super().__init__(*args, **kwargs) self._macroRunning = False # Private URDict dictionary-based methods to be called via the QUndoCommand-based classes. def _realSetItem(self, key: Union[str, List], value: Any) -> NoReturn: """Actually changes the value for the existing key in dictionary.""" if isinstance(key, list): self.getItemByPath(key[:-1])[key[-1]] = value else: super().__setitem__(key, value) def _realAddItem(self, key: str, value: Any) -> NoReturn: """Actually adds a key-value pair to dictionary.""" super().__setitem__(key, value) def _realDelItem(self, key: str) -> NoReturn: """Actually deletes a key-value pair from dictionary.""" del self[key] def _realSetItemByPath(self, keys: list, value: Any) -> NoReturn: """Actually sets the value in a nested object by the key sequence.""" self.getItemByPath(keys[:-1])[keys[-1]] = value # Public URDict dictionary-based methods def __setitem__(self, key: str, val: Any) -> NoReturn: """Overrides default dictionary assignment to self[key] implementation. Calls the undoable command and pushes this command on the stack.""" if key in self: self._stack.push(_SetItemCommand(self, key, val)) else: self._stack.push(_AddItemCommand(self, key, val)) def setItemByPath(self, keys: list, value: Any) -> NoReturn: """Calls the undoable command to set a value in a nested object by key sequence and pushes this command on the stack.""" self._stack.push(_SetItemCommand(self, keys, value)) def getItemByPath(self, keys: list, default=None) -> Any: """Returns a value in a nested object by key sequence.""" item = self for key in keys: if key in item.keys(): item = item[key] else: return default return item def getItem(self, key: Union[str, list], default=None): """Returns a value in a nested object. Key can be either a sequence or a simple string.""" if isinstance(key, list): return self.getItemByPath(key, default) else: return self.get(key, default) # Public URDict undostack-based methods def undoText(self) -> NoReturn: """Returns the text of the command which will be undone in the next call to undo().""" return self._stack.undoText() def redoText(self) -> NoReturn: """Returns the text of the command which will be redone in the next call to redo().""" return self._stack.redoText() def undo(self) -> NoReturn: """Undoes the current command on stack.""" self._stack.undo() def redo(self) -> NoReturn: """Redoes the current command on stack.""" self._stack.redo() def startBulkUpdate(self, text='Bulk update') -> NoReturn: """Begins composition of a macro command with the given text description.""" if self._macroRunning: print('Macro already running') return self._stack.beginMacro(text) self._macroRunning = True def endBulkUpdate(self) -> NoReturn: """Ends composition of a macro command.""" if not self._macroRunning: print('Macro not running') return self._stack.endMacro() self._macroRunning = False
class NodeGraph(QtCore.QObject): node_selected = QtCore.Signal(list) def __init__(self, parent=None): super(NodeGraph, self).__init__(parent) self._model = NodeGraphModel() self._viewer = NodeViewer() self._undo_stack = QUndoStack(self) self._init_actions() self._wire_signals() self.patch_context_menu() def _wire_signals(self): self._viewer.moved_nodes.connect(self._on_nodes_moved) self._viewer.search_triggered.connect(self._on_search_triggered) self._viewer.connection_changed.connect(self._on_connection_changed) self._viewer.node_selected.connect(self._on_node_selected) def _init_actions(self): # setup tab search shortcut. tab = QAction('Search Nodes', self) tab.setShortcut('tab') tab.triggered.connect(self._toggle_tab_search) self._viewer.addAction(tab) setup_actions(self) def _toggle_tab_search(self): """ toggle the tab search widget. """ self._viewer.tab_search_set_nodes(NodeVendor.names) self._viewer.tab_search_toggle() def _on_nodes_moved(self, node_data): """ called when selected nodes in the viewer has changed position. Args: node_data (dict): {<node_view>: <previous_pos>} """ self._undo_stack.beginMacro('moved nodes') for node_view, prev_pos in node_data.items(): node = self._model.nodes[node_view.id] self._undo_stack.push(NodeMovedCmd(node, node.pos(), prev_pos)) self._undo_stack.endMacro() def _on_nodes_moved(self, node_data): """ called when a node in the viewer is selected on left click. """ nodes = self.selected_nodes() self.node_selected.emit(nodes) def _on_search_triggered(self, node_type, pos): """ called when the tab search widget is triggered in the viewer. Args: node_type (str): node identifier. pos (tuple): x,y position for the node. """ self.create_node(node_type, pos=pos) def _on_connection_changed(self, disconnected, connected): """ called when a pipe connection has been changed in the viewer. Args: disconnected (list[list[widgets.port.PortItem]): pair list of port view items. connected (list[list[widgets.port.PortItem]]): pair list of port view items. """ if not (disconnected or connected): return label = 'connected node(s)' if connected else 'disconnected node(s)' ptypes = {'in': 'inputs', 'out': 'outputs'} self._undo_stack.beginMacro(label) for p1_view, p2_view in disconnected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.disconnect_from(port2) for p1_view, p2_view in connected: node1 = self._model.nodes[p1_view.node.id] node2 = self._model.nodes[p2_view.node.id] port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name] port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name] port1.connect_to(port2) self._undo_stack.endMacro() @property def model(self): """ Return the node graph model. Returns: NodeGraphModel: model object. """ return self._model def show(self): """ Show node graph viewer widget. """ self._viewer.show() def hide(self): """ Hide node graph viewer widget. """ self._viewer.hide() def close(self): """ Close node graph viewer widget. """ self._viewer.close() def viewer(self): """ Return the node graph viewer widget object. Returns: NodeGraphQt.widgets.viewer.NodeViewer: viewer widget. """ return self._viewer def scene(self): """ Return the scene object. Returns: NodeGraphQt.widgets.scene.NodeScene: node scene. """ return self._viewer.scene() def undo_stack(self): """ Returns the undo stack used in the node graph Returns: QUndoStack: undo stack. """ return self._undo_stack def begin_undo(self, name='undo'): """ Start of an undo block followed by a end_undo(). Args: name (str): name for the undo block. """ self._undo_stack.beginMacro(name) def end_undo(self): """ End of an undo block started by begin_undo(). """ self._undo_stack.endMacro() def context_menu(self): """ Returns a node graph context menu object. Returns: ContextMenu: node graph context menu object instance. """ return self._viewer.context_menu() @staticmethod def modify_context_menu(viewer):