예제 #1
0
class _valueMenu(QtWidgets.QMenu):

    mouseMove = QtCore.Signal(object)
    mouseRelease = QtCore.Signal(object)
    stepChange = QtCore.Signal()

    def __init__(self, parent=None):
        super(_valueMenu, self).__init__(parent)
        self.step = 1
        self.last_action = None
        self.steps = []

    def set_steps(self, steps):
        self.clear()
        self.steps = steps
        for step in steps:
            self._add_action(step)

    def _add_action(self, step):
        action = QtWidgets.QAction(str(step), self)
        action.step = step
        self.addAction(action)

    def mouseMoveEvent(self, event):
        self.mouseMove.emit(event)
        super(_valueMenu, self).mouseMoveEvent(event)

        action = self.actionAt(event.pos())
        if action:
            if action is not self.last_action:
                self.stepChange.emit()
            self.last_action = action
            self.step = action.step
        elif self.last_action:
            self.setActiveAction(self.last_action)

    def mousePressEvent(self, event):
        return

    def mouseReleaseEvent(self, event):
        self.mouseRelease.emit(event)
        super(_valueMenu, self).mouseReleaseEvent(event)

    def set_data_type(self, dt):
        if dt is int:
            new_steps = []
            for step in self.steps:
                if "." not in str(step):
                    new_steps.append(step)
            self.set_steps(new_steps)
        elif dt is float:
            self.set_steps(self.steps)
예제 #2
0
class PropTextEdit(QtWidgets.QTextEdit):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropTextEdit, self).__init__(parent)
        self.__prev_text = ''

    def focusInEvent(self, event):
        super(PropTextEdit, self).focusInEvent(event)
        self.__prev_text = self.toPlainText()

    def focusOutEvent(self, event):
        super(PropTextEdit, self).focusOutEvent(event)
        if self.__prev_text != self.toPlainText():
            self.value_changed.emit(self.toolTip(), self.toPlainText())
        self.__prev_text = ''

    def get_value(self):
        return self.toPlainText()

    def set_value(self, value):
        _value = str(value)
        if _value != self.get_value():
            self.setPlainText(_value)
            self.value_changed.emit(self.toolTip(), _value)
예제 #3
0
class PropComboBox(QtWidgets.QComboBox):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropComboBox, self).__init__(parent)
        self.currentIndexChanged.connect(self._on_index_changed)

    def _on_index_changed(self):
        self.value_changed.emit(self.toolTip(), self.get_value())

    def items(self):
        return [self.itemText(i) for i in range(self.count())]

    def set_items(self, items):
        self.clear()
        self.addItems(items)

    def get_value(self):
        return self.currentText()

    def set_value(self, value):
        if type(value) is list:
            self.set_items(value)
            return
        if value != self.get_value():
            idx = self.findText(value, QtCore.Qt.MatchExactly)
            self.setCurrentIndex(idx)
            if idx >= 0:
                self.value_changed.emit(self.toolTip(), value)
예제 #4
0
class PropLineEdit(QtWidgets.QLineEdit):

    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropLineEdit, self).__init__(parent)
        self.__prev_text = ''
        self.returnPressed.connect(self._on_return_pressed)

    def focusInEvent(self, event):
        super(PropLineEdit, self).focusInEvent(event)
        self.__prev_text = self.text()

    def focusOutEvent(self, event):
        super(PropLineEdit, self).focusOutEvent(event)
        if self.__prev_text != self.text():
            self.value_changed.emit(self.toolTip(), self.text())
        self.__prev_text = ''

    def _on_return_pressed(self):
        if self.__prev_text != self.text():
            self.value_changed.emit(self.toolTip(), self.text())

    def get_value(self):
        return self.text()

    def set_value(self, value):
        if value != self.get_value():
            self.setText(str(value))
            self.value_changed.emit(self.toolTip(), value)
예제 #5
0
class TabSearchWidget(QtWidgets.QLineEdit):

    search_submitted = QtCore.Signal(str)

    def __init__(self, parent=None, node_dict=None):
        super(TabSearchWidget, self).__init__(parent)
        self.setAttribute(QtCore.Qt.WA_MacShowFocusRect, 0)
        self.setStyleSheet(STYLE_TABSEARCH)
        self.setMinimumSize(200, 22)
        self.setTextMargins(2, 0, 2, 0)
        self.hide()

        self._node_dict = node_dict or {}

        node_names = sorted(self._node_dict.keys())
        self._model = QtCore.QStringListModel(node_names, self)

        self._completer = TabSearchCompleter()
        self._completer.setModel(self._model)
        self.setCompleter(self._completer)

        popup = self._completer.popup()
        popup.setStyleSheet(STYLE_TABSEARCH_LIST)
        popup.clicked.connect(self._on_search_submitted)
        self.returnPressed.connect(self._on_search_submitted)

    def __repr__(self):
        return '<{} at {}>'.format(self.__class__.__name__, hex(id(self)))

    def _on_search_submitted(self, index=0):
        node_type = self._node_dict.get(self.text())
        if node_type:
            self.search_submitted.emit(node_type)
        self.close()
        self.parentWidget().clearFocus()

    def showEvent(self, event):
        super(TabSearchWidget, self).showEvent(event)
        self.setSelection(0, len(self.text()))
        self.setFocus()
        if not self.text():
            self.completer().popup().show()
            self.completer().complete()

    def mousePressEvent(self, event):
        if not self.text():
            self.completer().complete()

    def set_nodes(self, node_dict=None):
        self._node_dict = {}
        for name, node_types in node_dict.items():
            if len(node_types) == 1:
                self._node_dict[name] = node_types[0]
                continue
            for node_id in node_types:
                self._node_dict['{} ({})'.format(name, node_id)] = node_id
        node_names = sorted(self._node_dict.keys())
        self._model.setStringList(node_names)
        self._completer.setModel(self._model)
예제 #6
0
class BaseProperty(QtWidgets.QWidget):
    value_changed = QtCore.Signal(str, object)

    def set_value(self, value):
        raise NotImplementedError

    def get_value(self):
        raise NotImplementedError
예제 #7
0
class PropLabel(QtWidgets.QLabel):
    value_changed = QtCore.Signal(str, object)

    def get_value(self):
        return self.text()

    def set_value(self, value):
        if value != self.get_value():
            self.setText(str(value))
            self.value_changed.emit(self.toolTip(), value)
예제 #8
0
class NodeAction(GraphAction):

    executed = QtCore.Signal(object, object)

    def __init__(self, *args, **kwargs):
        super(NodeAction, self).__init__(*args, **kwargs)
        self.node_id = None

    def _on_triggered(self):
        node = self.graph.get_node_by_id(self.node_id)
        self.executed.emit(self.graph, node)
예제 #9
0
class NodeBaseWidget(QtWidgets.QGraphicsProxyWidget):
    """
    Base Node Widget.
    """

    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None, name='widget', label=''):
        super(NodeBaseWidget, self).__init__(parent)
        self.setZValue(Z_VAL_NODE_WIDGET)
        self._name = name
        self._label = label

    def _value_changed(self):
        self.value_changed.emit(self.name, self.value)

    def setToolTip(self, tooltip):
        tooltip = tooltip.replace('\n', '<br/>')
        tooltip = '<b>{}</b><br/>{}'.format(self.name, tooltip)
        super(NodeBaseWidget, self).setToolTip(tooltip)

    @property
    def widget(self):
        raise NotImplementedError

    @property
    def value(self):
        raise NotImplementedError

    @value.setter
    def value(self, text):
        raise NotImplementedError

    @property
    def label(self):
        return self._label

    @label.setter
    def label(self, label):
        self._label = label

    @property
    def type_(self):
        return str(self.__class__.__name__)

    @property
    def node(self):
        self.parentItem()

    @property
    def name(self):
        return self._name
예제 #10
0
class PropButton(QtWidgets.QPushButton):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropButton, self).__init__(parent)

    def set_value(self, value, node=None):
        # value: list of functions
        if type(value) is not list:
            return
        for func in value:
            self.clicked.connect(lambda: func(node))

    def get_value(self):
        return None
예제 #11
0
class PropComboBox(QtWidgets.QComboBox):

    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropComboBox, self).__init__(parent)
        self.currentIndexChanged.connect(self._on_index_changed)
        self.setMaxVisibleItems(10)
        self.setStyleSheet('combobox-popup: 0')

    def _on_index_changed(self):
        self.value_changed.emit(self.toolTip(), self.get_value())

    def items(self):
        return [self.itemText(i) for i in range(self.count())]

    def set_items(self, items):
        self.clear()

        if len(items) > 0:
            if isinstance(items[0], tuple):
                for key, value in items:
                    self.addItem(key, value)

                return

        self.addItems(items)

    def get_value(self):
        text = self.currentText()
        data = self.currentData()
        return data if data is not None else text

    def set_value(self, value):
        if value != self.get_value():
            idx = None
            for i in range(self.count()):
                itemData = self.itemData(i)
                if itemData == value:
                    idx = i
                    break

            if idx is None:
                idx = self.findText(value, QtCore.Qt.MatchExactly)

            self.setCurrentIndex(idx)
            if idx >= 0:
                self.value_changed.emit(self.toolTip(), value)
예제 #12
0
class GraphAction(QtWidgets.QAction):

    executed = QtCore.Signal(object)

    def __init__(self, *args, **kwargs):
        super(GraphAction, self).__init__(*args, **kwargs)
        self.graph = None
        self.triggered.connect(self._on_triggered)

    def _on_triggered(self):
        self.executed.emit(self.graph)

    def get_action(self, name):
        for action in self.qmenu.actions():
            if not action.menu() and action.text() == name:
                return action
예제 #13
0
class PropSpinBox(QtWidgets.QSpinBox):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropSpinBox, self).__init__(parent)
        self.setButtonSymbols(self.NoButtons)
        self.valueChanged.connect(self._on_value_change)

    def _on_value_change(self, value):
        self.value_changed.emit(self.toolTip(), value)

    def get_value(self):
        return self.value()

    def set_value(self, value):
        if value != self.get_value():
            self.setValue(value)
예제 #14
0
class PropFloat(_valueSliderEdit):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropFloat, self).__init__(parent)
        self.valueChanged.connect(self._on_value_changed)

    def _on_value_changed(self, value):
        self.value_changed.emit(self.toolTip(), value)

    def get_value(self):
        return self.value()

    def set_value(self, value):
        if value != self.get_value():
            self.setValue(value)
            self.value_changed.emit(self.toolTip(), value)
예제 #15
0
class PropCheckBox(QtWidgets.QCheckBox):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropCheckBox, self).__init__(parent)
        self.clicked.connect(self._on_clicked)

    def _on_clicked(self):
        self.value_changed.emit(self.toolTip(), self.get_value())

    def get_value(self):
        return self.isChecked()

    def set_value(self, value):
        if value != self.get_value():
            self.setChecked(value)
            self.value_changed.emit(self.toolTip(), value)
예제 #16
0
class PropLineEdit(QtWidgets.QLineEdit):
    value_changed = QtCore.Signal(str, object)

    def __init__(self, parent=None):
        super(PropLineEdit, self).__init__(parent)
        self.editingFinished.connect(self._on_editing_finished)

    def _on_editing_finished(self):
        self.value_changed.emit(self.toolTip(), self.text())

    def get_value(self):
        return self.text()

    def set_value(self, value):
        _value = str(value)
        if _value != self.get_value():
            self.setText(_value)
            self.value_changed.emit(self.toolTip(), _value)
예제 #17
0
class NodePropWidget(QtWidgets.QWidget):
    """
    Node properties widget for display a Node object.

    Args:
        parent (QtWidgets.QWidget): parent object.
        node (NodeGraphQt.BaseNode): node.
    """

    #: signal (node_id, prop_name, prop_value)
    property_changed = QtCore.Signal(str, str, object)
    property_closed = QtCore.Signal(str)

    def __init__(self, parent=None, node=None):
        super(NodePropWidget, self).__init__(parent)
        self.__node_id = node.id
        self.__tab_windows = {}
        self.__tab = QtWidgets.QTabWidget()

        close_btn = QtWidgets.QPushButton('X')
        close_btn.setToolTip('close property')
        close_btn.clicked.connect(self._on_close)

        self.name_wgt = PropLineEdit()
        self.name_wgt.setToolTip('name')
        self.name_wgt.set_value(node.name())
        self.name_wgt.value_changed.connect(self._on_property_changed)

        self.type_wgt = QtWidgets.QLabel(node.type_)
        self.type_wgt.setAlignment(QtCore.Qt.AlignRight)
        self.type_wgt.setToolTip('type_')
        font = self.type_wgt.font()
        font.setPointSize(10)
        self.type_wgt.setFont(font)

        name_layout = QtWidgets.QHBoxLayout()
        name_layout.setContentsMargins(0, 0, 0, 0)
        name_layout.addWidget(QtWidgets.QLabel('name'))
        name_layout.addWidget(self.name_wgt)
        name_layout.addWidget(close_btn)
        layout = QtWidgets.QVBoxLayout(self)
        layout.setSpacing(4)
        layout.addLayout(name_layout)
        layout.addWidget(self.__tab)
        layout.addWidget(self.type_wgt)
        self._read_node(node)

    def __repr__(self):
        return '<NodePropWidget object at {}>'.format(hex(id(self)))

    def _on_close(self):
        """
        called by the close button.
        """
        self.property_closed.emit(self.__node_id)

    def _on_property_changed(self, name, value):
        """
        slot function called when a property widget has changed.

        Args:
            name (str): property name.
            value (object): new value.
        """
        self.property_changed.emit(self.__node_id, name, value)

    def _read_node(self, node):
        """
        Populate widget from a node.

        Args:
            node (NodeGraphQt.BaseNode): node class.
        """
        model = node.model
        graph_model = node.graph.model

        common_props = graph_model.get_node_common_properties(node.type_)

        # sort tabs and properties.
        tab_mapping = defaultdict(list)
        for prop_name, prop_val in model.custom_properties.items():
            tab_name = model.get_tab_name(prop_name)
            tab_mapping[tab_name].append((prop_name, prop_val))

        # add tabs.
        for tab in sorted(tab_mapping.keys()):
            if tab != 'Node':
                self.add_tab(tab)

        min_widget_height = 25
        # populate tab properties.
        for tab in sorted(tab_mapping.keys()):
            prop_window = self.__tab_windows[tab]
            for prop_name, value in tab_mapping[tab]:
                wid_type = model.get_widget_type(prop_name)
                if wid_type == 0:
                    continue

                WidClass = WIDGET_MAP.get(wid_type)
                widget = WidClass()
                widget.setMinimumHeight(min_widget_height)
                if prop_name in common_props.keys():
                    if 'items' in common_props[prop_name].keys():
                        _prop_name = '_' + prop_name + "_"
                        if node.has_property(_prop_name):
                            widget.set_items(node.get_property(_prop_name))
                        else:
                            widget.set_items(common_props[prop_name]['items'])
                    if 'range' in common_props[prop_name].keys():
                        prop_range = common_props[prop_name]['range']
                        widget.set_min(prop_range[0])
                        widget.set_max(prop_range[1])
                    if 'ext' in common_props[prop_name].keys():
                        widget.set_ext(common_props[prop_name]['ext'])
                    if 'funcs' in common_props[prop_name].keys():
                        widget.set_value(common_props[prop_name]['funcs'],
                                         node)

                prop_window.add_widget(prop_name, widget, value,
                                       prop_name.replace('_', ' '))
                widget.value_changed.connect(self._on_property_changed)

        # add "Node" tab properties.
        self.add_tab('Node')
        default_props = ['color', 'text_color', 'disabled', 'id']
        prop_window = self.__tab_windows['Node']
        for prop_name in default_props:
            wid_type = model.get_widget_type(prop_name)
            WidClass = WIDGET_MAP.get(wid_type)

            widget = WidClass()
            widget.setMinimumHeight(min_widget_height)
            prop_window.add_widget(prop_name, widget,
                                   model.get_property(prop_name),
                                   prop_name.replace('_', ' '))

            widget.value_changed.connect(self._on_property_changed)

        self.type_wgt.setText(model.get_property('type_'))

    def node_id(self):
        """
        Returns the node id linked to the widget.

        Returns:
            str: node id
        """
        return self.__node_id

    def add_widget(self, name, widget, tab='Properties'):
        """
        add new node property widget.

        Args:
            name (str): property name.
            widget (BaseProperty): property widget.
            tab (str): tab name.
        """
        if tab not in self._widgets.keys():
            tab = 'Properties'
        window = self.__tab_windows[tab]
        window.add_widget(name, widget)
        widget.value_changed.connect(self._on_property_changed)

    def add_tab(self, name):
        """
        add a new tab.

        Args:
            name (str): tab name.

        Returns:
            PropWindow: tab child widget.
        """
        if name in self.__tab_windows.keys():
            raise AssertionError('Tab name {} already taken!'.format(name))
        self.__tab_windows[name] = PropWindow(self)
        self.__tab.addTab(self.__tab_windows[name], name)
        return self.__tab_windows[name]

    def get_widget(self, name):
        """
        get property widget.

        Args:
            name (str): property name.

        Returns:
            QtWidgets.QWidget: property widget.
        """
        if name == 'name':
            return self.name_wgt
        for tab_name, prop_win in self.__tab_windows.items():
            widget = prop_win.get_widget(name)
            if widget:
                return widget
예제 #18
0
class NodeViewer(QtWidgets.QGraphicsView):
    """
    node viewer is the widget used for displaying the scene and nodes

    functions in this class is used internally by the
    class:`NodeGraphQt.NodeGraph` class.
    """

    moved_nodes = QtCore.Signal(dict)
    search_triggered = QtCore.Signal(str, tuple)
    connection_sliced = QtCore.Signal(list)
    connection_changed = QtCore.Signal(list, list)

    # pass through signals
    node_selected = QtCore.Signal(str)
    node_double_clicked = QtCore.Signal(str)
    data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint)

    def __init__(self, parent=None):
        super(NodeViewer, self).__init__(parent)
        if parent is not None:
            self.setWindowFlags(QtCore.Qt.Window)

        scene_pos = (SCENE_AREA / 2) * -1
        self.setScene(NodeScene(self))
        self.setSceneRect(scene_pos, scene_pos, SCENE_AREA, SCENE_AREA)
        self.setRenderHint(QtGui.QPainter.Antialiasing, True)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.setViewportUpdateMode(QtWidgets.QGraphicsView.FullViewportUpdate)
        self.setAcceptDrops(True)
        self.resize(1000, 800)

        self._pipe_layout = PIPE_LAYOUT_CURVED
        self._live_pipe = None
        self._detached_port = None
        self._start_port = None
        self._origin_pos = None
        self._previous_pos = QtCore.QPoint(self.width(), self.height())
        self._prev_selection = []
        self._node_positions = {}
        self._rubber_band = QtWidgets.QRubberBand(
            QtWidgets.QRubberBand.Rectangle, self)
        self._pipe_slicer = SlicerPipe()
        self._pipe_slicer.setVisible(False)
        self.scene().addItem(self._pipe_slicer)

        self._undo_stack = QtWidgets.QUndoStack(self)
        self._context_menu = QtWidgets.QMenu('main', self)
        self._context_menu.setStyleSheet(STYLE_QMENU)
        self._search_widget = TabSearchWidget(self)
        self._search_widget.search_submitted.connect(self._on_search_submitted)

        # workaround fix for shortcuts from the non-native menu actions
        # don't seem to trigger so we create a hidden menu bar.
        menu_bar = QtWidgets.QMenuBar(self)
        menu_bar.setNativeMenuBar(False)
        # shortcuts don't work with "setVisibility(False)".
        menu_bar.resize(0, 0)
        menu_bar.addMenu(self._context_menu)

        self.acyclic = True
        self.LMB_state = False
        self.RMB_state = False
        self.MMB_state = False

    def __repr__(self):
        return '{}.{}()'.format(self.__module__, self.__class__.__name__)

    # --- private ---

    def _set_viewer_zoom(self, value, sensitivity=0.0):
        if value == 0.0:
            return
        scale = (0.9 + sensitivity) if value < 0.0 else (1.1 - sensitivity)
        zoom = self.get_zoom()
        if ZOOM_MIN >= zoom:
            if scale == 0.9:
                return
        if ZOOM_MAX <= zoom:
            if scale == 1.1:
                return
        self.scale(scale, scale)

    def _set_viewer_pan(self, pos_x, pos_y):
        scroll_x = self.horizontalScrollBar()
        scroll_y = self.verticalScrollBar()
        scroll_x.setValue(scroll_x.value() - pos_x)
        scroll_y.setValue(scroll_y.value() - pos_y)

    def _combined_rect(self, nodes):
        group = self.scene().createItemGroup(nodes)
        rect = group.boundingRect()
        self.scene().destroyItemGroup(group)
        return rect

    def _items_near(self, pos, item_type=None, width=20, height=20):
        x, y = pos.x() - width, pos.y() - height
        rect = QtCore.QRectF(x, y, width, height)
        items = []
        for item in self.scene().items(rect):
            if not item_type or isinstance(item, item_type):
                items.append(item)
        return items

    def _on_search_submitted(self, node_type):
        pos = self.mapToScene(self._previous_pos)
        self.search_triggered.emit(node_type, (pos.x(), pos.y()))

    def _on_pipes_sliced(self, path):
        self.connection_sliced.emit([[i.input_port, i.output_port]
                                     for i in self.scene().items(path)
                                     if isinstance(i, Pipe)])

    # --- reimplemented events ---

    def resizeEvent(self, event):
        super(NodeViewer, self).resizeEvent(event)

    def contextMenuEvent(self, event):
        self.RMB_state = False
        if self._context_menu.isEnabled():
            self._context_menu.exec_(event.globalPos())
        else:
            return super(NodeViewer, self).contextMenuEvent(event)

    def mousePressEvent(self, event):
        alt_modifier = event.modifiers() == QtCore.Qt.AltModifier
        shift_modifier = event.modifiers() == QtCore.Qt.ShiftModifier

        if event.button() == QtCore.Qt.LeftButton:
            self.LMB_state = True
        elif event.button() == QtCore.Qt.RightButton:
            self.RMB_state = True
        elif event.button() == QtCore.Qt.MiddleButton:
            self.MMB_state = True
        self._origin_pos = event.pos()
        self._previous_pos = event.pos()
        self._prev_selection = self.selected_nodes()

        # close tab search
        if self._search_widget.isVisible():
            self.tab_search_toggle()

        # cursor pos.
        map_pos = self.mapToScene(event.pos())

        # pipe slicer enabled.
        if event.modifiers() == (QtCore.Qt.AltModifier
                                 | QtCore.Qt.ShiftModifier):
            self._pipe_slicer.draw_path(map_pos, map_pos)
            self._pipe_slicer.setVisible(True)
            return

        if alt_modifier:
            return

        items = self._items_near(map_pos, None, 20, 20)
        nodes = [i for i in items if isinstance(i, AbstractNodeItem)]

        # toggle extend node selection.
        if shift_modifier:
            for node in nodes:
                node.selected = not node.selected

        # update the recorded node positions.
        self._node_positions.update(
            {n: n.xy_pos
             for n in self.selected_nodes()})

        # show selection selection marquee
        if self.LMB_state and not items:
            rect = QtCore.QRect(self._previous_pos, QtCore.QSize())
            rect = rect.normalized()
            map_rect = self.mapToScene(rect).boundingRect()
            self.scene().update(map_rect)
            self._rubber_band.setGeometry(rect)
            self._rubber_band.show()

        if not shift_modifier:
            super(NodeViewer, self).mousePressEvent(event)

    def mouseReleaseEvent(self, event):
        if event.button() == QtCore.Qt.LeftButton:
            self.LMB_state = False
        elif event.button() == QtCore.Qt.RightButton:
            self.RMB_state = False
        elif event.button() == QtCore.Qt.MiddleButton:
            self.MMB_state = False

        # hide pipe slicer.
        if self._pipe_slicer.isVisible():
            self._on_pipes_sliced(self._pipe_slicer.path())
            p = QtCore.QPointF(0.0, 0.0)
            self._pipe_slicer.draw_path(p, p)
            self._pipe_slicer.setVisible(False)

        # hide selection marquee
        if self._rubber_band.isVisible():
            rect = self._rubber_band.rect()
            map_rect = self.mapToScene(rect).boundingRect()
            self._rubber_band.hide()
            self.scene().update(map_rect)

        # find position changed nodes and emit signal.
        moved_nodes = {
            n: xy_pos
            for n, xy_pos in self._node_positions.items() if n.xy_pos != xy_pos
        }
        if moved_nodes:
            self.moved_nodes.emit(moved_nodes)

        # reset recorded positions.
        self._node_positions = {}

        super(NodeViewer, self).mouseReleaseEvent(event)

    def mouseMoveEvent(self, event):
        alt_modifier = event.modifiers() == QtCore.Qt.AltModifier
        shift_modifier = event.modifiers() == QtCore.Qt.ShiftModifier
        if event.modifiers() == (QtCore.Qt.AltModifier
                                 | QtCore.Qt.ShiftModifier):
            if self.LMB_state:
                p1 = self._pipe_slicer.path().pointAtPercent(0)
                p2 = self.mapToScene(self._previous_pos)
                self._pipe_slicer.draw_path(p1, p2)
                self._previous_pos = event.pos()
                super(NodeViewer, self).mouseMoveEvent(event)
                return

        if self.MMB_state and alt_modifier:
            pos_x = (event.x() - self._previous_pos.x())
            zoom = 0.1 if pos_x > 0 else -0.1
            self._set_viewer_zoom(zoom, 0.05)
        elif self.MMB_state or (self.LMB_state and alt_modifier):
            pos_x = (event.x() - self._previous_pos.x())
            pos_y = (event.y() - self._previous_pos.y())
            self._set_viewer_pan(pos_x, pos_y)

        if self.LMB_state and self._rubber_band.isVisible():
            rect = QtCore.QRect(self._origin_pos, event.pos()).normalized()
            map_rect = self.mapToScene(rect).boundingRect()
            path = QtGui.QPainterPath()
            path.addRect(map_rect)
            self._rubber_band.setGeometry(rect)
            self.scene().setSelectionArea(path, QtCore.Qt.IntersectsItemShape)
            self.scene().update(map_rect)

            if shift_modifier and self._prev_selection:
                for node in self._prev_selection:
                    if node not in self.selected_nodes():
                        node.selected = True

        self._previous_pos = event.pos()
        super(NodeViewer, self).mouseMoveEvent(event)

    def wheelEvent(self, event):
        try:
            delta = event.delta()
        except AttributeError:
            # For PyQt5
            delta = event.angleDelta().y()

        adjust = (delta / 120) * 0.1
        self._set_viewer_zoom(adjust)

    def dropEvent(self, event):
        pos = self.mapToScene(event.pos())
        event.setDropAction(QtCore.Qt.MoveAction)
        self.data_dropped.emit(event.mimeData(),
                               QtCore.QPoint(pos.x(), pos.y()))

    def dragEnterEvent(self, event):
        if event.mimeData().hasFormat('text/plain'):
            event.accept()
        else:
            event.ignore()

    def dragMoveEvent(self, event):
        if event.mimeData().hasFormat('text/plain'):
            event.accept()
        else:
            event.ignore()

    def dragLeaveEvent(self, event):
        event.ignore()

    # --- scene events ---

    def sceneMouseMoveEvent(self, event):
        """
        triggered mouse move event for the scene.
         - redraw the connection pipe.

        Args:
            event (QtWidgets.QGraphicsSceneMouseEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        if not self._live_pipe:
            return
        if not self._start_port:
            return

        pos = event.scenePos()
        items = self.scene().items(pos)
        if items and isinstance(items[0], PortItem):
            x = items[0].boundingRect().width() / 2
            y = items[0].boundingRect().height() / 2
            pos = items[0].scenePos()
            pos.setX(pos.x() + x)
            pos.setY(pos.y() + y)

        self._live_pipe.draw_path(self._start_port, None, pos)

    def sceneMousePressEvent(self, event):
        """
        triggered mouse press event for the scene (takes priority over viewer event).
         - detect selected pipe and start connection.
         - remap Shift and Ctrl modifier.

        Args:
            event (QtWidgets.QGraphicsScenePressEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        # pipe slicer enabled.
        if event.modifiers() == (QtCore.Qt.AltModifier
                                 | QtCore.Qt.ShiftModifier):
            return
        # viewer pan mode.
        if event.modifiers() == QtCore.Qt.AltModifier:
            return

        pos = event.scenePos()
        port_items = self._items_near(pos, PortItem, 5, 5)
        if port_items:
            port = port_items[0]
            if not port.multi_connection and port.connected_ports:
                self._detached_port = port.connected_ports[0]
            self.start_live_connection(port)
            if not port.multi_connection:
                [p.delete() for p in port.connected_pipes]
            return

        node_items = self._items_near(pos, AbstractNodeItem, 3, 3)
        if node_items:
            node = node_items[0]

            # record the node positions at selection time.
            for n in node_items:
                self._node_positions[n] = n.xy_pos

            # emit selected node id with LMB.
            if event.button() == QtCore.Qt.LeftButton:
                self.node_selected.emit(node.id)

            if not isinstance(node, BackdropNodeItem):
                return

        pipe_items = self._items_near(pos, Pipe, 3, 3)
        if pipe_items:
            if not self.LMB_state:
                return
            pipe = pipe_items[0]
            attr = {IN_PORT: 'output_port', OUT_PORT: 'input_port'}
            from_port = pipe.port_from_pos(pos, True)
            from_port._hovered = True

            self._detached_port = getattr(pipe, attr[from_port.port_type])
            self.start_live_connection(from_port)
            self._live_pipe.draw_path(self._start_port, None, pos)
            pipe.delete()

    def sceneMouseReleaseEvent(self, event):
        """
        triggered mouse release event for the scene.
         - verify to make a the connection Pipe.

        Args:
            event (QtWidgets.QGraphicsSceneMouseEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        if not self._live_pipe:
            return

        self._start_port._hovered = False

        # find the end port.
        end_port = None
        for item in self.scene().items(event.scenePos()):
            if isinstance(item, PortItem):
                end_port = item
                break

        connected = []
        disconnected = []

        # if port disconnected from existing pipe.
        if end_port is None:
            if self._detached_port:
                disconnected.append((self._start_port, self._detached_port))
                self.connection_changed.emit(disconnected, connected)

            self._detached_port = None
            self.end_live_connection()
            return

        # restore connection check.
        restore_connection = any([
            # if same port type.
            end_port.port_type == self._start_port.port_type,
            # if connection to itself.
            end_port.node == self._start_port.node,
            # if end port is the start port.
            end_port == self._start_port,
            # if detached port is the end port.
            self._detached_port == end_port
        ])
        if restore_connection:
            if self._detached_port:
                to_port = self._detached_port or end_port
                self.establish_connection(self._start_port, to_port)
                self._detached_port = None
            self.end_live_connection()
            return

        # register as disconnected if not acyclic.
        if self.acyclic and not self.acyclic_check(self._start_port, end_port):
            if self._detached_port:
                disconnected.append((self._start_port, self._detached_port))

            self.connection_changed.emit(disconnected, connected)

            self._detached_port = None
            self.end_live_connection()
            return

        # make connection.
        if not end_port.multi_connection and end_port.connected_ports:
            dettached_end = end_port.connected_ports[0]
            disconnected.append((end_port, dettached_end))

        if self._detached_port:
            disconnected.append((self._start_port, self._detached_port))

        connected.append((self._start_port, end_port))

        self.connection_changed.emit(disconnected, connected)

        self._detached_port = None
        self.end_live_connection()

    # --- port connections ---

    def start_live_connection(self, selected_port):
        """
        create new pipe for the connection.
        (draws the live pipe from the port following the cursor position)
        """
        if not selected_port:
            return
        self._start_port = selected_port
        self._live_pipe = Pipe()
        self._live_pipe.setZValue(Z_VAL_NODE_WIDGET)
        self._live_pipe.activate()
        self._live_pipe.style = PIPE_STYLE_DASHED
        if self._start_port.type == IN_PORT:
            self._live_pipe.input_port = self._start_port
        elif self._start_port == OUT_PORT:
            self._live_pipe.output_port = self._start_port
        self.scene().addItem(self._live_pipe)

    def end_live_connection(self):
        """
        delete live connection pipe and reset start port.
        (removes the pipe item used for drawing the live connection)
        """
        if self._live_pipe:
            self._live_pipe.delete()
            self._live_pipe = None
        self._start_port = None

    def establish_connection(self, start_port, end_port):
        """
        establish a new pipe connection.
        (adds a new pipe item to draw between 2 ports)
        """
        pipe = Pipe()
        self.scene().addItem(pipe)
        pipe.set_connections(start_port, end_port)
        pipe.draw_path(pipe.input_port, pipe.output_port)
        if start_port.node.selected or end_port.node.selected:
            pipe.highlight()

    def acyclic_check(self, start_port, end_port):
        """
        validate the connection so it doesn't loop itself.

        Returns:
            bool: True if port connection is valid.
        """
        start_node = start_port.node
        check_nodes = [end_port.node]
        io_types = {IN_PORT: 'outputs', OUT_PORT: 'inputs'}
        while check_nodes:
            check_node = check_nodes.pop(0)
            for check_port in getattr(check_node,
                                      io_types[end_port.port_type]):
                if check_port.connected_ports:
                    for port in check_port.connected_ports:
                        if port.node != start_node:
                            check_nodes.append(port.node)
                        else:
                            return False
        return True

    # --- viewer ---

    def tab_search_set_nodes(self, nodes):
        self._search_widget.set_nodes(nodes)

    def tab_search_toggle(self):
        pos = self._previous_pos
        state = not self._search_widget.isVisible()
        if state:
            rect = self._search_widget.rect()
            new_pos = QtCore.QPoint(pos.x() - rect.width() / 2,
                                    pos.y() - rect.height() / 2)
            self._search_widget.move(new_pos)
            self._search_widget.setVisible(state)
            rect = self.mapToScene(rect).boundingRect()
            self.scene().update(rect)
        else:
            self._search_widget.setVisible(state)
            self.clearFocus()

    def context_menu(self):
        return self._context_menu

    def question_dialog(self, text, title='Node Graph'):
        dlg = QtWidgets.QMessageBox.question(self, title, text,
                                             QtWidgets.QMessageBox.Yes,
                                             QtWidgets.QMessageBox.No)
        return dlg == QtWidgets.QMessageBox.Yes

    def message_dialog(self, text, title='Node Graph'):
        QtWidgets.QMessageBox.information(self, title, text,
                                          QtWidgets.QMessageBox.Ok)

    def load_dialog(self, current_dir=None, ext=None):
        current_dir = current_dir or os.path.expanduser('~')
        ext = '*{} '.format(ext) if ext else ''
        ext_filter = ';;'.join(
            ['Node Graph ({}*json)'.format(ext), 'All Files (*)'])
        file_dlg = QtWidgets.QFileDialog.getOpenFileName(
            self, 'Open Session Setup', current_dir, ext_filter)
        return file_dlg[0] or None

    def save_dialog(self, current_dir=None, ext=None):
        current_dir = current_dir or os.path.expanduser('~')
        ext_label = '*{} '.format(ext) if ext else ''
        ext_type = '.{}'.format(ext) if ext else '.json'
        ext_map = {
            'Node Graph ({}*json)'.format(ext_label): ext_type,
            'All Files (*)': ''
        }
        file_dlg = QtWidgets.QFileDialog.getSaveFileName(
            self, 'Save Session', current_dir, ';;'.join(ext_map.keys()))
        file_path = file_dlg[0]
        if not file_path:
            return
        ext = ext_map[file_dlg[1]]
        if ext and not file_path.endswith(ext):
            file_path += ext
        return file_path

    def all_pipes(self):
        pipes = []
        for item in self.scene().items():
            if isinstance(item, Pipe):
                pipes.append(item)
        return pipes

    def all_nodes(self):
        nodes = []
        for item in self.scene().items():
            if isinstance(item, AbstractNodeItem):
                nodes.append(item)
        return nodes

    def selected_nodes(self):
        nodes = []
        for item in self.scene().selectedItems():
            if isinstance(item, AbstractNodeItem):
                nodes.append(item)
        return nodes

    def add_node(self, node, pos=None):
        pos = pos or (self._previous_pos.x(), self._previous_pos.y())
        node.pre_init(self, pos)
        self.scene().addItem(node)
        node.post_init(self, pos)

    def remove_node(self, node):
        if isinstance(node, AbstractNodeItem):
            node.delete()

    def move_nodes(self, nodes, pos=None, offset=None):
        group = self.scene().createItemGroup(nodes)
        group_rect = group.boundingRect()
        if pos:
            x, y = pos
        else:
            pos = self.mapToScene(self._previous_pos)
            x = pos.x() - group_rect.center().x()
            y = pos.y() - group_rect.center().y()
        if offset:
            x += offset[0]
            y += offset[1]
        group.setPos(x, y)
        self.scene().destroyItemGroup(group)

    def get_pipes_from_nodes(self, nodes=None):
        nodes = nodes or self.selected_nodes()
        if not nodes:
            return
        pipes = []
        for node in nodes:
            n_inputs = node.inputs if hasattr(node, 'inputs') else []
            n_outputs = node.outputs if hasattr(node, 'outputs') else []

            for port in n_inputs:
                for pipe in port.connected_pipes:
                    connected_node = pipe.output_port.node
                    if connected_node in nodes:
                        pipes.append(pipe)
            for port in n_outputs:
                for pipe in port.connected_pipes:
                    connected_node = pipe.input_port.node
                    if connected_node in nodes:
                        pipes.append(pipe)
        return pipes

    def center_selection(self, nodes=None):
        if not nodes:
            if self.selected_nodes():
                nodes = self.selected_nodes()
            elif self.all_nodes():
                nodes = self.all_nodes()
        if len(nodes) == 1:
            self.centerOn(nodes[0])
        else:
            rect = self._combined_rect(nodes)
            self.centerOn(rect.center().x(), rect.center().y())

    def get_pipe_layout(self):
        return self._pipe_layout

    def set_pipe_layout(self, layout):
        self._pipe_layout = layout
        for pipe in self.all_pipes():
            pipe.draw_path(pipe.input_port, pipe.output_port)

    def reset_zoom(self):
        self.scale(1.0, 1.0)
        self.resetTransform()

    def get_zoom(self):
        transform = self.transform()
        cur_scale = (transform.m11(), transform.m22())
        return float('{:0.2f}'.format(cur_scale[0] - 1.0))

    def set_zoom(self, value=0.0):
        if value == 0.0:
            self.reset_zoom()
            return
        zoom = self.get_zoom()
        if zoom < 0.0:
            if not (ZOOM_MIN <= zoom <= ZOOM_MAX):
                return
        else:
            if not (ZOOM_MIN <= value <= ZOOM_MAX):
                return
        value = value - zoom
        self._set_viewer_zoom(value)

    def zoom_to_nodes(self, nodes):
        rect = self._combined_rect(nodes)
        self.fitInView(rect, QtCore.Qt.KeepAspectRatio)
        if self.get_zoom() > 0.1:
            self.reset_zoom()
예제 #19
0
class AutoNode(BaseNode, QtCore.QObject):
    cooked = QtCore.Signal()

    def __init__(self, defaultInputType=None, defaultOutputType=None):
        super(AutoNode, self).__init__()
        QtCore.QObject.__init__(self)
        self.needCook = True
        self._autoCook = True
        self._error = False
        self.matchTypes = [['float', 'int']]
        self.errorColor = (200, 50, 50)
        self.stopCookColor = (200, 200, 200)
        self._cryptoColors = CryptoColors()

        self.defaultColor = self.get_property("color")
        self.defaultValue = None
        self.defaultInputType = defaultInputType
        self.defaultOutputType = defaultOutputType

        self._cookTime = 0.0
        self._toolTip = self._setup_tool_tip()

    @property
    def autoCook(self):
        """
        Returns whether the node can update stream automatically.
        """

        return self._autoCook

    @autoCook.setter
    def autoCook(self, mode):
        """
        Set whether the node can update stream automatically.

        Args:
            mode(bool).
        """

        if mode is self._autoCook:
            return

        self._autoCook = mode
        if mode:
            self.set_property('color', self.defaultColor)
        else:
            self.defaultColor = self.get_property("color")
            self.set_property('color', self.stopCookColor)

    @property
    def cookTime(self):
        """
        Get the last cooked time of the node.
        """

        return self._cookTime

    @autoCook.setter
    def cookTime(self, time):
        """
        Set the last cooked time of the node.

        Args:
            time(float).
        """

        self._cookTime = time
        self._update_tool_tip()

    def update_stream(self, forceCook=False):
        """
        Update all down stream nodes.

        Args:
            forceCook(bool): if True, it will ignore the autoCook and so on.
        """

        if not forceCook:
            if not self._autoCook or not self.needCook:
                return
            if self.graph is not None and not self.graph.auto_update:
                return
        update_node_down_stream(self)

    def getData(self, port):
        """
        Get node data by port.
        Most time it will called by output nodes of the node.

        Args:
            port(Port).

        Returns:
            node data.
        """

        return self.get_property(port.name())

    def getInputData(self, port):
        """
        Get input data by input port name/index/object.

        Args:
            port(str/int/Port): input port name/index/object.
        """
        if type(port) is not Port:
            to_port = self.get_input(port)
        else:
            to_port = port
        if to_port is None:
            return copy.deepcopy(self.defaultValue)

        from_ports = to_port.connected_ports()
        if not from_ports:
            return copy.deepcopy(self.defaultValue)

        for from_port in from_ports:
            data = from_port.node().getData(from_port)
            return copy.deepcopy(data)

    def when_disabled(self):
        """
        Node evaluation logic when node has been disabled.
        """

        num = max(0, len(self.input_ports())-1)
        for index, out_port in enumerate(self.output_ports()):
            self.set_property(out_port.name(), self.getInputData(min(index, num)))

    def cook(self):
        """
        The entry of the node evaluation.
        Most time we need to call this method instead of AutoNode.run'.
        """

        _tmp = self._autoCook
        self._autoCook = False

        if self.error():
            self._close_error()

        _start_time = time.time()
        try:
            self.run()
        except:
            self.error(traceback.format_exc())

        self._autoCook = _tmp

        if self.error():
            return

        self.cookTime = time.time() - _start_time

        self.cooked.emit()

    def run(self):
        """
        Node evaluation logic.
        """

        pass

    def on_input_connected(self, to_port, from_port):
        if self.checkPortType(to_port, from_port):
            self.update_stream()
        else:
            self.needCook = False
            to_port.disconnect_from(from_port)

    def on_input_disconnected(self, to_port, from_port):
        if not self.needCook:
            self.needCook = True
            return
        self.update_stream()

    def checkPortType(self, to_port, from_port):
        """
        Check whether the port_type of the to_port and from_type is matched.

        Args:
            to_port(Port).
            from_port(Port).

        Returns:
            bool.
        """

        if to_port.data_type != from_port.data_type:
            if to_port.data_type == 'None' or from_port.data_type == 'None':
                return True
            for types in self.matchTypes:
                if to_port.data_type in types and from_port.data_type in types:
                    return True
            return False

        return True

    def set_property(self, name, value):
        super(AutoNode, self).set_property(name, value)
        self.set_port_type(name, type(value).__name__)
        if name in self.model.custom_properties.keys():
            self.update_stream()

    def set_port_type(self, port, data_type: str):
        """
        Set the data_type of the port.

        Args:
            port(Port): the port to set the data_type.
            data_type(str): port new data_type.
        """

        current_port = None

        if type(port) is Port:
            current_port = port
        elif type(port) is str:
            inputs = self.inputs()
            outputs = self.outputs()
            if port in inputs.keys():
                current_port = inputs[port]
            elif port in outputs.keys():
                current_port = outputs[port]

        if current_port:
            if current_port.data_type == data_type:
                return
            else:
                current_port.data_type = data_type

            current_port.border_color = current_port.color = self._cryptoColors.get(data_type)
            conn_type = 'multi' if current_port.multi_connection() else 'single'
            current_port.view.setToolTip('{}: {} ({}) '.format(current_port.name(), data_type, conn_type))

    def create_property(self, name, value, items=None, range=None,
                        widget_type=NODE_PROP, tab=None, ext=None, funcs=None):
        super(AutoNode, self).create_property(name, value, items, range, widget_type, tab, ext, funcs)

        if value is not None:
            self.set_port_type(name, type(value).__name__)

    def add_input(self, name='input', data_type='None', multi_input=False, display_name=True,
                  color=None):
        new_port = super(AutoNode, self).add_input(name, multi_input, display_name, color)
        if data_type == 'None' and self.defaultInputType is not None:
            data_type = self.defaultInputType
        if type(data_type) is not str:
            data_type = data_type.__name__
        self.set_port_type(new_port, data_type)

        return new_port

    def add_output(self, name='output', data_type='None', multi_output=True, display_name=True,
                   color=None):
        new_port = super(AutoNode, self).add_output(name, multi_output, display_name, color)
        if data_type == 'None' and self.defaultOutputType is not None:
            data_type = self.defaultOutputType
        if type(data_type) is not str:
            data_type = data_type.__name__
        self.set_port_type(new_port, data_type)

        return new_port

    def set_disabled(self, mode=False):
        super(AutoNode, self).set_disabled(mode)
        self._autoCook = not mode
        if mode:
            self.when_disabled()
            if self.graph is None or self.graph.auto_update:
                self.update_stream()
        else:
            self.update_stream()

    def _close_error(self):
        """
        Close the node error.
        """

        self._error = False
        self.set_property('color', self.defaultColor)
        self._update_tool_tip()

    def _show_error(self, message):
        """
        Show the node error.
        It will change the node color and set error describe to the node tooltip.

        Args:
            message(str): the describe of the error.
        """

        if not self._error:
            self.defaultColor = self.get_property("color")

        self._error = True
        self.set_property('color', self.errorColor)
        tooltip = '<font color="red"><br>({})</br></font>'.format(message)
        self._update_tool_tip(tooltip)

    def _update_tool_tip(self, message=None):
        """
        Update the node tooltip.

        Args:
            message(str): new node tooltip.
        """

        if message is None:
            tooltip = self._toolTip.format(self._cookTime)
        else:
            tooltip = '<b>{}</b>'.format(self.name())
            tooltip += message
            tooltip += '<br/>{}<br/>'.format(self._view.type_)
        self.view.setToolTip(tooltip)
        return tooltip

    def _setup_tool_tip(self):
        """
        Setup default node tooltip.

        Returns:
            str: new node tooltip.
        """
        tooltip = '<br> last cook used: {}s</br>'
        return self._update_tool_tip(tooltip)

    def error(self, message=None):
        """
        Update the node tooltip.

        Args:
            message(str): the describe of the error or None.

        Returns:
            if message is None, returns whether the node has error.
        """

        if message is None:
            return self._error

        self._show_error(message)

    def update_model(self):
        if self.error():
            self.set_property('color', self.defaultColor)
        super(AutoNode, self).update_model()
예제 #20
0
class _valueSliderEdit(QtWidgets.QWidget):
    valueChanged = QtCore.Signal(object)

    def __init__(self, parent=None):
        super(_valueSliderEdit, self).__init__(parent)
        self._edit = _valueEdit()
        self._edit.valueChanged.connect(self._on_edit_changed)
        self._edit.setMaximumWidth(70)
        self._slider = _slider()
        self._slider.valueChanged.connect(self._on_slider_changed)

        hbox = QtWidgets.QHBoxLayout()
        hbox.addWidget(self._edit)
        hbox.addWidget(self._slider)
        self.setLayout(hbox)

        self._mul = 1000.0
        self.set_min(0)
        self.set_max(10)
        self.set_data_type(float)
        self._lock = False

    def _on_edit_changed(self, value):
        self._set_slider_value(value)
        self.valueChanged.emit(self._edit.value())

    def _on_slider_changed(self, value):
        if self._lock:
            self._lock = False
            return
        value = value / float(self._mul)
        self._edit.setValue(value)
        self._on_edit_changed(value)

    def _set_slider_value(self, value):
        value = int(value * self._mul)

        if value == self._slider.value():
            return
        self._lock = True
        _min = self._slider.minimum()
        _max = self._slider.maximum()
        if _min <= value <= _max:
            self._slider.setValue(value)
        elif value < _min and self._slider.value() != _min:
            self._slider.setValue(_min)
        elif value > _max and self._slider.value() != _max:
            self._slider.setValue(_max)

    def set_min(self, value=0):
        self._slider.setMinimum(int(value * self._mul))

    def set_max(self, value=10):
        self._slider.setMaximum(int(value * self._mul))

    def set_data_type(self, dt):
        _min = int(self._slider.minimum() / self._mul)
        _max = int(self._slider.maximum() / self._mul)
        if dt is int:
            self._mul = 1.0
        elif dt is float:
            self._mul = 1000.0

        self.set_min(_min)
        self.set_max(_max)
        self._edit.set_data_type(dt)

    def value(self):
        return self._edit.value()

    def setValue(self, value):
        self._edit.setValue(value)
        self._on_edit_changed(value)
예제 #21
0
class _valueEdit(QtWidgets.QLineEdit):
    valueChanged = QtCore.Signal(object)

    def __init__(self, parent=None):
        super(_valueEdit, self).__init__(parent)
        self.mid_state = False
        self._data_type = float
        self.setText("0")

        self.pre_x = None
        self.pre_val = None
        self._step = 1
        self._speed = 0.1

        self.editingFinished.connect(self._on_text_changed)

        self.menu = _valueMenu()
        self.menu.mouseMove.connect(self.mouseMoveEvent)
        self.menu.mouseRelease.connect(self.mouseReleaseEvent)
        self.menu.stepChange.connect(self._reset)
        steps = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
        self.menu.set_steps(steps)

        self.set_data_type(float)

    def _on_text_changed(self):
        self.valueChanged.emit(self.value())

    def _reset(self):
        self.pre_x = None

    def mouseMoveEvent(self, event):
        if self.mid_state:
            if self.pre_x is None:
                self.pre_x = event.x()
                self.pre_val = self.value()
            else:
                self.set_step(self.menu.step)
                delta = event.x() - self.pre_x
                value = self.pre_val + int(delta * self._speed) * self._step
                self.setValue(value)
                self._on_text_changed()

        super(_valueEdit, self).mouseMoveEvent(event)

    def mousePressEvent(self, event):
        if event.button() == QtCore.Qt.MiddleButton:
            self.mid_state = True
            self._reset()
            self.menu.exec_(QtGui.QCursor.pos())
        super(_valueEdit, self).mousePressEvent(event)

    def mouseReleaseEvent(self, event):
        self.menu.close()
        self.mid_state = False
        super(_valueEdit, self).mouseReleaseEvent(event)

    def set_step(self, step):
        self._step = step

    def set_data_type(self, dt):
        if dt is int:
            self.setValidator(QtGui.QIntValidator())
        elif dt is float:
            self.setValidator(QtGui.QDoubleValidator())
        self._data_type = dt
        self.menu.set_data_type(dt)

    def _convert_text(self, text):
        # int("1.0") will return error
        # so we use int(float("1.0"))
        try:
            value = float(text)
        except:
            value = 0.0
        if self._data_type is int:
            value = int(value)
        return value

    def value(self):
        if self.text().startswith("."):
            text = "0" + self.text()
            self.setText(text)
        return self._convert_text(self.text())

    def setValue(self, value):
        if value != self.value():
            self.setText(str(self._convert_text(value)))
class AutoNode(BaseNode, QtCore.QObject):
    cooked = QtCore.Signal()

    def __init__(self, defaultInputType=None, defaultOutputType=None):
        super(AutoNode, self).__init__()
        QtCore.QObject.__init__(self)
        self._need_cook = True
        self._error = False
        self.matchTypes = [['float', 'int']]
        self.errorColor = (200, 50, 50)
        self.stopCookColor = (200, 200, 200)

        self.create_property('auto_cook', True)
        self.defaultValue = None
        self.defaultInputType = defaultInputType
        self.defaultOutputType = defaultOutputType

        self._cook_time = 0.0
        self._toolTip = self._setup_tool_tip()

        # effect
        self.color_effect = QtWidgets.QGraphicsColorizeEffect()
        self.color_effect.setStrength(0.7)
        self.color_effect.setEnabled(False)
        self.view.setGraphicsEffect(self.color_effect)

    @property
    def auto_cook(self):
        """
        Returns whether the node can update stream automatically.
        """

        return self.get_property('auto_cook')

    @auto_cook.setter
    def auto_cook(self, mode):
        """
        Set whether the node can update stream automatically.

        Args:
            mode(bool).
        """

        if mode is self.auto_cook:
            return

        self.model.set_property('auto_cook', mode)
        self.color_effect.setEnabled(not mode)
        if not mode:
            self.color_effect.setColor(QtGui.QColor(*self.stopCookColor))

    @property
    def cook_time(self):
        """
        Returns the last cooked time of the node.
        """

        return self._cook_time

    @cook_time.setter
    def cook_time(self, cook_time):
        """
        Set the last cooked time of the node.

        Args:
            cook_time(float).
        """

        self._cook_time = cook_time
        self._update_tool_tip()

    @property
    def has_error(self):
        """
        Returns whether the node has errors.
        """

        return self._error

    def update_stream(self, forceCook=False):
        """
        Update all down stream nodes.

        Args:
            forceCook(bool): if True, it will ignore the auto_cook and so on.
        """

        if not forceCook:
            if not self.auto_cook or not self._need_cook:
                return
            if self.graph is not None and not self.graph.auto_update:
                return
        update_node_down_stream(self)

    def get_data(self, port):
        """
        Get node data by port.
        Most time it will called by output nodes of the node.

        Args:
            port(Port).

        Returns:
            node data.
        """
        if self.disabled() and self.input_ports():
            out_ports = self.output_ports()
            if port in out_ports:
                idx = out_ports.index(port)
                max_idx = max(0, len(self.input_ports()) - 1)
                return self.get_input_data(min(idx, max_idx))

        return self.get_property(port.name())

    def get_input_data(self, port):
        """
        Get input data by input port name/index/object.

        Args:
            port(str/int/Port): input port name/index/object.
        """
        if type(port) is not Port:
            to_port = self.get_input(port)
        else:
            to_port = port
        if to_port is None:
            return copy.deepcopy(self.defaultValue)

        from_ports = to_port.connected_ports()
        if not from_ports:
            return copy.deepcopy(self.defaultValue)

        for from_port in from_ports:
            data = from_port.node().get_data(from_port)
            return copy.deepcopy(data)

    def cook(self):
        """
        The entry of the node evaluation.
        Most time we need to call this method instead of AutoNode.run'.
        """

        _tmp = self.auto_cook
        self.model.set_property('auto_cook', False)

        if self._error:
            self._close_error()

        _start_time = time.time()
        try:
            self.run()
        except:
            self.error(traceback.format_exc())

        self.model.set_property('auto_cook', _tmp)

        if self._error:
            return

        self.cook_time = time.time() - _start_time

        self.cooked.emit()

    def run(self):
        """
        Node evaluation logic.
        """

        pass

    def on_input_connected(self, to_port, from_port):
        if self.check_port_type(to_port, from_port):
            self.update_stream()
        else:
            self._need_cook = False
            to_port.disconnect_from(from_port)

    def on_input_disconnected(self, to_port, from_port):
        if not self._need_cook:
            self._need_cook = True
            return
        self.update_stream()

    def check_port_type(self, to_port, from_port):
        """
        Check whether the port_type of the to_port and from_type is matched.

        Args:
            to_port(Port).
            from_port(Port).

        Returns:
            bool.
        """

        if to_port.data_type != from_port.data_type:
            if to_port.data_type == 'NoneType' or from_port.data_type == 'NoneType':
                return True
            for types in self.matchTypes:
                if to_port.data_type in types and from_port.data_type in types:
                    return True
            return False

        return True

    def set_property(self, name, value):
        super(AutoNode, self).set_property(name, value)
        self.set_port_type(name, type(value).__name__)
        if name in self.model.custom_properties.keys():
            self.update_stream()

    def set_port_type(self, port, data_type: str):
        """
        Set the data_type of the port.

        Args:
            port(Port): the port to set the data_type.
            data_type(str): port new data_type.
        """

        current_port = None

        if type(port) is Port:
            current_port = port
        elif type(port) is str:
            inputs = self.inputs()
            outputs = self.outputs()
            if port in inputs.keys():
                current_port = inputs[port]
            elif port in outputs.keys():
                current_port = outputs[port]

        if current_port:
            if current_port.data_type == data_type:
                return
            else:
                current_port.data_type = data_type

            current_port.border_color = current_port.color = CryptoColors.get(data_type)
            conn_type = 'multi' if current_port.multi_connection() else 'single'
            current_port.view.setToolTip('{}: {} ({}) '.format(current_port.name(), data_type, conn_type))

    def add_input(self, name='input', data_type='', multi_input=False, display_name=True,
                  color=None, painter_func=None):
        new_port = super(AutoNode, self).add_input(name, multi_input, display_name,
                                                   color, data_type, painter_func)
        if data_type == '':
            data_type = self.defaultInputType
        self.set_port_type(new_port, get_data_type(data_type))
        return new_port

    def add_output(self, name='output', data_type='', multi_output=True, display_name=True,
                   color=None, painter_func=None):
        new_port = super(AutoNode, self).add_output(name, multi_output, display_name,
                                                    color, data_type, painter_func)
        if data_type == '':
            data_type = self.defaultOutputType
        self.set_port_type(new_port, get_data_type(data_type))
        return new_port

    def set_disabled(self, mode=False):
        super(AutoNode, self).set_disabled(mode)
        self.update_stream()

    def _close_error(self):
        """
        Close the node error.
        """

        self._error = False
        self.color_effect.setEnabled(False)
        self._update_tool_tip()

    def _update_tool_tip(self, message=None):
        """
        Update the node tooltip.

        Args:
            message(str): new node tooltip.
        """

        if message is None:
            tooltip = self._toolTip.format(self._cook_time)
        else:
            tooltip = '<b>{}</b>'.format(self.name())
            tooltip += message
            tooltip += '<br/>{}<br/>'.format(self._view.type_)
        self.view.setToolTip(tooltip)
        return tooltip

    def _setup_tool_tip(self):
        """
        Setup default node tooltip.

        Returns:
            str: new node tooltip.
        """
        tooltip = '<br> last cook used: {}s</br>'
        return self._update_tool_tip(tooltip)

    def error(self, message):
        """
        Change the node color and set error describe to the node tooltip.

        Args:
            message(str): the describe of the error.
        """

        self._error = True
        self.color_effect.setEnabled(True)
        self.color_effect.setColor(QtGui.QColor(*self.errorColor))
        tooltip = '<font color="red"><br>({})</br></font>'.format(message)
        self._update_tool_tip(tooltip)
예제 #23
0
class PropertiesBinWidget(QtWidgets.QWidget):

    #: Signal emitted (node_id, prop_name, prop_value)
    property_changed = QtCore.Signal(str, str, object)

    def __init__(self, parent=None):
        super(PropertiesBinWidget, self).__init__(parent)
        self.setWindowTitle('Properties Bin')
        self._prop_list = PropertiesList()
        self._limit = QtWidgets.QSpinBox()
        self._limit.setToolTip('Set node limit to display.')
        self._limit.setMaximum(10)
        self._limit.setMinimum(0)
        self._limit.setValue(10)
        self._limit.valueChanged.connect(self.__on_limit_changed)
        self.resize(400, 400)

        btn_clr = QtWidgets.QPushButton('clear')
        btn_clr.setToolTip('Clear the properties bin.')
        btn_clr.clicked.connect(self.clear_bin)

        top_layout = QtWidgets.QHBoxLayout()
        top_layout.addWidget(self._limit)
        top_layout.addStretch(1)
        top_layout.addWidget(btn_clr)

        layout = QtWidgets.QVBoxLayout(self)
        layout.addLayout(top_layout)
        layout.addWidget(self._prop_list, 1)

    def __on_prop_close(self, node_id):
        items = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly)
        [self._prop_list.removeRow(i.row()) for i in items]

    def __on_limit_changed(self, value):
        rows = self._prop_list.rowCount()
        if rows > value:
            self._prop_list.removeRow(rows - 1)

    def limit(self):
        """
        Returns the limit for how many nodes can be loaded into the bin.

        Returns:
            int: node limit.
        """
        return int(self._limit.value())

    def add_node(self, node):
        """
        Add node to the properties bin.

        Args:
            node (NodeGraphQt.Node): node object.
        """
        if self.limit() == 0:
            return

        rows = self._prop_list.rowCount()
        if rows >= self.limit():
            self._prop_list.removeRow(rows - 1)

        itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly)
        if itm_find:
            self._prop_list.removeRow(itm_find[0].row())

        self._prop_list.insertRow(0)
        prop_widget = NodePropWidget(node=node)
        prop_widget.property_changed.connect(self.property_changed.emit)
        prop_widget.property_closed.connect(self.__on_prop_close)
        self._prop_list.setCellWidget(0, 0, prop_widget)

        item = QtWidgets.QTableWidgetItem(node.id)
        self._prop_list.setItem(0, 0, item)
        self._prop_list.selectRow(0)

    def remove_node(self, node):
        """
        Remove node from the properties bin.

        Args:
            node (NodeGraphQt.Node): node object.
        """
        self.__on_prop_close(node.id)

    def clear_bin(self):
        """
        Clear the properties bin.
        """
        self._prop_list.setRowCount(0)

    def prop_widget(self, node):
        """
        Returns the node property widget.

        Args:
            node (NodeGraphQt.Node): node object.

        Returns:
            NodePropWidget: node property widget.
        """
        itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly)
        if itm_find:
            item = itm_find[0]
            return self._prop_list.cellWidget(item.row(), 0)
예제 #24
0
class PropertiesBinWidget(QtWidgets.QWidget):
    """
    Node properties bin for displaying properties.

    Args:
        parent (QtWidgets.QWidget): parent of the new widget.
        node_graph (NodeGraphQt.NodeGraph): node graph.
    """

    #: Signal emitted (node_id, prop_name, prop_value)
    property_changed = QtCore.Signal(str, str, object)

    def __init__(self, parent=None, node_graph=None):
        super(PropertiesBinWidget, self).__init__(parent)
        self.setWindowTitle('Properties Bin')
        self._prop_list = PropertiesList()
        self._limit = QtWidgets.QSpinBox()
        self._limit.setToolTip('Set display nodes limit.')
        self._limit.setMaximum(10)
        self._limit.setMinimum(0)
        self._limit.setValue(5)
        self._limit.valueChanged.connect(self.__on_limit_changed)
        self.resize(400, 400)

        self._block_signal = False

        btn_clr = QtWidgets.QPushButton('clear')
        btn_clr.setToolTip('Clear the properties bin.')
        btn_clr.clicked.connect(self.clear_bin)

        top_layout = QtWidgets.QHBoxLayout()
        top_layout.addWidget(self._limit)
        top_layout.addStretch(1)
        top_layout.addWidget(btn_clr)

        layout = QtWidgets.QVBoxLayout(self)
        layout.addLayout(top_layout)
        layout.addWidget(self._prop_list, 1)

        # wire up node graph.
        node_graph.add_properties_bin(self)
        node_graph.node_double_clicked.connect(self.add_node)
        node_graph.property_changed.connect(self.__on_graph_property_changed)

    def __repr__(self):
        return '<{} object at {}>'.format(self.__class__.__name__,
                                          hex(id(self)))

    def __on_prop_close(self, node_id):
        items = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly)
        [self._prop_list.removeRow(i.row()) for i in items]

    def __on_limit_changed(self, value):
        rows = self._prop_list.rowCount()
        if rows > value:
            self._prop_list.removeRow(rows - 1)

    def __on_graph_property_changed(self, node, prop_name, prop_value):
        """
        Slot function that updates the property bin from the node graph signal.

        Args:
            node (NodeGraphQt.NodeObject):
            prop_name (str):
            prop_value (object):
        """
        properties_widget = self.prop_widget(node)
        if not properties_widget:
            return

        property_window = properties_widget.get_widget(prop_name)
        if prop_value != property_window.get_value():
            self._block_signal = True
            property_window.set_value(prop_value)
            self._block_signal = False

    def __on_property_widget_changed(self, node_id, prop_name, prop_value):
        """
        Slot function triggered when a property widget value has changed.

        Args:
            node_id (str):
            prop_name (str):
            prop_value (object):
        """
        if not self._block_signal:
            self.property_changed.emit(node_id, prop_name, prop_value)

    def limit(self):
        """
        Returns the limit for how many nodes can be loaded into the bin.

        Returns:
            int: node limit.
        """
        return int(self._limit.value())

    def set_limit(self, limit):
        """
        Set limit of nodes to display.

        Args:
            limit (int): node limit.
        """
        self._limit.setValue(limit)

    def add_node(self, node):
        """
        Add node to the properties bin.

        Args:
            node (NodeGraphQt.NodeObject): node object.
        """
        if self.limit() == 0:
            return

        rows = self._prop_list.rowCount()
        if rows >= self.limit():
            self._prop_list.removeRow(rows - 1)

        itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly)
        if itm_find:
            self._prop_list.removeRow(itm_find[0].row())

        self._prop_list.insertRow(0)
        prop_widget = NodePropWidget(node=node)
        prop_widget.property_changed.connect(self.__on_property_widget_changed)
        prop_widget.property_closed.connect(self.__on_prop_close)
        self._prop_list.setCellWidget(0, 0, prop_widget)

        item = QtWidgets.QTableWidgetItem(node.id)
        self._prop_list.setItem(0, 0, item)
        self._prop_list.selectRow(0)

    def remove_node(self, node):
        """
        Remove node from the properties bin.

        Args:
            node (NodeGraphQt.BaseNode): node object.
        """
        self.__on_prop_close(node.id)

    def clear_bin(self):
        """
        Clear the properties bin.
        """
        self._prop_list.setRowCount(0)

    def prop_widget(self, node):
        """
        Returns the node property widget.

        Args:
            node (NodeGraphQt.NodeObject): node object.

        Returns:
            NodePropWidget: node property widget.
        """
        itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly)
        if itm_find:
            item = itm_find[0]
            return self._prop_list.cellWidget(item.row(), 0)
예제 #25
0
class NodeViewer(QtWidgets.QGraphicsView):
    """
    The widget interface used for displaying the scene and nodes.

    functions in this class are called by the
    class:`NodeGraphQt.NodeGraph` class.
    """

    moved_nodes = QtCore.Signal(dict)
    search_triggered = QtCore.Signal(str, tuple)
    connection_sliced = QtCore.Signal(list)
    connection_changed = QtCore.Signal(list, list)
    insert_node = QtCore.Signal(object, str, dict)
    need_show_tab_search = QtCore.Signal()

    # pass through signals
    node_selected = QtCore.Signal(str)
    node_double_clicked = QtCore.Signal(str)
    data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint)

    def __init__(self, parent=None):
        super(NodeViewer, self).__init__(parent)

        self.setScene(NodeScene(self))
        self.setRenderHint(QtGui.QPainter.Antialiasing, True)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.setViewportUpdateMode(QtWidgets.QGraphicsView.FullViewportUpdate)
        self.setCacheMode(QtWidgets.QGraphicsView.CacheBackground)
        self.setOptimizationFlag(
            QtWidgets.QGraphicsView.DontAdjustForAntialiasing)

        self.setAcceptDrops(True)
        self.resize(850, 800)

        self._scene_range = QtCore.QRectF(0, 0,
                                          self.size().width(),
                                          self.size().height())
        self._update_scene()
        self._last_size = self.size()

        self._pipe_layout = PIPE_LAYOUT_CURVED
        self._detached_port = None
        self._start_port = None
        self._origin_pos = None
        self._previous_pos = QtCore.QPoint(self.width(), self.height())
        self._prev_selection_nodes = []
        self._prev_selection_pipes = []
        self._node_positions = {}
        self._rubber_band = QtWidgets.QRubberBand(
            QtWidgets.QRubberBand.Rectangle, self)

        self._LIVE_PIPE = LivePipe()
        self._LIVE_PIPE.setVisible(False)
        self.scene().addItem(self._LIVE_PIPE)

        self._SLICER_PIPE = SlicerPipe()
        self._SLICER_PIPE.setVisible(False)
        self.scene().addItem(self._SLICER_PIPE)

        self._undo_stack = QtWidgets.QUndoStack(self)
        self._search_widget = TabSearchMenuWidget(self)
        self._search_widget.search_submitted.connect(self._on_search_submitted)

        # workaround fix for shortcuts from the non-native menu actions
        # don't seem to trigger so we create a hidden menu bar.
        menu_bar = QtWidgets.QMenuBar(self)
        menu_bar.setNativeMenuBar(False)
        # shortcuts don't work with "setVisibility(False)".
        menu_bar.setMaximumWidth(0)

        self._ctx_menu = BaseMenu('NodeGraph', self)
        self._ctx_node_menu = BaseMenu('Nodes', self)
        menu_bar.addMenu(self._ctx_menu)
        menu_bar.addMenu(self._ctx_node_menu)
        self._ctx_node_menu.setDisabled(True)

        self.acyclic = True
        self.LMB_state = False
        self.RMB_state = False
        self.MMB_state = False
        self.ALT_state = False
        self.CTRL_state = False
        self.SHIFT_state = False
        self.COLLIDING_state = False

    def __repr__(self):
        return '{}.{}()'.format(self.__module__, self.__class__.__name__)

    # --- private ---

    def _set_viewer_zoom(self, value, sensitivity=None, pos=None):
        if pos:
            pos = self.mapToScene(pos)
        if sensitivity is None:
            scale = 1.001**value
            self.scale(scale, scale, pos)
            return

        if value == 0.0:
            return
        scale = (0.9 + sensitivity) if value < 0.0 else (1.1 - sensitivity)
        zoom = self.get_zoom()
        if ZOOM_MIN >= zoom:
            if scale == 0.9:
                return
        if ZOOM_MAX <= zoom:
            if scale == 1.1:
                return
        self.scale(scale, scale, pos)

    def _set_viewer_pan(self, pos_x, pos_y):
        speed = self._scene_range.width() * 0.0015
        x = -pos_x * speed
        y = -pos_y * speed
        self._scene_range.adjust(x, y, x, y)
        self._update_scene()

    def scale(self, sx, sy, pos=None):
        scale = [sx, sx]

        center = pos or self._scene_range.center()

        w = self._scene_range.width() / scale[0]
        h = self._scene_range.height() / scale[1]
        self._scene_range = QtCore.QRectF(
            center.x() - (center.x() - self._scene_range.left()) / scale[0],
            center.y() - (center.y() - self._scene_range.top()) / scale[1], w,
            h)

        self._update_scene()

    def _update_scene(self):
        self.setSceneRect(self._scene_range)
        self.fitInView(self._scene_range, QtCore.Qt.KeepAspectRatio)

    def _combined_rect(self, nodes):
        group = self.scene().createItemGroup(nodes)
        rect = group.boundingRect()
        self.scene().destroyItemGroup(group)
        return rect

    def _items_near(self, pos, item_type=None, width=20, height=20):
        x, y = pos.x() - width, pos.y() - height
        rect = QtCore.QRectF(x, y, width, height)
        items = []
        excl = [self._LIVE_PIPE, self._SLICER_PIPE]
        for item in self.scene().items(rect):
            if item in excl:
                continue
            if not item_type or isinstance(item, item_type):
                items.append(item)
        return items

    def _on_search_submitted(self, node_type):
        pos = self.mapToScene(self._previous_pos)
        self.search_triggered.emit(node_type, (pos.x(), pos.y()))

    def _on_pipes_sliced(self, path):
        self.connection_sliced.emit(
            [[i.input_port, i.output_port] for i in self.scene().items(path)
             if isinstance(i, Pipe) and i != self._LIVE_PIPE])

    # --- reimplemented events ---

    def resizeEvent(self, event):
        delta = max(self.size().width() / self._last_size.width(),
                    self.size().height() / self._last_size.height())
        self._set_viewer_zoom(delta)
        self._last_size = self.size()
        super(NodeViewer, self).resizeEvent(event)

    def contextMenuEvent(self, event):
        self.RMB_state = False
        ctx_menu = None

        if self._ctx_node_menu.isEnabled():
            pos = self.mapToScene(self._previous_pos)
            items = self._items_near(pos)
            nodes = [i for i in items if isinstance(i, AbstractNodeItem)]
            if nodes:
                node = nodes[0]
                ctx_menu = self._ctx_node_menu.get_menu(node.type_, node.id)
                if ctx_menu:
                    for action in ctx_menu.actions():
                        if not action.menu():
                            action.node_id = node.id

        ctx_menu = ctx_menu or self._ctx_menu
        if len(ctx_menu.actions()) > 0:
            if ctx_menu.isEnabled():
                ctx_menu.exec_(event.globalPos())
            else:
                return super(NodeViewer, self).contextMenuEvent(event)
        else:
            self.need_show_tab_search.emit()

    def mousePressEvent(self, event):
        if event.button() == QtCore.Qt.LeftButton:
            self.LMB_state = True
        elif event.button() == QtCore.Qt.RightButton:
            self.RMB_state = True
        elif event.button() == QtCore.Qt.MiddleButton:
            self.MMB_state = True

        self._origin_pos = event.pos()
        self._previous_pos = event.pos()
        self._prev_selection_nodes, \
            self._prev_selection_pipes = self.selected_items()

        # close tab search
        if self._search_widget.isVisible():
            self.tab_search_toggle()

        # cursor pos.
        map_pos = self.mapToScene(event.pos())

        # pipe slicer enabled.
        if self.ALT_state and self.SHIFT_state and self.LMB_state:
            self._SLICER_PIPE.draw_path(map_pos, map_pos)
            self._SLICER_PIPE.setVisible(True)
            return

        # pan mode.
        if self.ALT_state:
            return

        items = self._items_near(map_pos, None, 20, 20)
        nodes = [i for i in items if isinstance(i, AbstractNodeItem)]
        pipes = [i for i in items if isinstance(i, Pipe)]

        if nodes:
            self.MMB_state = False

        # toggle extend node selection.
        if self.LMB_state:
            if self.SHIFT_state:
                for node in nodes:
                    node.selected = not node.selected
            elif self.CTRL_state:
                for node in nodes:
                    node.selected = False

        # update the recorded node positions.
        self._node_positions.update(
            {n: n.xy_pos
             for n in self.selected_nodes()})

        # show selection selection marquee.
        if self.LMB_state and not items:
            rect = QtCore.QRect(self._previous_pos, QtCore.QSize())
            rect = rect.normalized()
            map_rect = self.mapToScene(rect).boundingRect()
            self.scene().update(map_rect)
            self._rubber_band.setGeometry(rect)
            self._rubber_band.show()

        # allow new live pipe with the shift modifier.
        # if self.LMB_state:
        #     if (not self.SHIFT_state and not self.CTRL_state) or\
        #             (self.SHIFT_state and pipes):
        if not self._LIVE_PIPE.isVisible():
            super(NodeViewer, self).mousePressEvent(event)

    def mouseReleaseEvent(self, event):
        if event.button() == QtCore.Qt.LeftButton:
            self.LMB_state = False
        elif event.button() == QtCore.Qt.RightButton:
            self.RMB_state = False
        elif event.button() == QtCore.Qt.MiddleButton:
            self.MMB_state = False

        # hide pipe slicer.
        if self._SLICER_PIPE.isVisible():
            self._on_pipes_sliced(self._SLICER_PIPE.path())
            p = QtCore.QPointF(0.0, 0.0)
            self._SLICER_PIPE.draw_path(p, p)
            self._SLICER_PIPE.setVisible(False)

        # hide selection marquee
        if self._rubber_band.isVisible():
            rect = self._rubber_band.rect()
            map_rect = self.mapToScene(rect).boundingRect()
            self._rubber_band.hide()
            self.scene().update(map_rect)

        # find position changed nodes and emit signal.
        moved_nodes = {
            n: xy_pos
            for n, xy_pos in self._node_positions.items() if n.xy_pos != xy_pos
        }
        # only emit of node is not colliding with a pipe.
        if moved_nodes and not self.COLLIDING_state:
            self.moved_nodes.emit(moved_nodes)

        # reset recorded positions.
        self._node_positions = {}

        # emit signal if selected node collides with pipe.
        # Note: if collide state is true then only 1 node is selected.
        if self.COLLIDING_state:
            nodes, pipes = self.selected_items()
            if nodes and pipes:
                self.insert_node.emit(pipes[0], nodes[0].id, moved_nodes)

        super(NodeViewer, self).mouseReleaseEvent(event)

    def mouseMoveEvent(self, event):
        if self.ALT_state and self.SHIFT_state:
            if self.LMB_state and self._SLICER_PIPE.isVisible():
                p1 = self._SLICER_PIPE.path().pointAtPercent(0)
                p2 = self.mapToScene(self._previous_pos)
                self._SLICER_PIPE.draw_path(p1, p2)
                self._SLICER_PIPE.show()
            self._previous_pos = event.pos()
            super(NodeViewer, self).mouseMoveEvent(event)
            return

        if self.MMB_state and self.ALT_state:
            pos_x = (event.x() - self._previous_pos.x())
            zoom = 0.1 if pos_x > 0 else -0.1
            self._set_viewer_zoom(zoom, 0.05, pos=event.pos())
        elif self.MMB_state or (self.LMB_state and self.ALT_state):
            pos_x = (event.x() - self._previous_pos.x())
            pos_y = (event.y() - self._previous_pos.y())
            self._set_viewer_pan(pos_x, pos_y)

        if self.LMB_state and self._rubber_band.isVisible():
            rect = QtCore.QRect(self._origin_pos, event.pos()).normalized()
            map_rect = self.mapToScene(rect).boundingRect()
            path = QtGui.QPainterPath()
            path.addRect(map_rect)
            self._rubber_band.setGeometry(rect)
            self.scene().setSelectionArea(path, QtCore.Qt.IntersectsItemShape)
            self.scene().update(map_rect)

            if self.SHIFT_state or self.CTRL_state:
                nodes, pipes = self.selected_items()

                for pipe in self._prev_selection_pipes:
                    pipe.setSelected(True)
                for node in self._prev_selection_nodes:
                    node.selected = True

                if self.CTRL_state:
                    for pipe in pipes:
                        pipe.setSelected(False)
                    for node in nodes:
                        node.selected = False

        elif self.LMB_state:
            self.COLLIDING_state = False
            nodes = self.selected_nodes()
            if len(nodes) == 1:
                node = nodes[0]
                for pipe in self.selected_pipes():
                    pipe.setSelected(False)
                for item in node.collidingItems():
                    if isinstance(item, Pipe) and item.isVisible():
                        if not item.input_port:
                            continue
                        if not item.input_port.node is node and \
                                not item.output_port.node is node:
                            item.setSelected(True)
                            self.COLLIDING_state = True
                            break

        self._previous_pos = event.pos()
        super(NodeViewer, self).mouseMoveEvent(event)

    def wheelEvent(self, event):
        try:
            delta = event.delta()
        except AttributeError:
            # For PyQt5
            delta = event.angleDelta().y()
            if delta == 0:
                delta = event.angleDelta().x()

        self._set_viewer_zoom(delta, pos=event.pos())

    def dropEvent(self, event):
        pos = self.mapToScene(event.pos())
        event.setDropAction(QtCore.Qt.MoveAction)
        self.data_dropped.emit(event.mimeData(),
                               QtCore.QPoint(pos.x(), pos.y()))

    def dragEnterEvent(self, event):
        if event.mimeData().hasFormat('text/plain'):
            event.accept()
        else:
            event.ignore()

    def dragMoveEvent(self, event):
        if event.mimeData().hasFormat('text/plain'):
            event.accept()
        else:
            event.ignore()

    def dragLeaveEvent(self, event):
        event.ignore()

    def keyPressEvent(self, event):
        self.ALT_state = event.modifiers() == QtCore.Qt.AltModifier
        self.CTRL_state = event.modifiers() == QtCore.Qt.ControlModifier
        self.SHIFT_state = event.modifiers() == QtCore.Qt.ShiftModifier

        # Todo: find a better solution to catch modifier keys.
        if event.modifiers() == (QtCore.Qt.AltModifier
                                 | QtCore.Qt.ShiftModifier):
            self.ALT_state = True
            self.SHIFT_state = True

        super(NodeViewer, self).keyPressEvent(event)

    def keyReleaseEvent(self, event):
        self.ALT_state = event.modifiers() == QtCore.Qt.AltModifier
        self.CTRL_state = event.modifiers() == QtCore.Qt.ControlModifier
        self.SHIFT_state = event.modifiers() == QtCore.Qt.ShiftModifier
        super(NodeViewer, self).keyReleaseEvent(event)

    # --- scene events ---

    def sceneMouseMoveEvent(self, event):
        """
        triggered mouse move event for the scene.
         - redraw the connection pipe.

        Args:
            event (QtWidgets.QGraphicsSceneMouseEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        if not self._LIVE_PIPE.isVisible():
            return
        if not self._start_port:
            return

        pos = event.scenePos()
        items = self.scene().items(pos)
        if items and isinstance(items[0], PortItem):
            x = items[0].boundingRect().width() / 2
            y = items[0].boundingRect().height() / 2
            pos = items[0].scenePos()
            pos.setX(pos.x() + x)
            pos.setY(pos.y() + y)

        self._LIVE_PIPE.draw_path(self._start_port, cursor_pos=pos)

    def sceneMousePressEvent(self, event):
        """
        triggered mouse press event for the scene (takes priority over viewer event).
         - detect selected pipe and start connection.
         - remap Shift and Ctrl modifier.

        Args:
            event (QtWidgets.QGraphicsScenePressEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        # pipe slicer enabled.
        if self.ALT_state and self.SHIFT_state:
            return
        # viewer pan mode.
        if self.ALT_state:
            return
        if self._LIVE_PIPE.isVisible():
            self.apply_live_connection(event)
            return

        pos = event.scenePos()
        port_items = self._items_near(pos, PortItem, 5, 5)
        if port_items:
            port = port_items[0]
            if not port.multi_connection and port.connected_ports:
                self._detached_port = port.connected_ports[0]
            self.start_live_connection(port)
            if not port.multi_connection:
                [p.delete() for p in port.connected_pipes]
            return

        node_items = self._items_near(pos, AbstractNodeItem, 3, 3)
        if node_items:
            node = node_items[0]

            # record the node positions at selection time.
            for n in node_items:
                self._node_positions[n] = n.xy_pos

            # emit selected node id with LMB.
            if event.button() == QtCore.Qt.LeftButton:
                self.node_selected.emit(node.id)

            if not isinstance(node, BackdropNodeItem):
                return

        pipe_items = self._items_near(pos, Pipe, 3, 3)
        if pipe_items:
            if not self.LMB_state:
                return
            pipe = pipe_items[0]
            from_port = pipe.port_from_pos(pos, True)
            from_port.hovered = True

            attr = {IN_PORT: 'output_port', OUT_PORT: 'input_port'}
            self._detached_port = getattr(pipe, attr[from_port.port_type])
            self.start_live_connection(from_port)
            self._LIVE_PIPE.draw_path(self._start_port, cursor_pos=pos)

            if self.SHIFT_state:
                self._LIVE_PIPE.shift_selected = True
                return

            pipe.delete()

    def sceneMouseReleaseEvent(self, event):
        """
        triggered mouse release event for the scene.
        Args:
            event (QtWidgets.QGraphicsSceneMouseEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        if event.button() != QtCore.Qt.MiddleButton:
            self.apply_live_connection(event)

    # --- port connections ---

    def apply_live_connection(self, event):
        """
        triggered mouse press/release event for the scene.
         - verify to make a the connection Pipe.

        Args:
            event (QtWidgets.QGraphicsSceneMouseEvent):
                The event handler from the QtWidgets.QGraphicsScene
        """
        if not self._LIVE_PIPE.isVisible():
            return

        self._start_port.hovered = False

        # find the end port.
        end_port = None
        for item in self.scene().items(event.scenePos()):
            if isinstance(item, PortItem):
                end_port = item
                break

        connected = []
        disconnected = []

        # if port disconnected from existing pipe.
        if end_port is None:
            if self._detached_port and not self._LIVE_PIPE.shift_selected:
                dist = math.hypot(
                    self._previous_pos.x() - self._origin_pos.x(),
                    self._previous_pos.y() - self._origin_pos.y())
                if dist <= 2.0:  # cursor pos threshold.
                    self.establish_connection(self._start_port,
                                              self._detached_port)
                    self._detached_port = None
                else:
                    disconnected.append(
                        (self._start_port, self._detached_port))
                    self.connection_changed.emit(disconnected, connected)

            self._detached_port = None
            self.end_live_connection()
            return

        else:
            if self._start_port is end_port:
                return

        # restore connection check.
        restore_connection = any([
            # if same port type.
            end_port.port_type == self._start_port.port_type,
            # if connection to itself.
            end_port.node == self._start_port.node,
            # if end port is the start port.
            end_port == self._start_port,
            # if detached port is the end port.
            self._detached_port == end_port
        ])
        if restore_connection:
            if self._detached_port:
                to_port = self._detached_port or end_port
                self.establish_connection(self._start_port, to_port)
                self._detached_port = None
            self.end_live_connection()
            return

        # register as disconnected if not acyclic.
        if self.acyclic and not self.acyclic_check(self._start_port, end_port):
            if self._detached_port:
                disconnected.append((self._start_port, self._detached_port))

            self.connection_changed.emit(disconnected, connected)

            self._detached_port = None
            self.end_live_connection()
            return

        # make connection.
        if not end_port.multi_connection and end_port.connected_ports:
            dettached_end = end_port.connected_ports[0]
            disconnected.append((end_port, dettached_end))

        if self._detached_port:
            disconnected.append((self._start_port, self._detached_port))

        connected.append((self._start_port, end_port))

        self.connection_changed.emit(disconnected, connected)

        self._detached_port = None
        self.end_live_connection()

    def start_live_connection(self, selected_port):
        """
        create new pipe for the connection.
        (show the live pipe visibility from the port following the cursor position)
        """
        if not selected_port:
            return
        self._start_port = selected_port
        if self._start_port.type == IN_PORT:
            self._LIVE_PIPE.input_port = self._start_port
        elif self._start_port == OUT_PORT:
            self._LIVE_PIPE.output_port = self._start_port
        self._LIVE_PIPE.setVisible(True)

    def end_live_connection(self):
        """
        delete live connection pipe and reset start port.
        (hides the pipe item used for drawing the live connection)
        """
        self._LIVE_PIPE.reset_path()
        self._LIVE_PIPE.setVisible(False)
        self._LIVE_PIPE.shift_selected = False
        self._start_port = None

    def establish_connection(self, start_port, end_port):
        """
        establish a new pipe connection.
        (adds a new pipe item to draw between 2 ports)
        """
        pipe = Pipe()
        self.scene().addItem(pipe)
        pipe.set_connections(start_port, end_port)
        pipe.draw_path(pipe.input_port, pipe.output_port)
        if start_port.node.selected or end_port.node.selected:
            pipe.highlight()

    def acyclic_check(self, start_port, end_port):
        """
        validate the connection so it doesn't loop itself.

        Returns:
            bool: True if port connection is valid.
        """
        start_node = start_port.node
        check_nodes = [end_port.node]
        io_types = {IN_PORT: 'outputs', OUT_PORT: 'inputs'}
        while check_nodes:
            check_node = check_nodes.pop(0)
            for check_port in getattr(check_node,
                                      io_types[end_port.port_type]):
                if check_port.connected_ports:
                    for port in check_port.connected_ports:
                        if port.node != start_node:
                            check_nodes.append(port.node)
                        else:
                            return False
        return True

    # --- viewer ---

    def tab_search_set_nodes(self, nodes):
        self._search_widget.set_nodes(nodes)

    def tab_search_toggle(self):
        if type(self._search_widget) is TabSearchMenuWidget:
            return

        pos = self._previous_pos
        state = not self._search_widget.isVisible()
        if state:
            rect = self._search_widget.rect()
            new_pos = QtCore.QPoint(pos.x() - rect.width() / 2,
                                    pos.y() - rect.height() / 2)
            self._search_widget.move(new_pos)
            self._search_widget.setVisible(state)
            rect = self.mapToScene(rect).boundingRect()
            self.scene().update(rect)
        else:
            self._search_widget.setVisible(state)
            self.clearFocus()

    def context_menus(self):
        return {'graph': self._ctx_menu, 'nodes': self._ctx_node_menu}

    def question_dialog(self, text, title='Node Graph'):
        dlg = messageBox(text, title,
                         QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No)
        return dlg == QtWidgets.QMessageBox.Yes

    def message_dialog(self, text, title='Node Graph'):
        messageBox(text, title, QtWidgets.QMessageBox.Ok)

    def load_dialog(self, current_dir=None, ext=None):
        ext = '*{} '.format(ext) if ext else ''
        ext_filter = ';;'.join(
            ['Node Graph ({}*json)'.format(ext), 'All Files (*)'])
        file_dlg = file_dialog.getOpenFileName(self, 'Open File', current_dir,
                                               ext_filter)
        file = file_dlg[0] or None
        return file

    def save_dialog(self, current_dir=None, ext=None):
        ext_label = '*{} '.format(ext) if ext else ''
        ext_type = '.{}'.format(ext) if ext else '.json'
        ext_map = {
            'Node Graph ({}*json)'.format(ext_label): ext_type,
            'All Files (*)': ''
        }
        file_dlg = file_dialog.getSaveFileName(self, 'Save Session',
                                               current_dir,
                                               ';;'.join(ext_map.keys()))
        file_path = file_dlg[0]
        if not file_path:
            return
        ext = ext_map[file_dlg[1]]
        if ext and not file_path.endswith(ext):
            file_path += ext

        return file_path

    def all_pipes(self):
        pipes = []
        excl = [self._LIVE_PIPE, self._SLICER_PIPE]
        for item in self.scene().items():
            if isinstance(item, Pipe) and item not in excl:
                pipes.append(item)
        return pipes

    def all_nodes(self):
        nodes = []
        for item in self.scene().items():
            if isinstance(item, AbstractNodeItem):
                nodes.append(item)
        return nodes

    def selected_nodes(self):
        nodes = [item for item in self.scene().selectedItems()\
                 if isinstance(item,AbstractNodeItem)]
        return nodes

    def selected_pipes(self):
        pipes = [item for item in self.scene().selectedItems()\
                 if isinstance(item,Pipe)]
        return pipes

    def selected_items(self):
        nodes = []
        pipes = []
        for item in self.scene().selectedItems():
            if isinstance(item, AbstractNodeItem):
                nodes.append(item)
            elif isinstance(item, Pipe):
                pipes.append(item)
        return nodes, pipes

    def add_node(self, node, pos=None):
        pos = pos or (self._previous_pos.x(), self._previous_pos.y())
        node.pre_init(self, pos)
        self.scene().addItem(node)
        node.post_init(self, pos)

    def remove_node(self, node):
        if isinstance(node, AbstractNodeItem):
            node.delete()

    def move_nodes(self, nodes, pos=None, offset=None):
        group = self.scene().createItemGroup(nodes)
        group_rect = group.boundingRect()
        if pos:
            x, y = pos
        else:
            pos = self.mapToScene(self._previous_pos)
            x = pos.x() - group_rect.center().x()
            y = pos.y() - group_rect.center().y()
        if offset:
            x += offset[0]
            y += offset[1]
        group.setPos(x, y)
        self.scene().destroyItemGroup(group)

    def get_pipes_from_nodes(self, nodes=None):
        nodes = nodes or self.selected_nodes()
        if not nodes:
            return
        pipes = []
        for node in nodes:
            n_inputs = node.inputs if hasattr(node, 'inputs') else []
            n_outputs = node.outputs if hasattr(node, 'outputs') else []

            for port in n_inputs:
                for pipe in port.connected_pipes:
                    connected_node = pipe.output_port.node
                    if connected_node in nodes:
                        pipes.append(pipe)
            for port in n_outputs:
                for pipe in port.connected_pipes:
                    connected_node = pipe.input_port.node
                    if connected_node in nodes:
                        pipes.append(pipe)
        return pipes

    def center_selection(self, nodes=None):
        if not nodes:
            if self.selected_nodes():
                nodes = self.selected_nodes()
            elif self.all_nodes():
                nodes = self.all_nodes()
        if len(nodes) == 1:
            self.centerOn(nodes[0])
        else:
            rect = self._combined_rect(nodes)
            self.centerOn(rect.center().x(), rect.center().y())

    def get_pipe_layout(self):
        return self._pipe_layout

    def set_pipe_layout(self, layout):
        self._pipe_layout = layout
        for pipe in self.all_pipes():
            pipe.draw_path(pipe.input_port, pipe.output_port)

    def reset_zoom(self, cent=None):
        self._scene_range = QtCore.QRectF(0, 0,
                                          self.size().width(),
                                          self.size().height())
        if cent:
            self._scene_range.translate(cent - self._scene_range.center())
        self._update_scene()

    def get_zoom(self):
        transform = self.transform()
        cur_scale = (transform.m11(), transform.m22())
        return float('{:0.2f}'.format(cur_scale[0] - 1.0))

    def set_zoom(self, value=0.0):
        if value == 0.0:
            self.reset_zoom()
            return
        zoom = self.get_zoom()
        if zoom < 0.0:
            if not (ZOOM_MIN <= zoom <= ZOOM_MAX):
                return
        else:
            if not (ZOOM_MIN <= value <= ZOOM_MAX):
                return
        value = value - zoom
        self._set_viewer_zoom(value, 0.0)

    def zoom_to_nodes(self, nodes):
        self._scene_range = self._combined_rect(nodes)
        self._update_scene()

        if self.get_zoom() > 0.1:
            self.reset_zoom(self._scene_range.center())

    def use_opengl(self):
        format = QtOpenGL.QGLFormat(QtOpenGL.QGL.SampleBuffers)
        format.setSamples(4)
        self.setViewport(QtOpenGL.QGLWidget(format))
예제 #26
0
class NodeGraph(QtCore.QObject):
    """
    base node graph controller.
    """

    #: signal emits the node object when a node is created in the node graph.
    node_created = QtCore.Signal(NodeObject)
    #: signal emits a list of node ids from the deleted nodes.
    nodes_deleted = QtCore.Signal(list)
    #: signal emits the node object when selected in the node graph.
    node_selected = QtCore.Signal(NodeObject)
    #: signal triggered when a node is double clicked and emits the node.
    node_double_clicked = QtCore.Signal(NodeObject)
    #: signal for when a node has been connected emits (source port, target port).
    port_connected = QtCore.Signal(Port, Port)
    #: signal for when a node has been disconnected emits (source port, target port).
    port_disconnected = QtCore.Signal(Port, Port)
    #: signal for when a node property has changed emits (node, property name, property value).
    property_changed = QtCore.Signal(NodeObject, str, object)
    #: signal for when drop data has been added to the graph.
    data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint)

    def __init__(self, parent=None):
        super(NodeGraph, self).__init__(parent)
        self.setObjectName('NodeGraphQt')
        self._model = NodeGraphModel()
        self._viewer = NodeViewer(parent)
        self._node_factory = NodeFactory()
        self._undo_stack = QtWidgets.QUndoStack(self)

        tab = QtWidgets.QAction('Search Nodes', self)
        tab.setShortcut('tab')
        tab.triggered.connect(self._toggle_tab_search)
        self._viewer.addAction(tab)

        self._wire_signals()

    def __repr__(self):
        return '<{} object at {}>'.format(self.__class__.__name__,
                                          hex(id(self)))

    def _wire_signals(self):
        # internal signals.
        self._viewer.search_triggered.connect(self._on_search_triggered)
        self._viewer.connection_sliced.connect(self._on_connection_sliced)
        self._viewer.connection_changed.connect(self._on_connection_changed)
        self._viewer.moved_nodes.connect(self._on_nodes_moved)
        self._viewer.node_double_clicked.connect(self._on_node_double_clicked)

        # pass through signals.
        self._viewer.node_selected.connect(self._on_node_selected)
        self._viewer.data_dropped.connect(self._on_node_data_dropped)

    def _toggle_tab_search(self):
        """
        toggle the tab search widget.
        """
        self._viewer.tab_search_set_nodes(self._node_factory.names)
        self._viewer.tab_search_toggle()

    def _on_property_bin_changed(self, node_id, prop_name, prop_value):
        """
        called when a property widget has changed in a properties bin.
        (emits the node object, property name, property value)

        Args:
            node_id (str): node id.
            prop_name (str): node property name.
            prop_value (object): python object.
        """
        node = self.get_node_by_id(node_id)

        # prevent signals from causing a infinite loop.
        if node.get_property(prop_name) != prop_value:
            node.set_property(prop_name, prop_value)

    def _on_node_double_clicked(self, node_id):
        """
        called when a node in the viewer is double click.
        (emits the node object when the node is clicked)

        Args:
            node_id (str): node id emitted by the viewer.
        """
        node = self.get_node_by_id(node_id)
        self.node_double_clicked.emit(node)

    def _on_node_selected(self, node_id):
        """
        called when a node in the viewer is selected on left click.
        (emits the node object when the node is clicked)

        Args:
            node_id (str): node id emitted by the viewer.
        """
        node = self.get_node_by_id(node_id)
        self.node_selected.emit(node)

    def _on_node_data_dropped(self, data, pos):
        """
        called when data has been dropped on the viewer.

        Args:
            data (QtCore.QMimeData): mime data.
            pos (QtCore.QPoint): scene position relative to the drop.
        """

        # don't emit signal for internal widget drops.
        if data.hasFormat('text/plain'):
            if data.text().startswith('<${}>:'.format(DRAG_DROP_ID)):
                node_ids = data.text()[len('<${}>:'.format(DRAG_DROP_ID)):]
                x, y = pos.x(), pos.y()
                for node_id in node_ids.split(','):
                    self.create_node(node_id, pos=[x, y])
                x += 20
                y += 20
                return

        self.data_dropped.emit(data, pos)

    def _on_nodes_moved(self, node_data):
        """
        called when selected nodes in the viewer has changed position.

        Args:
            node_data (dict): {<node_view>: <previous_pos>}
        """
        self._undo_stack.beginMacro('move nodes')
        for node_view, prev_pos in node_data.items():
            node = self._model.nodes[node_view.id]
            self._undo_stack.push(NodeMovedCmd(node, node.pos(), prev_pos))
        self._undo_stack.endMacro()

    def _on_search_triggered(self, node_type, pos):
        """
        called when the tab search widget is triggered in the viewer.

        Args:
            node_type (str): node identifier.
            pos (tuple): x,y position for the node.
        """
        self.create_node(node_type, pos=pos)

    def _on_connection_changed(self, disconnected, connected):
        """
        called when a pipe connection has been changed in the viewer.

        Args:
            disconnected (list[list[widgets.port.PortItem]):
                pair list of port view items.
            connected (list[list[widgets.port.PortItem]]):
                pair list of port view items.
        """
        if not (disconnected or connected):
            return

        label = 'connect node(s)' if connected else 'disconnect node(s)'
        ptypes = {'in': 'inputs', 'out': 'outputs'}

        self._undo_stack.beginMacro(label)
        for p1_view, p2_view in disconnected:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.disconnect_from(port2)
        for p1_view, p2_view in connected:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.connect_to(port2)
        self._undo_stack.endMacro()

    def _on_connection_sliced(self, ports):
        """
        slot when connection pipes have been sliced.

        Args:
            ports (list[list[widgets.port.PortItem]]):
                pair list of port connections (in port, out port)
        """
        if not ports:
            return
        ptypes = {'in': 'inputs', 'out': 'outputs'}
        self._undo_stack.beginMacro('slice connections')
        for p1_view, p2_view in ports:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.disconnect_from(port2)
        self._undo_stack.endMacro()

    @property
    def model(self):
        """
        Returns the model used to store the node graph data.

        Returns:
            NodeGraphQt.base.model.NodeGraphModel: node graph model.
        """
        return self._model

    def show(self):
        """
        Show node graph viewer widget this is just a convenience
        function to :meth:`NodeGraph.viewer().show()`.
        """
        self._viewer.show()

    def close(self):
        """
        Close node graph NodeViewer widget this is just a convenience
        function to :meth:`NodeGraph.viewer().close()`.
        """
        self._viewer.close()

    def viewer(self):
        """
        Return the node graph viewer widget.

        Returns:
            NodeGraphQt.widgets.viewer.NodeViewer: viewer widget.
        """
        return self._viewer

    def scene(self):
        """
        Return the scene object.

        Returns:
            NodeGraphQt.widgets.scene.NodeScene: node scene.
        """
        return self._viewer.scene()

    def background_color(self):
        """
        Return the node graph background color.

        Returns:
            tuple: r, g ,b
        """
        return self.scene().background_color

    def set_background_color(self, r, g, b):
        """
        Set node graph background color.

        Args:
            r (int): red value.
            g (int): green value.
            b (int): blue value.
        """
        self.scene().background_color = (r, g, b)

    def grid_color(self):
        """
        Return the node graph grid color.

        Returns:
            tuple: r, g ,b
        """
        return self.scene().grid_color

    def set_grid_color(self, r, g, b):
        """
        Set node graph grid color.

        Args:
            r (int): red value.
            g (int): green value.
            b (int): blue value.
        """
        self.scene().grid_color = (r, g, b)

    def display_grid(self, display=True):
        """
        Display node graph background grid.

        Args:
            display: False to not draw the background grid.
        """
        self.scene().grid = display

    def add_properties_bin(self, prop_bin):
        """
        Wire up a properties bin widget to the node graph.

        Args:
            prop_bin (NodeGraphQt.PropertiesBinWidget): properties widget.
        """
        prop_bin.property_changed.connect(self._on_property_bin_changed)

    def undo_stack(self):
        """
        Returns the undo stack used in the node graph

        Returns:
            QtWidgets.QUndoStack: undo stack.
        """
        return self._undo_stack

    def clear_undo_stack(self):
        """
        Clears the undo stack.
        (convenience function to :meth:`NodeGraph.undo_stack().clear`)
        """
        self._undo_stack.clear()

    def begin_undo(self, name):
        """
        Start of an undo block followed by a
        :meth:`NodeGraph.end_undo()`.

        Args:
            name (str): name for the undo block.
        """
        self._undo_stack.beginMacro(name)

    def end_undo(self):
        """
        End of an undo block started by
        :meth:`NodeGraph.begin_undo()`.
        """
        self._undo_stack.endMacro()

    def context_menu(self):
        """
        Returns the node graph root context menu object.

        Returns:
            Menu: context menu object.
        """
        return Menu(self._viewer, self._viewer.context_menu())

    def disable_context_menu(self, disabled=True):
        """
        Disable/Enable node graph context menu.

        Args:
            disabled (bool): true to enable context menu.
        """
        menu = self._viewer.context_menu()
        menu.setDisabled(disabled)
        menu.setVisible(not disabled)

    def acyclic(self):
        """
        Returns true if the current node graph is acyclic.

        Returns:
            bool: true if acyclic (default: True).
        """
        return self._model.acyclic

    def set_acyclic(self, mode=False):
        """
        Enable the node graph to be a acyclic graph. (default=False)

        Args:
            mode (bool): true to enable acyclic.
        """
        self._model.acyclic = mode
        self._viewer.acyclic = mode

    def set_pipe_style(self, style=None):
        """
        Set node graph pipes to be drawn straight or curved by default
        all pipes are set curved. (default=0)

        ``NodeGraphQt.constants.PIPE_LAYOUT_CURVED`` = 0
        ``NodeGraphQt.constants.PIPE_LAYOUT_STRAIGHT`` = 1

        Args:
            style (int): pipe style.
        """
        pipe_default = max([PIPE_LAYOUT_CURVED, PIPE_LAYOUT_STRAIGHT])
        style = PIPE_LAYOUT_STRAIGHT if style > pipe_default else style
        self._viewer.set_pipe_layout(style)

    def fit_to_selection(self):
        """
        Sets the zoom level to fit selected nodes.
        If no nodes are selected then all nodes in the graph will be framed.
        """
        nodes = self.selected_nodes() or self.all_nodes()
        if not nodes:
            return
        self._viewer.zoom_to_nodes([n.view for n in nodes])

    def reset_zoom(self):
        """
        Reset the zoom level
        """
        self._viewer.reset_zoom()

    def set_zoom(self, zoom=0):
        """
        Set the zoom factor of the Node Graph the default is 0.0

        Args:
            zoom (float): zoom factor (max zoom out -0.9 / max zoom in 2.0)
        """
        self._viewer.set_zoom(zoom)

    def get_zoom(self):
        """
        Get the current zoom level of the node graph.

        Returns:
            float: the current zoom level.
        """
        return self._viewer.get_zoom()

    def center_on(self, nodes=None):
        """
        Center the node graph on the given nodes or all nodes by default.

        Args:
            nodes (list[NodeGraphQt.BaseNode]): a list of nodes.
        """
        self._viewer.center_selection(nodes)

    def center_selection(self):
        """
        Centers on the current selected nodes.
        """
        nodes = self._viewer.selected_nodes()
        self._viewer.center_selection(nodes)

    def registered_nodes(self):
        """
        Return a list of all node types that have been registered.

        To register a node see :meth:`NodeGraph.register_node`

        Returns:
            list[str]: list of node type identifiers.
        """
        return sorted(self._node_factory.nodes.keys())

    def register_node(self, node, alias=None):
        """
        Register the node to the node graph vendor.

        Args:
            node (NodeGraphQt.NodeObject): node.
            alias (str): custom alias name for the node type.
        """
        self._node_factory.register_node(node, alias)

    def create_node(self,
                    node_type,
                    name=None,
                    selected=True,
                    color=None,
                    text_color=None,
                    pos=None):
        """
        Create a new node in the node graph.

        (To list all node types see :meth:`NodeGraph.registered_nodes`)

        Args:
            node_type (str): node instance type.
            name (str): set name of the node.
            selected (bool): set created node to be selected.
            color (tuple or str): node color (255, 255, 255) or '#FFFFFF'.
            text_color (tuple or str): node text color (255, 255, 255) or '#FFFFFF'.
            pos (list[int, int]): initial x, y position for the node (default: (0, 0)).

        Returns:
            NodeGraphQt.BaseNode: the created instance of the node.
        """
        NodeCls = self._node_factory.create_node_instance(node_type)
        if NodeCls:
            node = NodeCls()

            node._graph = self
            node.model._graph_model = self.model

            wid_types = node.model.__dict__.pop('_TEMP_property_widget_types')
            prop_attrs = node.model.__dict__.pop('_TEMP_property_attrs')

            if self.model.get_node_common_properties(node.type_) is None:
                node_attrs = {
                    node.type_:
                    {n: {
                        'widget_type': wt
                    }
                     for n, wt in wid_types.items()}
                }
                for pname, pattrs in prop_attrs.items():
                    node_attrs[node.type_][pname].update(pattrs)
                self.model.set_node_common_properties(node_attrs)

            node.NODE_NAME = self.get_unique_name(name or node.NODE_NAME)
            node.model.name = node.NODE_NAME
            node.model.selected = selected

            def format_color(clr):
                if isinstance(clr, str):
                    clr = clr.strip('#')
                    return tuple(int(clr[i:i + 2], 16) for i in (0, 2, 4))
                return clr

            if color:
                node.model.color = format_color(color)
            if text_color:
                node.model.text_color = format_color(text_color)
            if pos:
                node.model.pos = [float(pos[0]), float(pos[1])]

            node.update()

            undo_cmd = NodeAddedCmd(self, node, node.model.pos)
            undo_cmd.setText('create node: "{}"'.format(node.NODE_NAME))
            self._undo_stack.push(undo_cmd)
            self.node_created.emit(node)
            return node
        raise Exception('\n\n>> Cannot find node:\t"{}"\n'.format(node_type))

    def add_node(self, node, pos=None):
        """
        Add a node into the node graph.

        Args:
            node (NodeGraphQt.BaseNode): node object.
            pos (list[float]): node x,y position. (optional)
        """
        assert isinstance(node, NodeObject), 'node must be a Node instance.'

        wid_types = node.model.__dict__.pop('_TEMP_property_widget_types')
        prop_attrs = node.model.__dict__.pop('_TEMP_property_attrs')

        if self.model.get_node_common_properties(node.type_) is None:
            node_attrs = {
                node.type_:
                {n: {
                    'widget_type': wt
                }
                 for n, wt in wid_types.items()}
            }
            for pname, pattrs in prop_attrs.items():
                node_attrs[node.type_][pname].update(pattrs)
            self.model.set_node_common_properties(node_attrs)

        node._graph = self
        node.NODE_NAME = self.get_unique_name(node.NODE_NAME)
        node.model._graph_model = self.model
        node.model.name = node.NODE_NAME
        node.update()
        self._undo_stack.push(NodeAddedCmd(self, node, pos))

    def delete_node(self, node):
        """
        Remove the node from the node graph.

        Args:
            node (NodeGraphQt.BaseNode): node object.
        """
        assert isinstance(node, NodeObject), \
            'node must be a instance of a NodeObject.'
        self.nodes_deleted.emit([node.id])
        self._undo_stack.push(NodeRemovedCmd(self, node))

    def delete_nodes(self, nodes):
        """
        Remove a list of specified nodes from the node graph.

        Args:
            nodes (list[NodeGraphQt.BaseNode]): list of node instances.
        """
        self.nodes_deleted.emit([n.id for n in nodes])
        self._undo_stack.beginMacro('delete nodes')
        [self._undo_stack.push(NodeRemovedCmd(self, n)) for n in nodes]
        self._undo_stack.endMacro()

    def all_nodes(self):
        """
        Return all nodes in the node graph.

        Returns:
            list[NodeGraphQt.BaseNode]: list of nodes.
        """
        return list(self._model.nodes.values())

    def selected_nodes(self):
        """
        Return all selected nodes that are in the node graph.

        Returns:
            list[NodeGraphQt.BaseNode]: list of nodes.
        """
        nodes = []
        for item in self._viewer.selected_nodes():
            node = self._model.nodes[item.id]
            nodes.append(node)
        return nodes

    def select_all(self):
        """
        Select all nodes in the node graph.
        """
        self._undo_stack.beginMacro('select all')
        for node in self.all_nodes():
            node.set_selected(True)
        self._undo_stack.endMacro()

    def clear_selection(self):
        """
        Clears the selection in the node graph.
        """
        self._undo_stack.beginMacro('clear selection')
        for node in self.all_nodes():
            node.set_selected(False)
        self._undo_stack.endMacro()

    def get_node_by_id(self, node_id=None):
        """
        Returns the node from the node id string.

        Args:
            node_id (str): node id (:meth:`NodeObject.id`)

        Returns:
            NodeGraphQt.NodeObject: node object.
        """
        return self._model.nodes.get(node_id)

    def get_node_by_name(self, name):
        """
        Returns node that matches the name.

        Args:
            name (str): name of the node.
        Returns:
            NodeGraphQt.NodeObject: node object.
        """
        for node_id, node in self._model.nodes.items():
            if node.name() == name:
                return node

    def get_unique_name(self, name):
        """
        Creates a unique node name to avoid having nodes with the same name.

        Args:
            name (str): node name.

        Returns:
            str: unique node name.
        """
        name = ' '.join(name.split())
        node_names = [n.name() for n in self.all_nodes()]
        if name not in node_names:
            return name

        regex = re.compile('[\w ]+(?: )*(\d+)')
        search = regex.search(name)
        if not search:
            for x in range(1, len(node_names) + 1):
                new_name = '{} {}'.format(name, x)
                if new_name not in node_names:
                    return new_name

        version = search.group(1)
        name = name[:len(version) * -1].strip()
        for x in range(1, len(node_names) + 1):
            new_name = '{} {}'.format(name, x)
            if new_name not in node_names:
                return new_name

    def current_session(self):
        """
        Returns the file path to the currently loaded session.

        Returns:
            str: path to the currently loaded session
        """
        return self._model.session

    def clear_session(self):
        """
        Clears the current node graph session.
        """
        for n in self.all_nodes():
            self._undo_stack.push(NodeRemovedCmd(self, n))
        self._undo_stack.clear()
        self._model.session = None

    def _serialize(self, nodes):
        """
        serialize nodes to a dict.
        (used internally by the node graph)

        Args:
            nodes (list[NodeGraphQt.Nodes]): list of node instances.

        Returns:
            dict: serialized data.
        """
        serial_data = {'nodes': {}, 'connections': []}
        nodes_data = {}
        for n in nodes:

            # update the node model.
            n.update_model()

            nodes_data.update(n.model.to_dict)

        for n_id, n_data in nodes_data.items():
            serial_data['nodes'][n_id] = n_data

            inputs = n_data.pop('inputs') if n_data.get('inputs') else {}
            outputs = n_data.pop('outputs') if n_data.get('outputs') else {}

            for pname, conn_data in inputs.items():
                for conn_id, prt_names in conn_data.items():
                    for conn_prt in prt_names:
                        pipe = {
                            'in': [n_id, pname],
                            'out': [conn_id, conn_prt]
                        }
                        if pipe not in serial_data['connections']:
                            serial_data['connections'].append(pipe)

            for pname, conn_data in outputs.items():
                for conn_id, prt_names in conn_data.items():
                    for conn_prt in prt_names:
                        pipe = {
                            'out': [n_id, pname],
                            'in': [conn_id, conn_prt]
                        }
                        if pipe not in serial_data['connections']:
                            serial_data['connections'].append(pipe)

        if not serial_data['connections']:
            serial_data.pop('connections')

        return serial_data

    def _deserialize(self, data, relative_pos=False, pos=None):
        """
        deserialize node data.
        (used internally by the node graph)

        Args:
            data (dict): node data.
            relative_pos (bool): position node relative to the cursor.

        Returns:
            list[NodeGraphQt.Nodes]: list of node instances.
        """
        nodes = {}

        # build the nodes.
        for n_id, n_data in data.get('nodes', {}).items():
            identifier = n_data['type_']
            NodeCls = self._node_factory.create_node_instance(identifier)
            if NodeCls:
                node = NodeCls()
                node.NODE_NAME = n_data.get('name', node.NODE_NAME)
                # set properties.
                for prop in node.model.properties.keys():
                    if prop in n_data.keys():
                        node.model.set_property(prop, n_data[prop])
                # set custom properties.
                for prop, val in n_data.get('custom', {}).items():
                    node.model.set_property(prop, val)

                nodes[n_id] = node
                self.add_node(node, n_data.get('pos'))

        # build the connections.
        for connection in data.get('connections', []):
            nid, pname = connection.get('in', ('', ''))
            in_node = nodes.get(nid)
            if not in_node:
                continue
            in_port = in_node.inputs().get(pname) if in_node else None

            nid, pname = connection.get('out', ('', ''))
            out_node = nodes.get(nid)
            if not out_node:
                continue
            out_port = out_node.outputs().get(pname) if out_node else None

            if in_port and out_port:
                self._undo_stack.push(PortConnectedCmd(in_port, out_port))

        node_objs = list(nodes.values())
        if relative_pos:
            self._viewer.move_nodes([n.view for n in node_objs])
            [setattr(n.model, 'pos', n.view.xy_pos) for n in node_objs]
        elif pos:
            self._viewer.move_nodes([n.view for n in node_objs], pos=pos)
            [setattr(n.model, 'pos', n.view.xy_pos) for n in node_objs]

        return node_objs

    def serialize_session(self):
        """
        Serializes the current node graph layout to a dictionary.

        Returns:
            dict: serialized session of the current node layout.
        """
        return self._serialize(self.all_nodes())

    def deserialize_session(self, layout_data):
        """
        Load node graph session from a dictionary object.

        Args:
            layout_data (dict): dictionary object containing a node session.
        """
        self.clear_session()
        self._deserialize(layout_data)
        self._undo_stack.clear()

    def save_session(self, file_path):
        """
        Saves the current node graph session layout to a `JSON` formatted file.

        Args:
            file_path (str): path to the saved node layout.
        """
        serliazed_data = self._serialize(self.all_nodes())
        file_path = file_path.strip()
        with open(file_path, 'w') as file_out:
            json.dump(serliazed_data,
                      file_out,
                      indent=2,
                      separators=(',', ':'))

    def load_session(self, file_path):
        """
        Load node graph session layout file.

        Args:
            file_path (str): path to the serialized layout file.
        """
        file_path = file_path.strip()
        if not os.path.isfile(file_path):
            raise IOError('file does not exist.')

        self.clear_session()

        try:
            with open(file_path) as data_file:
                layout_data = json.load(data_file)
        except Exception as e:
            layout_data = None
            print('Cannot read data from file.\n{}'.format(e))

        if not layout_data:
            return

        self._deserialize(layout_data)
        self._undo_stack.clear()
        self._model.session = file_path

    def copy_nodes(self, nodes=None):
        """
        Copy nodes to the clipboard.

        Args:
            nodes (list[NodeGraphQt.BaseNode]): list of nodes (default: selected nodes).
        """
        nodes = nodes or self.selected_nodes()
        if not nodes:
            return False
        clipboard = QtWidgets.QApplication.clipboard()
        serial_data = self._serialize(nodes)
        serial_str = json.dumps(serial_data)
        if serial_str:
            clipboard.setText(serial_str)
            return True
        return False

    def paste_nodes(self):
        """
        Pastes nodes copied from the clipboard.
        """
        clipboard = QtWidgets.QApplication.clipboard()
        cb_text = clipboard.text()
        if not cb_text:
            return

        self._undo_stack.beginMacro('pasted nodes')
        serial_data = json.loads(cb_text)
        self.clear_selection()
        nodes = self._deserialize(serial_data, relative_pos=True)
        [n.set_selected(True) for n in nodes]
        self._undo_stack.endMacro()

    def duplicate_nodes(self, nodes):
        """
        Create duplicate copy from the list of nodes.

        Args:
            nodes (list[NodeGraphQt.BaseNode]): list of nodes.
        Returns:
            list[NodeGraphQt.BaseNode]: list of duplicated node instances.
        """
        if not nodes:
            return

        self._undo_stack.beginMacro('duplicate nodes')

        self.clear_selection()
        serial = self._serialize(nodes)
        new_nodes = self._deserialize(serial)
        offset = 50
        for n in new_nodes:
            x, y = n.pos()
            n.set_pos(x + offset, y + offset)
            n.set_property('selected', True)

        self._undo_stack.endMacro()
        return new_nodes

    def disable_nodes(self, nodes, mode=None):
        """
        Set weather to Disable or Enable specified nodes.

        see: :meth:`NodeObject.set_disabled`

        Args:
            nodes (list[NodeGraphQt.BaseNode]): list of node instances.
            mode (bool): (optional) disable state of the nodes.
        """
        if not nodes:
            return
        if mode is None:
            mode = not nodes[0].disabled()
        if len(nodes) > 1:
            text = {False: 'enable', True: 'disable'}[mode]
            text = '{} ({}) nodes'.format(text, len(nodes))
            self._undo_stack.beginMacro(text)
            [n.set_disabled(mode) for n in nodes]
            self._undo_stack.endMacro()
            return
        nodes[0].set_disabled(mode)

    def question_dialog(self, text, title='Node Graph'):
        """
        Prompts a question open dialog with "Yes" and "No" buttons in
        the node graph.

        (convenience function to :meth:`NodeGraph.viewer().question_dialog`)

        Args:
            text (str): question text.
            title (str): dialog window title.

        Returns:
            bool: true if user clicked yes.
        """
        self._viewer.question_dialog(text, title)

    def message_dialog(self, text, title='Node Graph'):
        """
        Prompts a file open dialog in the node graph.

        (convenience function to :meth:`NodeGraph.viewer().message_dialog`)

        Args:
            text (str): message text.
            title (str): dialog window title.
        """
        self._viewer.message_dialog(text, title)

    def load_dialog(self, current_dir=None, ext=None):
        """
        Prompts a file open dialog in the node graph.

        (convenience function to :meth:`NodeGraph.viewer().load_dialog`)

        Args:
            current_dir (str): path to a directory.
            ext (str): custom file type extension (default: json)

        Returns:
            str: selected file path.
        """
        return self._viewer.load_dialog(current_dir, ext)

    def save_dialog(self, current_dir=None, ext=None):
        """
        Prompts a file save dialog in the node graph.

        (convenience function to :meth:`NodeGraph.viewer().save_dialog`)

        Args:
            current_dir (str): path to a directory.
            ext (str): custom file type extension (default: json)

        Returns:
            str: selected file path.
        """
        return self._viewer.save_dialog(current_dir, ext)
예제 #27
0
class NodeBaseWidget(QtWidgets.QGraphicsProxyWidget):
    """
    NodeBaseWidget is the main base class for all node widgets that is
    embedded in a :class:`NodeGraphQt.BaseNode` object.
    """

    value_changed = QtCore.Signal(str, object)
    """
    Signal triggered when the ``value`` attribute has changed.

    :parameters: str, object
    :emits: property name, propety value
    """
    def __init__(self, parent=None, name='widget', label=''):
        super(NodeBaseWidget, self).__init__(parent)
        self.setZValue(Z_VAL_NODE_WIDGET)
        self._name = name
        self._label = label

    def _value_changed(self):
        self.value_changed.emit(self.name, self.value)

    def setToolTip(self, tooltip):
        tooltip = tooltip.replace('\n', '<br/>')
        tooltip = '<b>{}</b><br/>{}'.format(self.name, tooltip)
        super(NodeBaseWidget, self).setToolTip(tooltip)

    @property
    def widget(self):
        """
        Returns the embedded QWidget used in the node.

        Returns:
            QtWidgets.QWidget: nested QWidget
        """
        raise NotImplementedError

    @property
    def value(self):
        """
        Returns the widgets current value.

        Returns:
            str: current property value.
        """
        raise NotImplementedError

    @value.setter
    def value(self, text):
        """
        Sets the widgets current value.

        Args:
            text (str): new text value.
        """
        raise NotImplementedError

    @property
    def label(self):
        """
        Returns the label text displayed above the embedded node widget.

        Returns:
            str: label text.
        """
        return self._label

    @label.setter
    def label(self, label):
        """
        Sets the label text above the embedded widget.

        Args:
            label (str): new label ext.
        """
        self._label = label

    @property
    def type_(self):
        """
        Returns the node widget type.

        Returns:
            str: widget type.
        """
        return str(self.__class__.__name__)

    @property
    def node(self):
        """
        Returns the parent base node qgraphics item.

        Returns:
            NodeItem: parent node.
        """
        self.parentItem()

    @property
    def name(self):
        """
        Returns the parent node property name.

        Returns:
            str: property name.
        """
        return self._name

    def get_icon(self, name):
        """
        Returns the qt default icon.

        Returns:
            str: icon name.
        """
        return self.style().standardIcon(QtWidgets.QStyle.StandardPixmap(name))
예제 #28
0
class TabSearchMenuWidget(QtWidgets.QLineEdit):
    search_submitted = QtCore.Signal(str)

    def __init__(self, parent=None, node_dict=None):
        super(TabSearchMenuWidget, self).__init__(parent)
        self.setAttribute(QtCore.Qt.WA_MacShowFocusRect, 0)
        self.setStyleSheet(STYLE_TABSEARCH)
        self.setMinimumSize(200, 22)
        self.setTextMargins(2, 0, 2, 0)

        self._node_dict = node_dict or {}
        if self._node_dict:
            self._generate_items_from_node_dict()

        self.SearchMenu = QtWidgets.QMenu()
        searchWidget = QtWidgets.QWidgetAction(self)
        searchWidget.setDefaultWidget(self)
        self.SearchMenu.addAction(searchWidget)
        self.SearchMenu.setStyleSheet(STYLE_QMENU)

        self._actions = []
        self._menus = {}
        self._searched_actions = []

        self.returnPressed.connect(self._on_search_submitted)
        self.textChanged.connect(self._on_text_changed)

    def __repr__(self):
        return '<{} at {}>'.format(self.__class__.__name__, hex(id(self)))

    def _on_text_changed(self, text):
        self._clear_actions()

        if not text:
            self._set_menu_visible(True)
            return

        self._set_menu_visible(False)

        self._searched_actions = [action for action in self._actions\
                                 if text.lower() in action.text().lower()]

        self.SearchMenu.addActions(self._searched_actions)

    def _clear_actions(self):
        for action in self._searched_actions:
            self.SearchMenu.removeAction(action)
        self._searched_actions = []

    def _set_menu_visible(self, visible):
        for menu in self._menus.values():
            menu.menuAction().setVisible(visible)

    def _close(self):
        self._set_menu_visible(False)
        self.SearchMenu.setVisible(False)
        self.SearchMenu.menuAction().setVisible(False)

    def _show(self):
        self.SearchMenu.exec_(QtGui.QCursor.pos())
        self.setText("")
        self.setFocus()
        self._set_menu_visible(True)

    def _on_search_submitted(self):
        action = self.sender()
        if type(action) is not QtWidgets.QAction:
            if len(self._searched_actions) > 0:
                action = self._searched_actions[0]
            else:
                self._close()
                return

        text = action.text()
        node_type = self._node_dict.get(text)
        if node_type:
            self.search_submitted.emit(node_type)

        self._close()

    def _generate_items_from_node_dict(self):
        node_names = sorted(self._node_dict.keys())
        node_types = sorted(self._node_dict.values())

        self._menus.clear()
        self._actions.clear()
        self._searched_actions.clear()

        for node_type in node_types:
            menu_name = ".".join(node_type.split(".")[:-1])
            if menu_name not in self._menus.keys():
                new_menu = QtWidgets.QMenu(menu_name)
                new_menu.setStyleSheet(STYLE_QMENU)
                self._menus[menu_name] = new_menu
                self.SearchMenu.addMenu(new_menu)

        for name in node_names:
            action = QtWidgets.QAction(name, self)
            action.setText(name)
            action.triggered.connect(self._on_search_submitted)
            self._actions.append(action)

            menu_name = self._node_dict[name]
            menu_name = ".".join(menu_name.split(".")[:-1])

            if menu_name in self._menus.keys():
                self._menus[menu_name].addAction(action)
            else:
                self.SearchMenu.addAction(action)

    def set_nodes(self, node_dict=None):
        if not self._node_dict:
            self._node_dict.clear()
            for name, node_types in node_dict.items():
                if len(node_types) == 1:
                    self._node_dict[name] = node_types[0]
                    continue
                for node_id in node_types:
                    self._node_dict['{} ({})'.format(name, node_id)] = node_id
            self._generate_items_from_node_dict()

        self._show()
예제 #29
0
class AutoNode(BaseNode, QtCore.QObject):
    cooked = QtCore.Signal()

    def __init__(self, defaultInputType=None, defaultOutputType=None):
        super(AutoNode, self).__init__()
        QtCore.QObject.__init__(self)
        self.needCook = True
        self._autoCook = True
        self._error = False
        self.matchTypes = [[float, int]]
        self.errorColor = (200, 50, 50)
        self.stopCookColor = (200, 200, 200)
        self._cryptoColors = CryptoColors()

        self.defaultColor = self.get_property("color")
        self.defaultValue = None
        self.defaultInputType = defaultInputType
        self.defaultOutputType = defaultOutputType

        self._cookTime = 0.0
        self._toolTip = self._setup_tool_tip()

    @property
    def autoCook(self):
        return self._autoCook

    @autoCook.setter
    def autoCook(self, mode):
        if mode is self._autoCook:
            return

        self._autoCook = mode
        if mode:
            self.set_property('color', self.defaultColor)
        else:
            self.defaultColor = self.get_property("color")
            self.set_property('color', self.stopCookColor)

    @property
    def cookTime(self):
        return self._cookTime

    @autoCook.setter
    def cookTime(self, time):
        self._cookTime = time
        self._update_tool_tip()

    def cookNextNode(self):
        for nodeList in self.connected_output_nodes().values():
            for n in nodeList:
                if n is not self:
                    n.cook()

    def getData(self, port):
        # for custom output data
        return self.get_property(port.name())

    def getInputData(self, port):
        """
        get input data by input Port,the type of "port" can be :
        int : Port index
        str : Port name
        Port : Port object
        """

        if type(port) is int:
            to_port = self.input(port)
        elif type(port) is str:
            to_port = self.inputs()[port]
        elif type(port) is Port:
            to_port = port
        else:
            print(self.inputs().keys())
            return self.defaultValue

        from_ports = to_port.connected_ports()
        if not from_ports:
            return copy.deepcopy(self.defaultValue)

        for from_port in from_ports:
            data = from_port.node().getData(from_port)
            return copy.deepcopy(data)

    def when_disabled(self):
        num = len(self.input_ports())
        for index, out_port in enumerate(self.output_ports()):
            self.set_property(out_port.name(), self.getInputData(index % num))

    def cook(self, forceCook=False):
        if not self._autoCook and forceCook is not True:
            return

        if not self.needCook:
            return

        _tmp = self._autoCook
        self._autoCook = False

        if self.error():
            self._close_error()

        _start_time = time.time()

        try:
            self.run()
        except Exception as error:
            self.error(error)

        self._autoCook = _tmp

        if self.error():
            return

        self.cookTime = time.time() - _start_time

        self.cooked.emit()
        self.cookNextNode()

    def run(self):
        pass

    def on_input_connected(self, to_port, from_port):
        if self.checkPortType(to_port, from_port):
            self.cook()
        else:
            self.needCook = False
            to_port.disconnect_from(from_port)

    def on_input_disconnected(self, to_port, from_port):
        if not self.needCook:
            self.needCook = True
            return
        self.cook()

    def checkPortType(self, to_port, from_port):
        # None type port can connect with any other type port
        # types in self.matchTypes can connect with each other

        if hasattr(to_port, "DataType") and hasattr(from_port, "DataType"):
            if to_port.DataType is not from_port.DataType:
                for types in self.matchTypes:
                    if to_port.DataType in types and from_port.DataType in types:
                        return True
                return False

        return True

    def set_property(self, name, value):
        super(AutoNode, self).set_property(name, value)
        self.set_port_type(name, type(value))
        if name in self.model.custom_properties.keys():
            self.cook()

    def set_port_type(self, port, value_type):
        current_port = None

        if type(port) is Port:
            current_port = port
        elif type(port) is str:
            inputs = self.inputs()
            outputs = self.outputs()
            if port in inputs.keys():
                current_port = inputs[port]
            elif port in outputs.keys():
                current_port = outputs[port]

        if current_port:
            if hasattr(current_port, "DataType"):
                if current_port.DataType is value_type:
                    return
            else:
                current_port.DataType = value_type

            current_port.border_color = self._cryptoColors.get(str(value_type))
            current_port.color = self._cryptoColors.get(str(value_type))
            conn_type = 'multi' if current_port.multi_connection(
            ) else 'single'
            data_type_name = value_type.__name__ if value_type else "all"
            current_port.view.setToolTip('{}: {} ({}) '.format(
                current_port.name(), data_type_name, conn_type))

    def create_property(self,
                        name,
                        value,
                        items=None,
                        range=None,
                        widget_type=NODE_PROP,
                        tab=None,
                        ext=None,
                        funcs=None):
        super(AutoNode, self).create_property(name, value, items, range,
                                              widget_type, tab, ext, funcs)

        if value is not None:
            self.set_port_type(name, type(value))

    def add_input(self,
                  name='input',
                  data_type=None,
                  multi_input=False,
                  display_name=True,
                  color=None):
        new_port = super(AutoNode, self).add_input(name, multi_input,
                                                   display_name, color)
        if data_type:
            self.set_port_type(new_port, data_type)
        elif self.defaultInputType:
            self.set_port_type(new_port, self.defaultInputType)
        return new_port

    def add_output(self,
                   name='output',
                   data_type=None,
                   multi_output=True,
                   display_name=True,
                   color=None):
        new_port = super(AutoNode, self).add_output(name, multi_output,
                                                    display_name, color)
        if data_type:
            self.set_port_type(new_port, data_type)
        elif self.defaultOutputType:
            self.set_port_type(new_port, self.defaultOutputType)
        return new_port

    def set_disabled(self, mode=False):
        super(AutoNode, self).set_disabled(mode)
        self._autoCook = not mode
        if mode is True:
            self.when_disabled()
            self.cookNextNode()
        else:
            self.cook()

    def _close_error(self):
        self._error = False
        self.set_property('color', self.defaultColor)
        self._update_tool_tip()

    def _show_error(self, message):
        if not self._error:
            self.defaultColor = self.get_property("color")

        self._error = True
        self.set_property('color', self.errorColor)
        tooltip = '<font color="red"><br>({})</br></font>'.format(message)
        self._update_tool_tip(tooltip)

    def _update_tool_tip(self, message=None):
        if message is None:
            tooltip = self._toolTip.format(self._cookTime)
        else:
            tooltip = '<b>{}</b>'.format(self.name())
            tooltip += message
            tooltip += '<br/>{}<br/>'.format(self._view.type_)
        self.view.setToolTip(tooltip)
        return tooltip

    def _setup_tool_tip(self):
        tooltip = '<br> last cook used: {}s</br>'
        return self._update_tool_tip(tooltip)

    def error(self, message=None):
        if message is None:
            return self._error

        self._show_error(message)

    def update_model(self):
        if self.error():
            self.set_property('color', self.defaultColor)
        super(AutoNode, self).update_model()
예제 #30
0
class NodeGraph(QtCore.QObject):
    """
    The ``NodeGraph`` class is the main controller for managing all nodes.

    Inherited from: :class:`PySide2.QtCore.QObject`

    .. image:: _images/graph.png
        :width: 60%
    """

    node_created = QtCore.Signal(NodeObject)
    """
    Signal triggered when a node is created in the node graph.
    
    :parameters: :class:`NodeGraphQt.NodeObject`
    :emits: created node
    """
    nodes_deleted = QtCore.Signal(list)
    """
    Signal triggered when nodes have been deleted from the node graph.

    :parameters: list[str]
    :emits: list of deleted node ids.
    """
    node_selected = QtCore.Signal(NodeObject)
    """
    Signal triggered when a node is clicked with the LMB.
    
    :parameters: :class:`NodeGraphQt.NodeObject`
    :emits: selected node
    """
    node_double_clicked = QtCore.Signal(NodeObject)
    """
    Signal triggered when a node is double clicked and emits the node.

    :parameters: :class:`NodeGraphQt.NodeObject`
    :emits: selected node
    """
    port_connected = QtCore.Signal(Port, Port)
    """
    Signal triggered when a node port has been connected.

    :parameters: :class:`NodeGraphQt.Port`, :class:`NodeGraphQt.Port` 
    :emits: input port, output port
    """
    port_disconnected = QtCore.Signal(Port, Port)
    """
    Signal triggered when a node port has been disconnected.

    :parameters: :class:`NodeGraphQt.Port`, :class:`NodeGraphQt.Port` 
    :emits: input port, output port
    """
    property_changed = QtCore.Signal(NodeObject, str, object)
    """
    Signal is triggered when a property has changed on a node.  

    :parameters: :class:`NodeGraphQt.BaseNode`, str, object
    :emits: triggered node, property name, property value
    """
    data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint)
    """
    Signal is triggered when data has been dropped to the graph.
    
    :parameters: :class:`PySide2.QtCore.QMimeData`, :class:`PySide2.QtCore.QPoint`
    :emits: mime data, node graph position
    """
    session_changed = QtCore.Signal(str)
    """
    Signal is triggered when session has been changed.

    :parameters: :str
    :emits: new session path
    """
    def __init__(self, parent=None):
        super(NodeGraph, self).__init__(parent)
        self.setObjectName('NodeGraphQt')
        self._widget = None
        self._model = NodeGraphModel()
        self._viewer = NodeViewer()
        self._node_factory = NodeFactory()
        self._undo_stack = QtWidgets.QUndoStack(self)

        tab = QtWidgets.QShortcut(QtGui.QKeySequence(QtCore.Qt.Key_Tab), self._viewer)
        tab.activated.connect(self._toggle_tab_search)
        self._viewer.need_show_tab_search.connect(self._toggle_tab_search)

        self._wire_signals()
        self.widget.setAcceptDrops(True)

    def __repr__(self):
        return '<{} object at {}>'.format(self.__class__.__name__, hex(id(self)))

    def _wire_signals(self):
        # internal signals.
        self._viewer.search_triggered.connect(self._on_search_triggered)
        self._viewer.connection_sliced.connect(self._on_connection_sliced)
        self._viewer.connection_changed.connect(self._on_connection_changed)
        self._viewer.moved_nodes.connect(self._on_nodes_moved)
        self._viewer.node_double_clicked.connect(self._on_node_double_clicked)
        self._viewer.insert_node.connect(self._insert_node)

        # pass through signals.
        self._viewer.node_selected.connect(self._on_node_selected)
        self._viewer.data_dropped.connect(self._on_node_data_dropped)

    def _insert_node(self, pipe, node_id, prev_node_pos):
        """
        Slot function triggered when a selected node has collided with a pipe.

        Args:
            pipe (Pipe): collided pipe item.
            node_id (str): selected node id to insert.
            prev_node_pos (dict): previous node position. {NodeItem: [prev_x, prev_y]}
        """
        node = self.get_node_by_id(node_id)

        # exclude the BackdropNode
        if not isinstance(node, BaseNode):
            return

        disconnected = [(pipe.input_port, pipe.output_port)]
        connected = []

        if node.inputs():
            connected.append(
                (pipe.output_port, list(node.inputs().values())[0].view)
            )
        if node.outputs():
            connected.append(
                (list(node.outputs().values())[0].view, pipe.input_port)
            )

        self._undo_stack.beginMacro('inserted node')
        self._on_connection_changed(disconnected, connected)
        self._on_nodes_moved(prev_node_pos)
        self._undo_stack.endMacro()

    def _toggle_tab_search(self):
        """
        toggle the tab search widget.
        """
        if self._viewer.underMouse():
            self._viewer.tab_search_set_nodes(self._node_factory.names)
            self._viewer.tab_search_toggle()

    def _on_property_bin_changed(self, node_id, prop_name, prop_value):
        """
        called when a property widget has changed in a properties bin.
        (emits the node object, property name, property value)

        Args:
            node_id (str): node id.
            prop_name (str): node property name.
            prop_value (object): python object.
        """
        node = self.get_node_by_id(node_id)

        # prevent signals from causing a infinite loop.
        _exc = [float, int , str, bool, None]
        if node.get_property(prop_name) != prop_value:
            if type(node.get_property(prop_name)) in _exc:
                value = prop_value
            else:
                value = copy.deepcopy(prop_value)
            node.set_property(prop_name, value)

    def _on_node_double_clicked(self, node_id):
        """
        called when a node in the viewer is double click.
        (emits the node object when the node is clicked)

        Args:
            node_id (str): node id emitted by the viewer.
        """
        node = self.get_node_by_id(node_id)
        self.node_double_clicked.emit(node)

    def _on_node_selected(self, node_id):
        """
        called when a node in the viewer is selected on left click.
        (emits the node object when the node is clicked)

        Args:
            node_id (str): node id emitted by the viewer.
        """
        node = self.get_node_by_id(node_id)
        self.node_selected.emit(node)

    def _on_node_data_dropped(self, data, pos):
        """
        called when data has been dropped on the viewer.

        Args:
            data (QtCore.QMimeData): mime data.
            pos (QtCore.QPoint): scene position relative to the drop.
        """

        # don't emit signal for internal widget drops.
        if data.hasFormat('text/plain'):
            if data.text().startswith('<${}>:'.format(DRAG_DROP_ID)):
                node_ids = data.text()[len('<${}>:'.format(DRAG_DROP_ID)):]
                x, y = pos.x(), pos.y()
                for node_id in node_ids.split(','):
                    self.create_node(node_id, pos=[x, y])
                x += 20
                y += 20
                return

        self.data_dropped.emit(data, pos)

    def _on_nodes_moved(self, node_data):
        """
        called when selected nodes in the viewer has changed position.

        Args:
            node_data (dict): {<node_view>: <previous_pos>}
        """
        self._undo_stack.beginMacro('move nodes')
        for node_view, prev_pos in node_data.items():
            node = self._model.nodes[node_view.id]
            self._undo_stack.push(NodeMovedCmd(node, node.pos(), prev_pos))
        self._undo_stack.endMacro()

    def _on_search_triggered(self, node_type, pos):
        """
        called when the tab search widget is triggered in the viewer.

        Args:
            node_type (str): node identifier.
            pos (tuple): x,y position for the node.
        """
        self.create_node(node_type, pos=pos)

    def _on_connection_changed(self, disconnected, connected):
        """
        called when a pipe connection has been changed in the viewer.

        Args:
            disconnected (list[list[widgets.port.PortItem]):
                pair list of port view items.
            connected (list[list[widgets.port.PortItem]]):
                pair list of port view items.
        """
        if not (disconnected or connected):
            return

        label = 'connect node(s)' if connected else 'disconnect node(s)'
        ptypes = {IN_PORT: 'inputs', OUT_PORT: 'outputs'}

        self._undo_stack.beginMacro(label)
        for p1_view, p2_view in disconnected:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.disconnect_from(port2)
        for p1_view, p2_view in connected:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.connect_to(port2)
        self._undo_stack.endMacro()

    def _on_connection_sliced(self, ports):
        """
        slot when connection pipes have been sliced.

        Args:
            ports (list[list[widgets.port.PortItem]]):
                pair list of port connections (in port, out port)
        """
        if not ports:
            return
        ptypes = {IN_PORT: 'inputs', OUT_PORT: 'outputs'}
        self._undo_stack.beginMacro('slice connections')
        for p1_view, p2_view in ports:
            node1 = self._model.nodes[p1_view.node.id]
            node2 = self._model.nodes[p2_view.node.id]
            port1 = getattr(node1, ptypes[p1_view.port_type])()[p1_view.name]
            port2 = getattr(node2, ptypes[p2_view.port_type])()[p2_view.name]
            port1.disconnect_from(port2)
        self._undo_stack.endMacro()

    @property
    def model(self):
        """
        The model used for storing the node graph data.

        Returns:
            NodeGraphQt.base.model.NodeGraphModel: node graph model.
        """
        return self._model

    @property
    def widget(self):
        """
        The node graph widget for adding into a layout.

        Returns:
            PySide2.QtWidgets.QWidget: node graph widget.
        """
        if self._widget is None:
            self._widget = QWidgetDrops()
            self._widget.import_session = self.import_session
            
                        layout = QtWidgets.QVBoxLayout(self._widget)
            layout.setContentsMargins(0, 0, 0, 0)
            layout.addWidget(self._viewer)
        return self._widget