示例#1
0
 def __init__(self, parent):
     self.groups: List[ProfileGroup] = []
     self.bottom_axis = BottomAxisItem(orientation="bottom")
     self.bottom_axis.setLabel("")
     left_axis = AxisItem(orientation="left")
     left_axis.setLabel("")
     super().__init__(parent, viewBox=LinePlotViewBox(),
                      background="w", enableMenu=False,
                      axisItems={"bottom": self.bottom_axis,
                                 "left": left_axis})
     self.view_box = self.getViewBox()
     self.selection = set()
     self.legend = self._create_legend(((1, 0), (1, 0)))
     self.getPlotItem().buttonsHidden = True
     self.setRenderHint(QPainter.Antialiasing, True)
class ViolinPlot(QGraphicsWidget):
    LABEL_COLUMN, VIOLIN_COLUMN, LEGEND_COLUMN = range(3)
    VIOLIN_COLUMN_WIDTH, OFFSET = 300, 80
    MAX_N_ITEMS = 100
    selection_cleared = Signal()
    selection_changed = Signal(float, float, str)
    resized = Signal()

    def __init__(self):
        super().__init__()
        self.__violin_column_width = self.VIOLIN_COLUMN_WIDTH  # type: int
        self.__range = None  # type: Optional[Tuple[float, float]]
        self.__violin_items = []  # type: List[ViolinItem]
        self.__variable_items = []  # type: List[VariableItem]
        self.__bottom_axis = AxisItem(parent=self, orientation="bottom",
                                      maxTickLength=7, pen=QPen(Qt.black))
        self.__bottom_axis.setLabel("Impact on model output")
        self.__vertical_line = QGraphicsLineItem(self.__bottom_axis)
        self.__vertical_line.setPen(QPen(Qt.gray))
        self.__legend = Legend(self)

        self.__layout = QGraphicsGridLayout()
        self.__layout.addItem(self.__legend, 0, ViolinPlot.LEGEND_COLUMN)
        self.__layout.setVerticalSpacing(0)
        self.setLayout(self.__layout)

        self.parameter_setter = ParameterSetter(self)

    @property
    def violin_column_width(self):
        return self.__violin_column_width

    @violin_column_width.setter
    def violin_column_width(self, view_width: int):
        j = ViolinPlot.LABEL_COLUMN
        w = max([self.__layout.itemAt(i, j).item.boundingRect().width()
                 for i in range(len(self.__violin_items))] + [0])
        width = view_width - self.legend.sizeHint().width() - self.OFFSET - w
        self.__violin_column_width = max(self.VIOLIN_COLUMN_WIDTH, width)

    @property
    def bottom_axis(self):
        return self.__bottom_axis

    @property
    def labels(self):
        return self.__variable_items

    @property
    def legend(self):
        return self.__legend

    def set_data(self, x: np.ndarray, colors: np.ndarray,
                 names: List[str], n_attrs: float, view_width: int):
        self.violin_column_width = view_width
        abs_max = np.max(np.abs(x)) * 1.05
        self.__range = (-abs_max, abs_max)
        self._set_violin_items(x, colors, names)
        self._set_labels(names)
        self._set_bottom_axis()
        self.set_n_visible(n_attrs)

    def set_n_visible(self, n: int):
        for i in range(len(self.__violin_items)):
            violin_item = self.__layout.itemAt(i, ViolinPlot.VIOLIN_COLUMN)
            violin_item.setVisible(i < n)
            text_item = self.__layout.itemAt(i, ViolinPlot.LABEL_COLUMN).item
            text_item.setVisible(i < n)
        self.set_vertical_line()

    def rescale(self, view_width: int):
        self.violin_column_width = view_width
        with temp_seed(0):
            for item in self.__violin_items:
                item.rescale(self.violin_column_width)

        self.__bottom_axis.setWidth(self.violin_column_width)
        x = self.violin_column_width / 2
        self.__vertical_line.setLine(x, 0, x, self.__vertical_line.line().y2())

    def show_legend(self, show: bool):
        self.__legend.setVisible(show)
        self.__bottom_axis.setWidth(self.violin_column_width)
        x = self.violin_column_width / 2
        self.__vertical_line.setLine(x, 0, x, self.__vertical_line.line().y2())

    def _set_violin_items(self, x: np.ndarray, colors: np.ndarray,
                          labels: List[str]):
        with temp_seed(0):
            for i in range(x.shape[1]):
                item = ViolinItem(self, labels[i], self.__range,
                                  self.violin_column_width)
                item.set_data(x[:, i], colors[:, i])
                item.selection_changed.connect(self.select)
                self.__violin_items.append(item)
                self.__layout.addItem(item, i, ViolinPlot.VIOLIN_COLUMN)
                if i == self.MAX_N_ITEMS:
                    break

    def _set_labels(self, labels: List[str]):
        for i, (label, _) in enumerate(zip(labels, self.__violin_items)):
            text = VariableItem(self, label)
            item = SimpleLayoutItem(text)
            item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
            self.__layout.addItem(item, i, ViolinPlot.LABEL_COLUMN,
                                  Qt.AlignRight | Qt.AlignVCenter)
            self.__variable_items.append(item)

    def _set_bottom_axis(self):
        self.__bottom_axis.setRange(*self.__range)
        self.__layout.addItem(self.__bottom_axis,
                              len(self.__violin_items),
                              ViolinPlot.VIOLIN_COLUMN)

    def set_vertical_line(self):
        x = self.violin_column_width / 2
        height = 0
        for i in range(len(self.__violin_items)):
            violin_item = self.__layout.itemAt(i, ViolinPlot.VIOLIN_COLUMN)
            text_item = self.__layout.itemAt(i, ViolinPlot.LABEL_COLUMN).item
            if violin_item.isVisible():
                height += max(text_item.boundingRect().height(),
                              violin_item.preferredSize().height())
        self.__vertical_line.setLine(x, 0, x, -height)

    def deselect(self):
        self.selection_cleared.emit()

    def select(self, *args):
        self.selection_changed.emit(*args)

    def select_from_settings(self, x1: float, x2: float, attr_name: str):
        point_r_diff = 2 * self.__range[1] / (self.violin_column_width / 2)
        for item in self.__violin_items:
            if item.attr_name == attr_name:
                item.add_selection_rect(x1 - point_r_diff, x2 + point_r_diff)
                break
        self.select(x1, x2, attr_name)

    def apply_visual_settings(self, settings: Dict):
        for key, value in settings.items():
            self.parameter_setter.set_parameter(key, value)
示例#3
0
class BarPlotGraph(PlotWidget):
    selection_changed = Signal(list)
    bar_width = 0.7

    def __init__(self, master, parent=None):
        self.selection = []
        self.master: OWBarPlot = master
        self.state: int = SELECT
        self.bar_item: pg.BarGraphItem = None
        super().__init__(
            parent=parent,
            viewBox=BarPlotViewBox(self),
            enableMenu=False,
            axisItems={
                "bottom": AxisItem(orientation="bottom", rotate_ticks=True),
                "left": AxisItem(orientation="left"),
            },
        )
        self.hideAxis("left")
        self.hideAxis("bottom")
        self.getPlotItem().buttonsHidden = True
        self.getPlotItem().setContentsMargins(10, 0, 0, 10)
        self.getViewBox().setMouseMode(pg.ViewBox.PanMode)

        self.group_axis = AxisItem("bottom")
        self.group_axis.hide()
        self.group_axis.linkToView(self.getViewBox())
        self.getPlotItem().layout.addItem(self.group_axis, 4, 1)

        self.legend = self._create_legend()

        self.tooltip_delegate = HelpEventDelegate(self.help_event)
        self.scene().installEventFilter(self.tooltip_delegate)

        self.parameter_setter = ParameterSetter(self)

        self.showGrid(
            y=self.parameter_setter.DEFAULT_SHOW_GRID,
            alpha=self.parameter_setter.DEFAULT_ALPHA_GRID / 255,
        )

    def _create_legend(self):
        legend = LegendItem()
        legend.setParentItem(self.getViewBox())
        legend.anchor((1, 0), (1, 0), offset=(-3, 1))
        legend.hide()
        return legend

    def update_legend(self):
        self.legend.clear()
        self.legend.hide()
        for color, text in self.master.get_legend_data():
            dot = pg.ScatterPlotItem(pen=pg.mkPen(color=color),
                                     brush=pg.mkBrush(color=color))
            self.legend.addItem(dot, escape(text))
            self.legend.show()
        Updater.update_legend_font(self.legend.items,
                                   **self.parameter_setter.legend_settings)

    def reset_graph(self):
        self.clear()
        self.update_bars()
        self.update_axes()
        self.update_group_lines()
        self.update_legend()
        self.reset_view()

    def update_bars(self):
        if self.bar_item is not None:
            self.removeItem(self.bar_item)
            self.bar_item = None

        values = self.master.get_values()
        if values is None:
            return

        self.bar_item = pg.BarGraphItem(
            x=np.arange(len(values)),
            height=values,
            width=self.bar_width,
            pen=pg.mkPen(QColor(Qt.white)),
            labels=self.master.get_labels(),
            brushes=self.master.get_colors(),
        )
        self.addItem(self.bar_item)
        self.__select_bars()

    def update_axes(self):
        if self.bar_item is not None:
            self.showAxis("left")
            self.showAxis("bottom")
            self.group_axis.show()

            vals_label, group_label, annot_label = self.master.get_axes()
            self.setLabel(axis="left", text=vals_label)
            self.setLabel(axis="bottom", text=annot_label)
            self.group_axis.setLabel(group_label)

            ticks = [list(enumerate(self.master.get_labels()))]
            self.getAxis("bottom").setTicks(ticks)

            labels = np.array(self.master.get_group_labels())
            _, indices, counts = np.unique(labels,
                                           return_index=True,
                                           return_counts=True)
            ticks = [[(i + (c - 1) / 2, labels[i])
                      for i, c in zip(indices, counts)]]
            self.group_axis.setTicks(ticks)

            if not group_label:
                self.group_axis.hide()
            elif not annot_label:
                self.hideAxis("bottom")
        else:
            self.hideAxis("left")
            self.hideAxis("bottom")
            self.group_axis.hide()

    def reset_view(self):
        if self.bar_item is None:
            return
        values = np.append(self.bar_item.opts["height"], 0)
        min_ = np.nanmin(values)
        max_ = -min_ + np.nanmax(values)
        rect = QRectF(-0.5, min_, len(values) - 1, max_)
        self.getViewBox().setRange(rect)

    def zoom_button_clicked(self):
        self.state = ZOOMING
        self.getViewBox().setMouseMode(pg.ViewBox.RectMode)

    def pan_button_clicked(self):
        self.state = PANNING
        self.getViewBox().setMouseMode(pg.ViewBox.PanMode)

    def select_button_clicked(self):
        self.state = SELECT
        self.getViewBox().setMouseMode(pg.ViewBox.RectMode)

    def reset_button_clicked(self):
        self.reset_view()

    def update_group_lines(self):
        if self.bar_item is None:
            return

        labels = np.array(self.master.get_group_labels())
        if labels is None or len(labels) == 0:
            return

        _, indices = np.unique(labels, return_index=True)
        offset = self.bar_width / 2 + (1 - self.bar_width) / 2
        for index in sorted(indices)[1:]:
            line = pg.InfiniteLine(pos=index - offset, angle=90)
            self.addItem(line)

    def select_by_rectangle(self, rect: QRectF):
        if self.bar_item is None:
            return

        x0, x1 = sorted((rect.topLeft().x(), rect.bottomRight().x()))
        y0, y1 = sorted((rect.topLeft().y(), rect.bottomRight().y()))
        x = self.bar_item.opts["x"]
        height = self.bar_item.opts["height"]
        d = self.bar_width / 2
        # positive bars
        mask = (x0 <= x + d) & (x1 >= x - d) & (y0 <= height) & (y1 > 0)
        # negative bars
        mask |= (x0 <= x + d) & (x1 >= x - d) & (y0 <= 0) & (y1 > height)
        self.select_by_indices(list(np.flatnonzero(mask)))

    def select_by_click(self, p: QPointF):
        if self.bar_item is None:
            return

        index = self.__get_index_at(p)
        self.select_by_indices([index] if index is not None else [])

    def __get_index_at(self, p: QPointF):
        x = p.x()
        index = round(x)
        heights = self.bar_item.opts["height"]
        if 0 <= index < len(heights) and abs(x - index) <= self.bar_width / 2:
            height = heights[index]  # pylint: disable=unsubscriptable-object
            if 0 <= p.y() <= height or height <= p.y() <= 0:
                return index
        return None

    def select_by_indices(self, indices: List):
        keys = QApplication.keyboardModifiers()
        if keys & Qt.ControlModifier:
            self.selection = list(set(self.selection) ^ set(indices))
        elif keys & Qt.AltModifier:
            self.selection = list(set(self.selection) - set(indices))
        elif keys & Qt.ShiftModifier:
            self.selection = list(set(self.selection) | set(indices))
        else:
            self.selection = list(set(indices))
        self.__select_bars()
        self.selection_changed.emit(self.selection)

    def __select_bars(self):
        if self.bar_item is None:
            return

        n = len(self.bar_item.opts["height"])
        pens = np.full(n, pg.mkPen(QColor(Qt.white)))
        pen = pg.mkPen(QColor(Qt.black))
        pen.setStyle(Qt.DashLine)
        pens[self.selection] = pen
        self.bar_item.setOpts(pens=pens)

    def help_event(self, ev: QGraphicsSceneHelpEvent):
        if self.bar_item is None:
            return False

        index = self.__get_index_at(self.bar_item.mapFromScene(ev.scenePos()))
        text = ""
        if index is not None:
            text = self.master.get_tooltip(index)
        if text:
            QToolTip.showText(ev.screenPos(), text, widget=self)
            return True
        else:
            return False
示例#4
0
class FeaturesPlot(QGraphicsWidget):
    BOTTOM_AXIS_LABEL = "Feature Importance"
    LABEL_COLUMN, ITEM_COLUMN = range(2)
    ITEM_COLUMN_WIDTH, OFFSET = 300, 80
    selection_cleared = Signal()
    selection_changed = Signal(object)
    resized = Signal()

    def __init__(self):
        super().__init__()
        self._item_column_width = self.ITEM_COLUMN_WIDTH
        self._range: Optional[Tuple[float, float]] = None
        self._items: List[FeatureItem] = []
        self._variable_items: List[VariableItem] = []
        self._bottom_axis = AxisItem(parent=self,
                                     orientation="bottom",
                                     maxTickLength=7,
                                     pen=QPen(Qt.black))
        self._bottom_axis.setLabel(self.BOTTOM_AXIS_LABEL)
        self._vertical_line = QGraphicsLineItem(self._bottom_axis)
        self._vertical_line.setPen(QPen(Qt.gray))

        self._layout = QGraphicsGridLayout()
        self._layout.setVerticalSpacing(0)
        self.setLayout(self._layout)

        self.parameter_setter = BaseParameterSetter(self)

    @property
    def item_column_width(self) -> int:
        return self._item_column_width

    @item_column_width.setter
    def item_column_width(self, view_width: int):
        j = FeaturesPlot.LABEL_COLUMN
        w = max([
            self._layout.itemAt(i, j).item.boundingRect().width()
            for i in range(len(self._items))
        ] + [0])
        width = view_width - self.OFFSET - w
        self._item_column_width = max(self.ITEM_COLUMN_WIDTH, width)

    @property
    def x0_scaled(self) -> float:
        min_max = self._range[1] - self._range[0]
        return -self._range[0] * self.item_column_width / min_max

    @property
    def bottom_axis(self) -> AxisItem:
        return self._bottom_axis

    @property
    def labels(self) -> List[VariableItem]:
        return self._variable_items

    def set_data(self, x: np.ndarray, names: List[str], n_attrs: int,
                 view_width: int, *plot_args):
        self.item_column_width = view_width
        self._set_range(x, *plot_args)
        self._set_items(x, names, *plot_args)
        self._set_labels(names)
        self._set_bottom_axis()
        self.set_n_visible(n_attrs)

    def _set_range(self, *_):
        raise NotImplementedError

    def _set_items(self, *_):
        raise NotImplementedError

    def set_n_visible(self, n: int):
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            item.setVisible(i < n)
            text_item = self._layout.itemAt(i, FeaturesPlot.LABEL_COLUMN).item
            text_item.setVisible(i < n)
        self.set_vertical_line()

    def rescale(self, view_width: int):
        self.item_column_width = view_width
        for item in self._items:
            item.rescale(self.item_column_width)

        self._bottom_axis.setWidth(self.item_column_width)
        x = self.x0_scaled
        self._vertical_line.setLine(x, 0, x, self._vertical_line.line().y2())
        self.updateGeometry()

    def set_height(self, height: float):
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            item.set_height(height)
        self.set_vertical_line()
        self.updateGeometry()

    def _set_labels(self, labels: List[str]):
        for i, (label, _) in enumerate(zip(labels, self._items)):
            text = VariableItem(self, label)
            item = SimpleLayoutItem(text)
            item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
            self._layout.addItem(item, i, FeaturesPlot.LABEL_COLUMN,
                                 Qt.AlignRight | Qt.AlignVCenter)
            self._variable_items.append(item)

    def _set_bottom_axis(self):
        self._bottom_axis.setRange(*self._range)
        self._layout.addItem(self._bottom_axis, len(self._items),
                             FeaturesPlot.ITEM_COLUMN)

    def set_vertical_line(self):
        height = 0
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            text_item = self._layout.itemAt(i, FeaturesPlot.LABEL_COLUMN).item
            if item.isVisible():
                height += max(text_item.boundingRect().height(),
                              item.preferredSize().height())
        self._vertical_line.setLine(self.x0_scaled, 0, self.x0_scaled, -height)

    def deselect(self):
        self.selection_cleared.emit()

    def select(self, *args):
        self.selection_changed.emit(*args)

    def select_from_settings(self, *_):
        raise NotImplementedError

    def apply_visual_settings(self, settings: Dict):
        for key, value in settings.items():
            self.parameter_setter.set_parameter(key, value)