Exemplo n.º 1
0
class ViolinPlot(QGraphicsWidget):
    LABEL_COLUMN, VIOLIN_COLUMN, LEGEND_COLUMN = range(3)
    VIOLIN_COLUMN_WIDTH, OFFSET = 300, 80
    MAX_N_ITEMS = 100
    selection_cleared = Signal()
    selection_changed = Signal(float, float, str)
    resized = Signal()

    def __init__(self):
        super().__init__()
        self.__violin_column_width = self.VIOLIN_COLUMN_WIDTH  # type: int
        self.__range = None  # type: Optional[Tuple[float, float]]
        self.__violin_items = []  # type: List[ViolinItem]
        self.__variable_items = []  # type: List[VariableItem]
        self.__bottom_axis = AxisItem(parent=self, orientation="bottom",
                                      maxTickLength=7, pen=QPen(Qt.black))
        self.__bottom_axis.setLabel("Impact on model output")
        self.__vertical_line = QGraphicsLineItem(self.__bottom_axis)
        self.__vertical_line.setPen(QPen(Qt.gray))
        self.__legend = Legend(self)

        self.__layout = QGraphicsGridLayout()
        self.__layout.addItem(self.__legend, 0, ViolinPlot.LEGEND_COLUMN)
        self.__layout.setVerticalSpacing(0)
        self.setLayout(self.__layout)

        self.parameter_setter = ParameterSetter(self)

    @property
    def violin_column_width(self):
        return self.__violin_column_width

    @violin_column_width.setter
    def violin_column_width(self, view_width: int):
        j = ViolinPlot.LABEL_COLUMN
        w = max([self.__layout.itemAt(i, j).item.boundingRect().width()
                 for i in range(len(self.__violin_items))] + [0])
        width = view_width - self.legend.sizeHint().width() - self.OFFSET - w
        self.__violin_column_width = max(self.VIOLIN_COLUMN_WIDTH, width)

    @property
    def bottom_axis(self):
        return self.__bottom_axis

    @property
    def labels(self):
        return self.__variable_items

    @property
    def legend(self):
        return self.__legend

    def set_data(self, x: np.ndarray, colors: np.ndarray,
                 names: List[str], n_attrs: float, view_width: int):
        self.violin_column_width = view_width
        abs_max = np.max(np.abs(x)) * 1.05
        self.__range = (-abs_max, abs_max)
        self._set_violin_items(x, colors, names)
        self._set_labels(names)
        self._set_bottom_axis()
        self.set_n_visible(n_attrs)

    def set_n_visible(self, n: int):
        for i in range(len(self.__violin_items)):
            violin_item = self.__layout.itemAt(i, ViolinPlot.VIOLIN_COLUMN)
            violin_item.setVisible(i < n)
            text_item = self.__layout.itemAt(i, ViolinPlot.LABEL_COLUMN).item
            text_item.setVisible(i < n)
        self.set_vertical_line()

    def rescale(self, view_width: int):
        self.violin_column_width = view_width
        with temp_seed(0):
            for item in self.__violin_items:
                item.rescale(self.violin_column_width)

        self.__bottom_axis.setWidth(self.violin_column_width)
        x = self.violin_column_width / 2
        self.__vertical_line.setLine(x, 0, x, self.__vertical_line.line().y2())

    def show_legend(self, show: bool):
        self.__legend.setVisible(show)
        self.__bottom_axis.setWidth(self.violin_column_width)
        x = self.violin_column_width / 2
        self.__vertical_line.setLine(x, 0, x, self.__vertical_line.line().y2())

    def _set_violin_items(self, x: np.ndarray, colors: np.ndarray,
                          labels: List[str]):
        with temp_seed(0):
            for i in range(x.shape[1]):
                item = ViolinItem(self, labels[i], self.__range,
                                  self.violin_column_width)
                item.set_data(x[:, i], colors[:, i])
                item.selection_changed.connect(self.select)
                self.__violin_items.append(item)
                self.__layout.addItem(item, i, ViolinPlot.VIOLIN_COLUMN)
                if i == self.MAX_N_ITEMS:
                    break

    def _set_labels(self, labels: List[str]):
        for i, (label, _) in enumerate(zip(labels, self.__violin_items)):
            text = VariableItem(self, label)
            item = SimpleLayoutItem(text)
            item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
            self.__layout.addItem(item, i, ViolinPlot.LABEL_COLUMN,
                                  Qt.AlignRight | Qt.AlignVCenter)
            self.__variable_items.append(item)

    def _set_bottom_axis(self):
        self.__bottom_axis.setRange(*self.__range)
        self.__layout.addItem(self.__bottom_axis,
                              len(self.__violin_items),
                              ViolinPlot.VIOLIN_COLUMN)

    def set_vertical_line(self):
        x = self.violin_column_width / 2
        height = 0
        for i in range(len(self.__violin_items)):
            violin_item = self.__layout.itemAt(i, ViolinPlot.VIOLIN_COLUMN)
            text_item = self.__layout.itemAt(i, ViolinPlot.LABEL_COLUMN).item
            if violin_item.isVisible():
                height += max(text_item.boundingRect().height(),
                              violin_item.preferredSize().height())
        self.__vertical_line.setLine(x, 0, x, -height)

    def deselect(self):
        self.selection_cleared.emit()

    def select(self, *args):
        self.selection_changed.emit(*args)

    def select_from_settings(self, x1: float, x2: float, attr_name: str):
        point_r_diff = 2 * self.__range[1] / (self.violin_column_width / 2)
        for item in self.__violin_items:
            if item.attr_name == attr_name:
                item.add_selection_rect(x1 - point_r_diff, x2 + point_r_diff)
                break
        self.select(x1, x2, attr_name)

    def apply_visual_settings(self, settings: Dict):
        for key, value in settings.items():
            self.parameter_setter.set_parameter(key, value)
Exemplo n.º 2
0
class FeaturesPlot(QGraphicsWidget):
    BOTTOM_AXIS_LABEL = "Feature Importance"
    LABEL_COLUMN, ITEM_COLUMN = range(2)
    ITEM_COLUMN_WIDTH, OFFSET = 300, 80
    selection_cleared = Signal()
    selection_changed = Signal(object)
    resized = Signal()

    def __init__(self):
        super().__init__()
        self._item_column_width = self.ITEM_COLUMN_WIDTH
        self._range: Optional[Tuple[float, float]] = None
        self._items: List[FeatureItem] = []
        self._variable_items: List[VariableItem] = []
        self._bottom_axis = AxisItem(parent=self,
                                     orientation="bottom",
                                     maxTickLength=7,
                                     pen=QPen(Qt.black))
        self._bottom_axis.setLabel(self.BOTTOM_AXIS_LABEL)
        self._vertical_line = QGraphicsLineItem(self._bottom_axis)
        self._vertical_line.setPen(QPen(Qt.gray))

        self._layout = QGraphicsGridLayout()
        self._layout.setVerticalSpacing(0)
        self.setLayout(self._layout)

        self.parameter_setter = BaseParameterSetter(self)

    @property
    def item_column_width(self) -> int:
        return self._item_column_width

    @item_column_width.setter
    def item_column_width(self, view_width: int):
        j = FeaturesPlot.LABEL_COLUMN
        w = max([
            self._layout.itemAt(i, j).item.boundingRect().width()
            for i in range(len(self._items))
        ] + [0])
        width = view_width - self.OFFSET - w
        self._item_column_width = max(self.ITEM_COLUMN_WIDTH, width)

    @property
    def x0_scaled(self) -> float:
        min_max = self._range[1] - self._range[0]
        return -self._range[0] * self.item_column_width / min_max

    @property
    def bottom_axis(self) -> AxisItem:
        return self._bottom_axis

    @property
    def labels(self) -> List[VariableItem]:
        return self._variable_items

    def set_data(self, x: np.ndarray, names: List[str], n_attrs: int,
                 view_width: int, *plot_args):
        self.item_column_width = view_width
        self._set_range(x, *plot_args)
        self._set_items(x, names, *plot_args)
        self._set_labels(names)
        self._set_bottom_axis()
        self.set_n_visible(n_attrs)

    def _set_range(self, *_):
        raise NotImplementedError

    def _set_items(self, *_):
        raise NotImplementedError

    def set_n_visible(self, n: int):
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            item.setVisible(i < n)
            text_item = self._layout.itemAt(i, FeaturesPlot.LABEL_COLUMN).item
            text_item.setVisible(i < n)
        self.set_vertical_line()

    def rescale(self, view_width: int):
        self.item_column_width = view_width
        for item in self._items:
            item.rescale(self.item_column_width)

        self._bottom_axis.setWidth(self.item_column_width)
        x = self.x0_scaled
        self._vertical_line.setLine(x, 0, x, self._vertical_line.line().y2())
        self.updateGeometry()

    def set_height(self, height: float):
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            item.set_height(height)
        self.set_vertical_line()
        self.updateGeometry()

    def _set_labels(self, labels: List[str]):
        for i, (label, _) in enumerate(zip(labels, self._items)):
            text = VariableItem(self, label)
            item = SimpleLayoutItem(text)
            item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
            self._layout.addItem(item, i, FeaturesPlot.LABEL_COLUMN,
                                 Qt.AlignRight | Qt.AlignVCenter)
            self._variable_items.append(item)

    def _set_bottom_axis(self):
        self._bottom_axis.setRange(*self._range)
        self._layout.addItem(self._bottom_axis, len(self._items),
                             FeaturesPlot.ITEM_COLUMN)

    def set_vertical_line(self):
        height = 0
        for i in range(len(self._items)):
            item = self._layout.itemAt(i, FeaturesPlot.ITEM_COLUMN)
            text_item = self._layout.itemAt(i, FeaturesPlot.LABEL_COLUMN).item
            if item.isVisible():
                height += max(text_item.boundingRect().height(),
                              item.preferredSize().height())
        self._vertical_line.setLine(self.x0_scaled, 0, self.x0_scaled, -height)

    def deselect(self):
        self.selection_cleared.emit()

    def select(self, *args):
        self.selection_changed.emit(*args)

    def select_from_settings(self, *_):
        raise NotImplementedError

    def apply_visual_settings(self, settings: Dict):
        for key, value in settings.items():
            self.parameter_setter.set_parameter(key, value)