def setUp(self):
     self.domain = Domain(
         attributes=[
             ContinuousVariable("c1"),
             DiscreteVariable("d1", values="abc"),
             DiscreteVariable("d2", values="def"),
         ],
         class_vars=[DiscreteVariable("d3", values="ghi")],
         metas=[
             ContinuousVariable("c2"),
             DiscreteVariable("d4", values="jkl")
         ],
     )
     self.args = (
         self.domain,
         {
             "c1": Continuous,
             "d1": Discrete,
             "d2": Discrete,
             "d3": Discrete
         },
         {
             "c2": Continuous,
             "d4": Discrete
         },
     )
     self.handler = ClassValuesContextHandler()
     self.handler.read_defaults = lambda: None
Esempio n. 2
0
 def setUp(self):
     self.domain = Domain(
         attributes=[ContinuousVariable('c1'),
                     DiscreteVariable('d1', values='abc'),
                     DiscreteVariable('d2', values='def')],
         class_vars=[DiscreteVariable('d3', values='ghi')],
         metas=[ContinuousVariable('c2'),
                DiscreteVariable('d4', values='jkl')]
     )
     self.args = (self.domain,
                  {'c1': Continuous, 'd1': Discrete,
                   'd2': Discrete, 'd3': Discrete},
                  {'c2': Continuous, 'd4': Discrete, })
     self.handler = ClassValuesContextHandler()
     self.handler.read_defaults = lambda: None
 def setUp(self):
     self.domain = Domain(
         attributes=[ContinuousVariable('c1'),
                     DiscreteVariable('d1', values='abc'),
                     DiscreteVariable('d2', values='def')],
         class_vars=[DiscreteVariable('d3', values='ghi')],
         metas=[ContinuousVariable('c2'),
                DiscreteVariable('d4', values='jkl')]
     )
     self.args = (self.domain,
                  {'c1': Continuous, 'd1': Discrete,
                   'd2': Discrete, 'd3': Discrete},
                  {'c2': Continuous, 'd4': Discrete, })
     self.handler = ClassValuesContextHandler()
     self.handler.read_defaults = lambda: None
Esempio n. 4
0
class OWConfusionMatrix(widget.OWWidget):
    """Confusion matrix widget"""

    name = "Confusion Matrix"
    description = "Display a confusion matrix constructed from " \
                  "the results of classifier evaluations."
    icon = "icons/ConfusionMatrix.svg"
    priority = 1001
    keywords = []

    class Inputs:
        evaluation_results = Input("Evaluation Results",
                                   Orange.evaluation.Results)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    quantities = [
        "Number of instances", "Proportion of predicted",
        "Proportion of actual"
    ]

    settings_version = 1
    settingsHandler = ClassValuesContextHandler()

    selected_learner = Setting([0], schema_only=True)
    selection = ContextSetting(set())
    selected_quantity = Setting(0)
    append_predictions = Setting(True)
    append_probabilities = Setting(False)
    autocommit = Setting(True)

    UserAdviceMessages = [
        widget.Message(
            "Clicking on cells or in headers outputs the corresponding "
            "data instances", "click_cell")
    ]

    class Error(widget.OWWidget.Error):
        no_regression = Msg("Confusion Matrix cannot show regression results.")
        invalid_values = Msg(
            "Evaluation Results input contains invalid values")
        empty_input = widget.Msg("Empty result on input. Nothing to display.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.results = None
        self.learners = []
        self.headers = []

        self.learners_box = gui.listBox(self.controlArea,
                                        self,
                                        "selected_learner",
                                        "learners",
                                        box=True,
                                        callback=self._learner_changed)

        self.outputbox = gui.vBox(self.controlArea, "Output")
        box = gui.hBox(self.outputbox)
        gui.checkBox(box,
                     self,
                     "append_predictions",
                     "Predictions",
                     callback=self._invalidate)
        gui.checkBox(box,
                     self,
                     "append_probabilities",
                     "Probabilities",
                     callback=self._invalidate)

        gui.auto_apply(self.outputbox, self, "autocommit", box=False)

        self.info.set_output_summary(self.info.NoOutput)

        self.mainArea.layout().setContentsMargins(0, 0, 0, 0)

        box = gui.vBox(self.mainArea, box=True)

        sbox = gui.hBox(box)
        gui.rubber(sbox)
        gui.comboBox(sbox,
                     self,
                     "selected_quantity",
                     items=self.quantities,
                     label="Show: ",
                     orientation=Qt.Horizontal,
                     callback=self._update)

        self.tablemodel = QStandardItemModel(self)
        view = self.tableview = QTableView(
            editTriggers=QTableView.NoEditTriggers)
        view.setModel(self.tablemodel)
        view.horizontalHeader().hide()
        view.verticalHeader().hide()
        view.horizontalHeader().setMinimumSectionSize(60)
        view.selectionModel().selectionChanged.connect(self._invalidate)
        view.setShowGrid(False)
        view.setItemDelegate(BorderedItemDelegate(Qt.white))
        view.setSizePolicy(QSizePolicy.MinimumExpanding,
                           QSizePolicy.MinimumExpanding)
        view.clicked.connect(self.cell_clicked)
        box.layout().addWidget(view)

        selbox = gui.hBox(box)
        gui.button(selbox,
                   self,
                   "Select Correct",
                   callback=self.select_correct,
                   autoDefault=False)
        gui.button(selbox,
                   self,
                   "Select Misclassified",
                   callback=self.select_wrong,
                   autoDefault=False)
        gui.button(selbox,
                   self,
                   "Clear Selection",
                   callback=self.select_none,
                   autoDefault=False)

    @staticmethod
    def sizeHint():
        """Initial size"""
        return QSize(750, 340)

    def _item(self, i, j):
        return self.tablemodel.item(i, j) or QStandardItem()

    def _set_item(self, i, j, item):
        self.tablemodel.setItem(i, j, item)

    def _init_table(self, nclasses):
        item = self._item(0, 2)
        item.setData("Predicted", Qt.DisplayRole)
        item.setTextAlignment(Qt.AlignCenter)
        item.setFlags(Qt.NoItemFlags)

        self._set_item(0, 2, item)
        item = self._item(2, 0)
        item.setData("Actual", Qt.DisplayRole)
        item.setTextAlignment(Qt.AlignHCenter | Qt.AlignBottom)
        item.setFlags(Qt.NoItemFlags)
        self.tableview.setItemDelegateForColumn(0, gui.VerticalItemDelegate())
        self._set_item(2, 0, item)
        self.tableview.setSpan(0, 2, 1, nclasses)
        self.tableview.setSpan(2, 0, nclasses, 1)

        font = self.tablemodel.invisibleRootItem().font()
        bold_font = QFont(font)
        bold_font.setBold(True)

        for i in (0, 1):
            for j in (0, 1):
                item = self._item(i, j)
                item.setFlags(Qt.NoItemFlags)
                self._set_item(i, j, item)

        for p, label in enumerate(self.headers):
            for i, j in ((1, p + 2), (p + 2, 1)):
                item = self._item(i, j)
                item.setData(label, Qt.DisplayRole)
                item.setFont(bold_font)
                item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                item.setFlags(Qt.ItemIsEnabled)
                if p < len(self.headers) - 1:
                    item.setData("br"[j == 1], BorderRole)
                    item.setData(QColor(192, 192, 192), BorderColorRole)
                self._set_item(i, j, item)

        hor_header = self.tableview.horizontalHeader()
        if len(' '.join(self.headers)) < 120:
            hor_header.setSectionResizeMode(QHeaderView.ResizeToContents)
        else:
            hor_header.setDefaultSectionSize(60)
        self.tablemodel.setRowCount(nclasses + 3)
        self.tablemodel.setColumnCount(nclasses + 3)

    @Inputs.evaluation_results
    def set_results(self, results):
        """Set the input results."""
        # false positive, pylint: disable=no-member
        prev_sel_learner = self.selected_learner.copy()
        self.clear()
        self.warning()
        self.closeContext()

        data = None
        if results is not None and results.data is not None:
            data = results.data[results.row_indices]

        self.Error.no_regression.clear()
        self.Error.empty_input.clear()
        if data is not None and not data.domain.has_discrete_class:
            self.Error.no_regression()
            data = results = None
        elif results is not None and not results.actual.size:
            self.Error.empty_input()
            data = results = None

        nan_values = False
        if results is not None:
            assert isinstance(results, Orange.evaluation.Results)
            if np.any(np.isnan(results.actual)) or \
                    np.any(np.isnan(results.predicted)):
                # Error out here (could filter them out with a warning
                # instead).
                nan_values = True
                results = data = None

        self.Error.invalid_values(shown=nan_values)

        self.results = results
        self.data = data

        if data is not None:
            class_values = data.domain.class_var.values
        elif results is not None:
            raise NotImplementedError

        if results is None:
            self.report_button.setDisabled(True)
            return

        self.report_button.setDisabled(False)

        nmodels = results.predicted.shape[0]
        self.headers = class_values + \
                       (unicodedata.lookup("N-ARY SUMMATION"), )

        # NOTE: The 'learner_names' is set in 'Test Learners' widget.
        self.learners = getattr(results, "learner_names",
                                [f"Learner #{i + 1}" for i in range(nmodels)])

        self._init_table(len(class_values))
        self.openContext(data.domain.class_var)
        if not prev_sel_learner or prev_sel_learner[0] >= len(self.learners):
            if self.learners:
                self.selected_learner[:] = [0]
        else:
            self.selected_learner[:] = prev_sel_learner
        self._update()
        self._set_selection()
        self.unconditional_commit()

    def clear(self):
        """Reset the widget, clear controls"""
        self.results = None
        self.data = None
        self.tablemodel.clear()
        self.headers = []
        # Clear learners last. This action will invoke `_learner_changed`
        self.learners = []

    def select_correct(self):
        """Select the diagonal elements of the matrix"""
        selection = QItemSelection()
        n = self.tablemodel.rowCount()
        for i in range(2, n):
            index = self.tablemodel.index(i, i)
            selection.select(index, index)
        self.tableview.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    def select_wrong(self):
        """Select the off-diagonal elements of the matrix"""
        selection = QItemSelection()
        n = self.tablemodel.rowCount()
        for i in range(2, n):
            for j in range(i + 1, n):
                index = self.tablemodel.index(i, j)
                selection.select(index, index)
                index = self.tablemodel.index(j, i)
                selection.select(index, index)
        self.tableview.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    def select_none(self):
        """Reset selection"""
        self.tableview.selectionModel().clear()

    def cell_clicked(self, model_index):
        """Handle cell click event"""
        i, j = model_index.row(), model_index.column()
        if not i or not j:
            return
        n = self.tablemodel.rowCount()
        index = self.tablemodel.index
        selection = None
        if i == j == 1 or i == j == n - 1:
            selection = QItemSelection(index(2, 2), index(n - 1, n - 1))
        elif i in (1, n - 1):
            selection = QItemSelection(index(2, j), index(n - 1, j))
        elif j in (1, n - 1):
            selection = QItemSelection(index(i, 2), index(i, n - 1))

        if selection is not None:
            self.tableview.selectionModel().select(
                selection, QItemSelectionModel.ClearAndSelect)

    def _prepare_data(self):
        indices = self.tableview.selectedIndexes()
        indices = {(ind.row() - 2, ind.column() - 2) for ind in indices}
        actual = self.results.actual
        learner_name = self.learners[self.selected_learner[0]]
        predicted = self.results.predicted[self.selected_learner[0]]
        selected = [
            i for i, t in enumerate(zip(actual, predicted)) if t in indices
        ]

        extra = []
        class_var = self.data.domain.class_var
        metas = self.data.domain.metas

        if self.append_predictions:
            extra.append(predicted.reshape(-1, 1))
            var = Orange.data.DiscreteVariable(
                "{}({})".format(class_var.name, learner_name),
                class_var.values)
            metas = metas + (var, )

        if self.append_probabilities and \
                        self.results.probabilities is not None:
            probs = self.results.probabilities[self.selected_learner[0]]
            extra.append(np.array(probs, dtype=object))
            pvars = [
                Orange.data.ContinuousVariable("p({})".format(value))
                for value in class_var.values
            ]
            metas = metas + tuple(pvars)

        domain = Orange.data.Domain(self.data.domain.attributes,
                                    self.data.domain.class_vars, metas)
        data = self.data.transform(domain)
        if extra:
            data.metas[:, len(self.data.domain.metas):] = \
                np.hstack(tuple(extra))
        data.name = learner_name

        if selected:
            annotated_data = create_annotated_table(data, selected)
            data = data[selected]
        else:
            annotated_data = create_annotated_table(data, [])
            data = None

        return data, annotated_data

    def commit(self):
        """Output data instances corresponding to selected cells"""
        if self.results is not None and self.data is not None \
                and self.selected_learner:
            data, annotated_data = self._prepare_data()
        else:
            data = None
            annotated_data = None

        summary = len(data) if data else self.info.NoOutput
        details = format_summary_details(data) if data else ""
        self.info.set_output_summary(summary, details)

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(annotated_data)

    def _invalidate(self):
        indices = self.tableview.selectedIndexes()
        self.selection = {(ind.row() - 2, ind.column() - 2) for ind in indices}
        self.commit()

    def _set_selection(self):
        selection = QItemSelection()
        index = self.tableview.model().index
        for row, col in self.selection:
            sel = index(row + 2, col + 2)
            selection.select(sel, sel)
        self.tableview.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    def _learner_changed(self):
        self._update()
        self._set_selection()
        self.commit()

    def _update(self):
        def _isinvalid(x):
            return isnan(x) or isinf(x)

        # Update the displayed confusion matrix
        if self.results is not None and self.selected_learner:
            cmatrix = confusion_matrix(self.results, self.selected_learner[0])
            colsum = cmatrix.sum(axis=0)
            rowsum = cmatrix.sum(axis=1)
            n = len(cmatrix)
            diag = np.diag_indices(n)

            colors = cmatrix.astype(np.double)
            colors[diag] = 0
            if self.selected_quantity == 0:
                normalized = cmatrix.astype(np.int)
                formatstr = "{}"
                div = np.array([colors.max()])
            else:
                if self.selected_quantity == 1:
                    normalized = 100 * cmatrix / colsum
                    div = colors.max(axis=0)
                else:
                    normalized = 100 * cmatrix / rowsum[:, np.newaxis]
                    div = colors.max(axis=1)[:, np.newaxis]
                formatstr = "{:2.1f} %"
            div[div == 0] = 1
            colors /= div
            maxval = normalized[diag].max()
            if maxval > 0:
                colors[diag] = normalized[diag] / maxval

            for i in range(n):
                for j in range(n):
                    val = normalized[i, j]
                    col_val = colors[i, j]
                    item = self._item(i + 2, j + 2)
                    item.setData(
                        "NA" if _isinvalid(val) else formatstr.format(val),
                        Qt.DisplayRole)
                    bkcolor = QColor.fromHsl(
                        [0, 240][i == j], 160,
                        255 if _isinvalid(col_val) else int(255 -
                                                            30 * col_val))
                    item.setData(QBrush(bkcolor), Qt.BackgroundRole)
                    item.setData("trbl", BorderRole)
                    item.setToolTip("actual: {}\npredicted: {}".format(
                        self.headers[i], self.headers[j]))
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    item.setFlags(Qt.ItemIsEnabled | Qt.ItemIsSelectable)
                    self._set_item(i + 2, j + 2, item)

            bold_font = self.tablemodel.invisibleRootItem().font()
            bold_font.setBold(True)

            def _sum_item(value, border=""):
                item = QStandardItem()
                item.setData(value, Qt.DisplayRole)
                item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                item.setFlags(Qt.ItemIsEnabled)
                item.setFont(bold_font)
                item.setData(border, BorderRole)
                item.setData(QColor(192, 192, 192), BorderColorRole)
                return item

            for i in range(n):
                self._set_item(n + 2, i + 2, _sum_item(int(colsum[i]), "t"))
                self._set_item(i + 2, n + 2, _sum_item(int(rowsum[i]), "l"))
            self._set_item(n + 2, n + 2, _sum_item(int(rowsum.sum())))

    def send_report(self):
        """Send report"""
        if self.results is not None and self.selected_learner:
            self.report_table(
                "Confusion matrix for {} (showing {})".format(
                    self.learners[self.selected_learner[0]],
                    self.quantities[self.selected_quantity].lower()),
                self.tableview)

    @classmethod
    def migrate_settings(cls, settings, version):
        if not version:
            # For some period of time the 'selected_learner' property was
            # changed from List[int] -> int
            # (commit 4e49bb3fd0e11262f3ebf4b1116a91a4b49cc982) and then back
            # again (commit 8a492d79a2e17154a0881e24a05843406c8892c0)
            if "selected_learner" in settings and \
                    isinstance(settings["selected_learner"], int):
                settings["selected_learner"] = [settings["selected_learner"]]
Esempio n. 5
0
 def test_migrate_removes_invalid_contexts(self):
     context_invalid = ClassValuesContextHandler().new_context([0, 1, 2])
     context_valid = PerfectDomainContextHandler().new_context(*[[]] * 4)
     settings = {'context_settings': [context_invalid, context_valid]}
     self.widget.migrate_settings(settings, 2)
     self.assertEqual(settings['context_settings'], [context_valid])
Esempio n. 6
0
class OWClassificationTreeGraph(OWTreeViewer2D):
    name = "Classification Tree Viewer"
    description = "Classification Tree Viewer"
    icon = "icons/ClassificationTree.svg"

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    color_settings = Setting(None)
    selected_color_settings_index = Setting(0)

    inputs = [("Classification Tree", ClassificationTreeClassifier, "ctree")]
    outputs = [("Data", Table)]

    def __init__(self):
        super().__init__()
        self.domain = None
        self.classifier = None
        self.dataset = None

        self.scene = TreeGraphicsScene(self)
        self.scene_view = TreeGraphicsView(self.scene)
        self.scene_view.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
        self.mainArea.layout().addWidget(self.scene_view)
        self.toggle_zoom_slider()
        self.scene.selectionChanged.connect(self.update_selection)

        box = gui.widgetBox(self.controlArea, "Nodes", addSpace=True)
        self.target_combo = gui.comboBox(box,
                                         self,
                                         "target_class_index",
                                         orientation=0,
                                         items=[],
                                         label="Target class",
                                         callback=self.toggle_target_class)
        gui.separator(box)
        gui.button(box, self, "Set Colors", callback=self.set_colors)
        dlg = self.create_color_dialog()
        self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
        gui.rubber(self.controlArea)

    def sendReport(self):
        if self.tree:
            tclass = str(self.targetCombo.currentText())
            tsize = "%i nodes, %i leaves" % (orngTree.countNodes(
                self.tree), orngTree.countLeaves(self.tree))
        else:
            tclass = tsize = "N/A"
        self.reportSettings("Information",
                            [("Target class", tclass),
                             ("Line widths", [
                                 "Constant", "Proportion of all instances",
                                 "Proportion of parent's instances"
                             ][self.line_width_method]), ("Tree size", tsize)])
        super().sendReport()

    def set_colors(self):
        dlg = self.create_color_dialog()
        if dlg.exec_():
            self.color_settings = dlg.getColorSchemas()
            self.selected_color_settings_index = dlg.selectedSchemaIndex
            self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
            self.scene.update()
            self.toggle_node_color()

    def create_color_dialog(self):
        c = ColorPaletteDlg(self, "Color Palette")
        c.createDiscretePalette("colorPalette", "Discrete Palette")
        c.setColorSchemas(self.color_settings,
                          self.selected_color_settings_index)
        return c

    def set_node_info(self):
        for node in self.scene.nodes():
            node.set_rect(QRectF())
            self.update_node_info(node)
        w = max([n.rect().width() for n in self.scene.nodes()] + [0])
        if w > self.max_node_width < 200:
            w = self.max_node_width
        for node in self.scene.nodes():
            node.set_rect(
                QRectF(node.rect().x(),
                       node.rect().y(), w,
                       node.rect().height()))
        self.scene.fix_pos(self.root_node, 10, 10)

    def update_node_info(self, node):
        distr = node.get_distribution()
        total = int(node.num_instances())
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = node.majority()
            tabs = distr[modus]
            text = self.domain.class_vars[0].values[modus] + "<br/>"
        if tabs > 0.999:
            text += "100%, {}/{}".format(total, total)
        else:
            text += "{:2.1f}%, {}/{}".format(100 * tabs, int(total * tabs),
                                             total)
        if not node.is_leaf():
            text += "<hr/>{}".format(
                self.domain.attributes[node.attribute()].name)
        node.setHtml('<center><p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p></center>'.format(text))

    def activate_loaded_settings(self):
        if not self.tree:
            return
        super().activate_loaded_settings()
        self.set_node_info()
        self.toggle_node_color()

    def toggle_node_size(self):
        self.set_node_info()
        self.scene.update()
        self.scene_view.repaint()

    def toggle_node_color(self):
        palette = self.scene.colorPalette
        for node in self.scene.nodes():
            distr = node.get_distribution()
            total = sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = palette[self.target_class_index].light(200 - 100 * p)
            else:
                modus = node.majority()
                p = distr[modus] / total
                color = palette[int(modus)].light(400 - 300 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()

    def toggle_target_class(self):
        self.toggle_node_color()
        self.set_node_info()
        self.scene.update()

    def ctree(self, clf=None):
        self.clear()
        self.closeContext()
        self.classifier = clf
        if clf is None:
            self.info.setText('No tree.')
            self.tree = None
            self.root_node = None
            self.dataset = None
        else:
            self.tree = clf.clf.tree_
            self.domain = clf.domain
            self.dataset = getattr(clf, "instances", None)
            self.target_combo.clear()
            self.target_combo.addItem("None")
            self.target_combo.addItems(self.domain.class_vars[0].values)
            self.target_class_index = 0
            self.openContext(self.domain.class_var)
            self.root_node = self.walkcreate(self.tree, None, distr=clf.distr)
            self.info.setText('{} nodes, {} leaves'.format(
                self.root_node.num_nodes(), self.root_node.num_leaves()))
            self.scene.fix_pos(self.root_node, self._HSPACING, self._VSPACING)
            self.activate_loaded_settings()
            self.scene_view.centerOn(self.root_node.x(), self.root_node.y())
            self.update_node_tooltips()
        self.scene.update()

    def walkcreate(self, tree, parent=None, level=0, i=0, distr=None):
        node = ClassificationTreeNode(tree,
                                      self.domain,
                                      parent,
                                      None,
                                      self.scene,
                                      i=i,
                                      distr=distr[i])
        if parent:
            parent.graph_add_edge(
                GraphicsEdge(None, self.scene, node1=parent, node2=node))
        left_child_index = tree.children_left[i]
        right_child_index = tree.children_right[i]
        if left_child_index >= 0:
            self.walkcreate(tree,
                            parent=node,
                            level=level + 1,
                            i=left_child_index,
                            distr=distr)
        if right_child_index >= 0:
            self.walkcreate(tree,
                            parent=node,
                            level=level + 1,
                            i=right_child_index,
                            distr=distr)
        return node

    def node_tooltip(self, node):
        if node.i > 0:
            text = "<br/> AND ".join("%s %s %.3f" %
                                     (self.domain.attributes[a].name, s, t)
                                     for a, s, t in node.rule())
        else:
            text = "Root"
        return text

    def update_selection(self):
        if self.dataset is None:
            return

        items = self.scene.selectedItems()
        items = [
            item for item in items if isinstance(item, ClassificationTreeNode)
        ]
        if items:
            indices = [self.classifier.get_items(item.i) for item in items]
            indices = numpy.r_[indices]
            indices = numpy.unique(indices)
        else:
            indices = []

        if len(indices):
            data = self.dataset[indices]
        else:
            data = None
        self.send("Data", data)
Esempio n. 7
0
class OWNomogram(OWWidget):
    name = "Nomogram"
    description = " Nomograms for Visualization of Naive Bayesian" \
                  " and Logistic Regression Classifiers."
    icon = "icons/Nomogram.svg"
    priority = 2000

    inputs = [("Classifier", Model, "set_classifier"),
              ("Data", Table, "set_instance")]

    MAX_N_ATTRS = 1000
    POINT_SCALE = 0
    ALIGN_LEFT = 0
    ALIGN_ZERO = 1
    ACCEPTABLE = (NaiveBayesModel, LogisticRegressionClassifier)
    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    normalize_probabilities = Setting(False)
    scale = Setting(1)
    display_index = Setting(1)
    n_attributes = Setting(10)
    sort_index = Setting(SortBy.ABSOLUTE)
    cont_feature_dim_index = Setting(0)

    graph_name = "scene"

    class Error(OWWidget.Error):
        invalid_classifier = Msg("Nomogram accepts only Naive Bayes and "
                                 "Logistic Regression classifiers.")

    def __init__(self):
        super().__init__()
        self.instances = None
        self.domain = None
        self.data = None
        self.classifier = None
        self.align = OWNomogram.ALIGN_ZERO
        self.log_odds_ratios = []
        self.log_reg_coeffs = []
        self.log_reg_coeffs_orig = []
        self.log_reg_cont_data_extremes = []
        self.p = None
        self.b0 = None
        self.points = []
        self.feature_items = []
        self.feature_marker_values = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.old_target_class_index = self.target_class_index
        self.markers_set = False
        self.repaint = False

        # GUI
        box = gui.vBox(self.controlArea, "Target class")
        self.class_combo = gui.comboBox(box,
                                        self,
                                        "target_class_index",
                                        callback=self._class_combo_changed,
                                        contentsLength=12)
        self.norm_check = gui.checkBox(
            box,
            self,
            "normalize_probabilities",
            "Normalize probabilities",
            hidden=True,
            callback=self._norm_check_changed,
            tooltip="For multiclass data 1 vs. all probabilities do not"
            " sum to 1 and therefore could be normalized.")

        self.scale_radio = gui.radioButtons(
            self.controlArea,
            self,
            "scale", ["Point scale", "Log odds ratios"],
            box="Scale",
            callback=self._radio_button_changed)

        box = gui.vBox(self.controlArea, "Display features")
        grid = QGridLayout()
        self.display_radio = gui.radioButtonsInBox(
            box,
            self,
            "display_index", [],
            orientation=grid,
            callback=self._display_radio_button_changed)
        radio_all = gui.appendRadioButton(self.display_radio,
                                          "All:",
                                          addToLayout=False)
        radio_best = gui.appendRadioButton(self.display_radio,
                                           "Best ranked:",
                                           addToLayout=False)
        spin_box = gui.hBox(None, margin=0)
        self.n_spin = gui.spin(spin_box,
                               self,
                               "n_attributes",
                               1,
                               self.MAX_N_ATTRS,
                               label=" ",
                               controlWidth=60,
                               callback=self._n_spin_changed)
        grid.addWidget(radio_all, 1, 1)
        grid.addWidget(radio_best, 2, 1)
        grid.addWidget(spin_box, 2, 2)

        self.sort_combo = gui.comboBox(box,
                                       self,
                                       "sort_index",
                                       label="Sort by: ",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._sort_combo_changed)

        self.cont_feature_dim_combo = gui.comboBox(
            box,
            self,
            "cont_feature_dim_index",
            label="Continuous features: ",
            items=["1D projection", "2D curve"],
            orientation=Qt.Horizontal,
            callback=self._cont_feature_dim_combo_changed)

        gui.rubber(self.controlArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(
            self.scene,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            renderHints=QPainter.Antialiasing | QPainter.TextAntialiasing
            | QPainter.SmoothPixmapTransform,
            alignment=Qt.AlignLeft)
        self.view.viewport().installEventFilter(self)
        self.view.viewport().setMinimumWidth(300)
        self.view.sizeHint = lambda: QSize(600, 500)
        self.mainArea.layout().addWidget(self.view)

    def _class_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        coeffs = [
            np.nan_to_num(p[self.target_class_index] /
                          p[self.old_target_class_index]) for p in self.points
        ]
        points = [p[self.old_target_class_index] for p in self.points]
        self.feature_marker_values = [
            self.get_points_from_coeffs(v, c, p)
            for (v, c, p) in zip(self.feature_marker_values, coeffs, points)
        ]
        self.update_scene()
        self.old_target_class_index = self.target_class_index

    def _norm_check_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _radio_button_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _display_radio_button_changed(self):
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _n_spin_changed(self):
        self.display_index = 1
        self.__hide_attrs(self.n_attributes)

    def __hide_attrs(self, n_show):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(n_show)
        if self.vertical_line:
            x = self.vertical_line.line().x1()
            y = self.nomogram_main.layout.preferredHeight() + 30
            self.vertical_line.setLine(x, -6, x, y)
            self.hidden_vertical_line.setLine(x, -6, x, y)
        rect = QRectF(self.scene.sceneRect().x(),
                      self.scene.sceneRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def _sort_combo_changed(self):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(None)
        self.nomogram_main.sort(self.sort_index)
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _cont_feature_dim_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def eventFilter(self, obj, event):
        if obj is self.view.viewport() and event.type() == QEvent.Resize:
            self.repaint = True
            values = [item.dot.value for item in self.feature_items]
            self.feature_marker_values = self.scale_back(values)
            self.update_scene()
        return super().eventFilter(obj, event)

    def update_controls(self):
        self.class_combo.clear()
        self.norm_check.setHidden(True)
        self.cont_feature_dim_combo.setEnabled(True)
        if self.domain:
            self.class_combo.addItems(self.domain.class_vars[0].values)
            if len(self.domain.attributes) > self.MAX_N_ATTRS:
                self.display_index = 1
            if len(self.domain.class_vars[0].values) > 2:
                self.norm_check.setHidden(False)
            if not self.domain.has_continuous_attributes():
                self.cont_feature_dim_combo.setEnabled(False)
                self.cont_feature_dim_index = 0
        model = self.sort_combo.model()
        item = model.item(SortBy.POSITIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        item = model.item(SortBy.NEGATIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        self.align = OWNomogram.ALIGN_ZERO
        if self.classifier and isinstance(self.classifier,
                                          LogisticRegressionClassifier):
            self.align = OWNomogram.ALIGN_LEFT
            item = model.item(SortBy.POSITIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            item = model.item(SortBy.NEGATIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sort_index in (SortBy.POSITIVE, SortBy.POSITIVE):
                self.sort_index = SortBy.NO_SORTING

    def set_instance(self, data):
        self.instances = data
        self.feature_marker_values = []
        self.set_feature_marker_values()

    def set_classifier(self, classifier):
        self.closeContext()
        self.classifier = classifier
        self.Error.clear()
        if self.classifier and not isinstance(self.classifier,
                                              self.ACCEPTABLE):
            self.Error.invalid_classifier()
            self.classifier = None
        self.domain = self.classifier.domain if self.classifier else None
        self.data = None
        self.calculate_log_odds_ratios()
        self.calculate_log_reg_coefficients()
        self.update_controls()
        self.target_class_index = 0
        self.openContext(self.domain and self.domain.class_var)
        self.points = self.log_odds_ratios or self.log_reg_coeffs
        self.feature_marker_values = []
        self.old_target_class_index = self.target_class_index
        self.update_scene()

    def calculate_log_odds_ratios(self):
        self.log_odds_ratios = []
        self.p = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, NaiveBayesModel):
            return

        log_cont_prob = self.classifier.log_cont_prob
        class_prob = self.classifier.class_prob
        for i in range(len(self.domain.attributes)):
            ca = np.exp(log_cont_prob[i]) * class_prob[:, None]
            _or = (ca / (1 - ca)) / (class_prob / (1 - class_prob))[:, None]
            self.log_odds_ratios.append(np.log(_or))
        self.p = class_prob

    def calculate_log_reg_coefficients(self):
        self.log_reg_coeffs = []
        self.log_reg_cont_data_extremes = []
        self.b0 = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, LogisticRegressionClassifier):
            return

        self.domain = self.reconstruct_domain(self.classifier.original_domain,
                                              self.domain)
        self.data = self.classifier.original_data.transform(self.domain)
        attrs, ranges, start = self.domain.attributes, [], 0
        for attr in attrs:
            stop = start + len(attr.values) if attr.is_discrete else start + 1
            ranges.append(slice(start, stop))
            start = stop

        self.b0 = self.classifier.intercept
        coeffs = self.classifier.coefficients
        if len(self.domain.class_var.values) == 2:
            self.b0 = np.hstack((self.b0 * (-1), self.b0))
            coeffs = np.vstack((coeffs * (-1), coeffs))
        self.log_reg_coeffs = [coeffs[:, ranges[i]] for i in range(len(attrs))]
        self.log_reg_coeffs_orig = self.log_reg_coeffs.copy()

        min_values = nanmin(self.data.X, axis=0)
        max_values = nanmax(self.data.X, axis=0)

        for i, min_t, max_t in zip(range(len(self.log_reg_coeffs)), min_values,
                                   max_values):
            if self.log_reg_coeffs[i].shape[1] == 1:
                coef = self.log_reg_coeffs[i]
                self.log_reg_coeffs[i] = np.hstack(
                    (coef * min_t, coef * max_t))
                self.log_reg_cont_data_extremes.append(
                    [sorted([min_t, max_t], reverse=(c < 0)) for c in coef])
            else:
                self.log_reg_cont_data_extremes.append([None])

    def update_scene(self):
        if not self.repaint:
            return
        self.clear_scene()
        if self.domain is None or not len(self.points[0]):
            return

        name_items = [
            QGraphicsTextItem(a.name) for a in self.domain.attributes
        ]
        point_text = QGraphicsTextItem("Points")
        probs_text = QGraphicsTextItem("Probabilities (%)")
        all_items = name_items + [point_text, probs_text]
        name_offset = -max(t.boundingRect().width() for t in all_items) - 50
        w = self.view.viewport().rect().width()
        max_width = w + name_offset - 100

        points = [pts[self.target_class_index] for pts in self.points]
        minimums = [min(p) for p in points]
        if self.align == OWNomogram.ALIGN_LEFT:
            points = [p - m for m, p in zip(minimums, points)]
        max_ = np.nan_to_num(max(max(abs(p)) for p in points))
        d = 100 / max_ if max_ else 1
        if self.scale == OWNomogram.POINT_SCALE:
            points = [p * d for p in points]

        if self.scale == OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [
                p / d + m for m, p in zip(minimums, x)
            ]
            self.scale_forth = lambda x: [(p - m) * d
                                          for m, p in zip(minimums, x)]
        if self.scale == OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p / d for p in x]
            self.scale_forth = lambda x: [p * d for p in x]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p + m for m, p in zip(minimums, x)]
            self.scale_forth = lambda x: [p - m for m, p in zip(minimums, x)]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: x
            self.scale_forth = lambda x: x

        point_item, nomogram_head = self.create_main_nomogram(
            name_items, points, max_width, point_text, name_offset)
        probs_item, nomogram_foot = self.create_footer_nomogram(
            probs_text, d, minimums, max_width, name_offset)
        for item in self.feature_items:
            item.dot.point_dot = point_item.dot
            item.dot.probs_dot = probs_item.dot
            item.dot.vertical_line = self.hidden_vertical_line

        self.nomogram = nomogram = NomogramItem()
        nomogram.add_items([nomogram_head, self.nomogram_main, nomogram_foot])
        self.scene.addItem(nomogram)
        self.set_feature_marker_values()
        rect = QRectF(self.scene.itemsBoundingRect().x(),
                      self.scene.itemsBoundingRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def create_main_nomogram(self, name_items, points, max_width, point_text,
                             name_offset):
        cls_index = self.target_class_index
        min_p = min(min(p) for p in points)
        max_p = max(max(p) for p in points)
        values = self.get_ruler_values(min_p, max_p, max_width)
        min_p, max_p = min(values), max(values)
        diff_ = np.nan_to_num(max_p - min_p)
        scale_x = max_width / diff_ if diff_ else max_width

        nomogram_header = NomogramItem()
        point_item = RulerItem(point_text, values, scale_x, name_offset,
                               -scale_x * min_p)
        point_item.setPreferredSize(point_item.preferredWidth(), 35)
        nomogram_header.add_items([point_item])

        self.nomogram_main = SortableNomogramItem()
        cont_feature_item_class = ContinuousFeature2DItem if \
            self.cont_feature_dim_index else ContinuousFeatureItem
        self.feature_items = [
            DiscreteFeatureItem(name_items[i], [val for val in att.values],
                                points[i], scale_x, name_offset, -scale_x *
                                min_p, self.points[i][cls_index])
            if att.is_discrete else cont_feature_item_class(
                name_items[i], self.log_reg_cont_data_extremes[i][cls_index],
                self.get_ruler_values(
                    np.min(points[i]), np.max(points[i]),
                    scale_x * (np.max(points[i]) - np.min(points[i])),
                    False), scale_x, name_offset, -scale_x *
                min_p, self.log_reg_coeffs_orig[i][cls_index][0])
            for i, att in enumerate(self.domain.attributes)
        ]
        self.nomogram_main.add_items(
            self.feature_items, self.sort_index,
            self.n_attributes if self.display_index else None)

        x = -scale_x * min_p
        y = self.nomogram_main.layout.preferredHeight() + 30
        self.vertical_line = QGraphicsLineItem(x, -6, x, y)
        self.vertical_line.setPen(QPen(Qt.DotLine))
        self.vertical_line.setParentItem(point_item)
        self.hidden_vertical_line = QGraphicsLineItem(x, -6, x, y)
        pen = QPen(Qt.DashLine)
        pen.setBrush(QColor(Qt.red))
        self.hidden_vertical_line.setPen(pen)
        self.hidden_vertical_line.setParentItem(point_item)

        return point_item, nomogram_header

    def create_footer_nomogram(self, probs_text, d, minimums, max_width,
                               name_offset):
        eps, d_ = 0.05, 1
        k = -np.log(self.p / (1 - self.p)) if self.p is not None else -self.b0
        min_sum = k[self.target_class_index] - np.log((1 - eps) / eps)
        max_sum = k[self.target_class_index] - np.log(eps / (1 - eps))
        if self.align == OWNomogram.ALIGN_LEFT:
            max_sum = max_sum - sum(minimums)
            min_sum = min_sum - sum(minimums)
            for i in range(len(k)):
                k[i] = k[i] - sum(
                    [min(q) for q in [p[i] for p in self.points]])
        if self.scale == OWNomogram.POINT_SCALE:
            min_sum *= d
            max_sum *= d
            d_ = d

        values = self.get_ruler_values(min_sum, max_sum, max_width)
        min_sum, max_sum = min(values), max(values)
        diff_ = np.nan_to_num(max_sum - min_sum)
        scale_x = max_width / diff_ if diff_ else max_width
        cls_var, cls_index = self.domain.class_var, self.target_class_index
        nomogram_footer = NomogramItem()

        def get_normalized_probabilities(val):
            if not self.normalize_probabilities:
                return 1 / (1 + np.exp(k[cls_index] - val / d_))
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return 1 / (1 + np.exp(k[cls_index] - val / d_)) / p_sum

        def get_points(prob):
            if not self.normalize_probabilities:
                return (k[cls_index] - np.log(1 / prob - 1)) * d_
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return (k[cls_index] - np.log(1 / (prob * p_sum) - 1)) * d_

        self.markers_set = False
        probs_item = ProbabilitiesRulerItem(
            probs_text,
            values,
            scale_x,
            name_offset,
            -scale_x * min_sum,
            get_points=get_points,
            title="{}='{}'".format(cls_var.name, cls_var.values[cls_index]),
            get_probabilities=get_normalized_probabilities)
        self.markers_set = True
        nomogram_footer.add_items([probs_item])
        return probs_item, nomogram_footer

    def __get_totals_for_class_values(self, minimums):
        cls_index = self.target_class_index
        marker_values = [item.dot.value for item in self.feature_items]
        if not self.markers_set:
            marker_values = self.scale_forth(marker_values)
        totals = np.empty(len(self.domain.class_var.values))
        totals[cls_index] = sum(marker_values)
        marker_values = self.scale_back(marker_values)
        for i in range(len(self.domain.class_var.values)):
            if i == cls_index:
                continue
            coeffs = [np.nan_to_num(p[i] / p[cls_index]) for p in self.points]
            points = [p[cls_index] for p in self.points]
            total = sum([
                self.get_points_from_coeffs(v, c, p)
                for (v, c, p) in zip(marker_values, coeffs, points)
            ])
            if self.align == OWNomogram.ALIGN_LEFT:
                points = [p - m for m, p in zip(minimums, points)]
                total -= sum([min(p) for p in [p[i] for p in self.points]])
            d = 100 / max(max(abs(p)) for p in points)
            if self.scale == OWNomogram.POINT_SCALE:
                total *= d
            totals[i] = total
        return totals

    def set_feature_marker_values(self):
        if not (len(self.points) and len(self.feature_items)):
            return
        if not len(self.feature_marker_values):
            self._init_feature_marker_values()
        self.feature_marker_values = self.scale_forth(
            self.feature_marker_values)
        item = self.feature_items[0]
        for i, item in enumerate(self.feature_items):
            item.dot.move_to_val(self.feature_marker_values[i])
        item.dot.probs_dot.move_to_sum()

    def _init_feature_marker_values(self):
        self.feature_marker_values = []
        cls_index = self.target_class_index
        instances = Table(self.domain, self.instances) \
            if self.instances else None
        for i, attr in enumerate(self.domain.attributes):
            value, feature_val = 0, None
            if len(self.log_reg_coeffs):
                if attr.is_discrete:
                    ind, n = unique(self.data.X[:, i], return_counts=True)
                    feature_val = np.nan_to_num(ind[np.argmax(n)])
                else:
                    feature_val = mean(self.data.X[:, i])
            inst_in_dom = instances and attr in instances.domain
            if inst_in_dom and not np.isnan(instances[0][attr]):
                feature_val = instances[0][attr]
            if feature_val is not None:
                value = self.points[i][cls_index][int(feature_val)] \
                    if attr.is_discrete else \
                    self.log_reg_coeffs_orig[i][cls_index][0] * feature_val
            self.feature_marker_values.append(value)

    def clear_scene(self):
        self.feature_items = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.scene.clear()

    def send_report(self):
        self.report_plot()

    @staticmethod
    def reconstruct_domain(original, preprocessed):
        # abuse dict to make "in" comparisons faster
        attrs = OrderedDict()
        for attr in preprocessed.attributes:
            cv = attr._compute_value.variable._compute_value
            var = cv.variable if cv else original[attr.name]
            if var in attrs:  # the reason for OrderedDict
                continue
            attrs[var] = None  # we only need keys
        attrs = list(attrs.keys())
        return Domain(attrs, original.class_var, original.metas)

    @staticmethod
    def get_ruler_values(start, stop, max_width, round_to_nearest=True):
        if max_width == 0:
            return [0]
        diff = np.nan_to_num((stop - start) / max_width)
        if diff <= 0:
            return [0]
        decimals = int(np.floor(np.log10(diff)))
        if diff > 4 * pow(10, decimals):
            step = 5 * pow(10, decimals + 2)
        elif diff > 2 * pow(10, decimals):
            step = 2 * pow(10, decimals + 2)
        elif diff > 1 * pow(10, decimals):
            step = 1 * pow(10, decimals + 2)
        else:
            step = 5 * pow(10, decimals + 1)
        round_by = int(-np.floor(np.log10(step)))
        r = start % step
        if not round_to_nearest:
            _range = np.arange(start + step, stop + r, step) - r
            start, stop = np.floor(start * 100) / 100, np.ceil(
                stop * 100) / 100
            return np.round(np.hstack((start, _range, stop)), 2)
        return np.round(np.arange(start, stop + r + step, step) - r, round_by)

    @staticmethod
    def get_points_from_coeffs(current_value, coefficients, possible_values):
        if any(np.isnan(possible_values)):
            return 0
        indices = np.argsort(possible_values)
        sorted_values = possible_values[indices]
        sorted_coefficients = coefficients[indices]
        for i, val in enumerate(sorted_values):
            if current_value < val:
                break
        diff = sorted_values[i] - sorted_values[i - 1]
        k = 0 if diff < 1e-6 else (sorted_values[i] - current_value) / \
                                  (sorted_values[i] - sorted_values[i - 1])
        return sorted_coefficients[i - 1] * sorted_values[i - 1] * k + \
               sorted_coefficients[i] * sorted_values[i] * (1 - k)
class TestClassValuesContextHandler(TestCase):
    def setUp(self):
        self.domain = Domain(attributes=[
            ContinuousVariable('c1'),
            DiscreteVariable('d1', values='abc'),
            DiscreteVariable('d2', values='def')
        ],
                             class_vars=[DiscreteVariable('d3', values='ghi')],
                             metas=[
                                 ContinuousVariable('c2'),
                                 DiscreteVariable('d4', values='jkl')
                             ])
        self.args = (self.domain, {
            'c1': Continuous,
            'd1': Discrete,
            'd2': Discrete,
            'd3': Discrete
        }, {
            'c2': Continuous,
            'd4': Discrete,
        })
        self.handler = ClassValuesContextHandler()
        self.handler.read_defaults = lambda: None

    def test_open_context(self):
        self.handler.bind(SimpleWidget)
        context = Mock(classes=['g', 'h', 'i'],
                       values=dict(text='u',
                                   with_metas=[('d1', Discrete),
                                               ('d2', Discrete)]))
        self.handler.global_contexts = \
            [Mock(values={}), context, Mock(values={})]

        widget = SimpleWidget()
        self.handler.initialize(widget)
        self.handler.open_context(widget, self.args[0].class_var)
        self.assertEqual(widget.text, 'u')
        self.assertEqual(widget.with_metas, [('d1', Discrete),
                                             ('d2', Discrete)])

    def test_open_context_with_no_match(self):
        self.handler.bind(SimpleWidget)
        context = Mock(classes=['g', 'h', 'i'],
                       values=dict(text='u',
                                   with_metas=[('d1', Discrete),
                                               ('d2', Discrete)]))
        self.handler.global_contexts = \
            [Mock(values={}), context, Mock(values={})]
        widget = SimpleWidget()
        self.handler.initialize(widget)
        widget.text = 'u'

        self.handler.open_context(widget, self.args[0][1])

        context = widget.current_context
        self.assertEqual(context.classes, ['a', 'b', 'c'])
        self.assertEqual(widget.text, 'u')
        self.assertEqual(widget.with_metas, [])
Esempio n. 9
0
class OWNomogram(OWWidget):
    name = "列线图"
    description = "朴素贝叶斯可视化的诺莫图和逻辑回归分类器"
    icon = "icons/Nomogram.svg"
    priority = 2000
    keywords = []

    class Inputs:
        classifier = Input("分类器", Model)
        data = Input("数据", Table)

    MAX_N_ATTRS = 1000
    POINT_SCALE = 0
    ALIGN_LEFT = 0
    ALIGN_ZERO = 1
    ACCEPTABLE = (NaiveBayesModel, LogisticRegressionClassifier)
    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    normalize_probabilities = Setting(False)
    scale = Setting(1)
    display_index = Setting(1)
    n_attributes = Setting(10)
    sort_index = Setting(SortBy.ABSOLUTE)
    cont_feature_dim_index = Setting(0)

    graph_name = "场景"

    class Error(OWWidget.Error):
        invalid_classifier = Msg("诺莫图只接受朴素贝叶斯和逻辑回归分类器")

    def __init__(self):
        super().__init__()
        self.instances = None
        self.domain = None
        self.data = None
        self.classifier = None
        self.align = OWNomogram.ALIGN_ZERO
        self.log_odds_ratios = []
        self.log_reg_coeffs = []
        self.log_reg_coeffs_orig = []
        self.log_reg_cont_data_extremes = []
        self.p = None
        self.b0 = None
        self.points = []
        self.feature_items = {}
        self.feature_marker_values = []
        self.scale_marker_values = lambda x: x
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.old_target_class_index = self.target_class_index
        self.repaint = False

        # GUI
        box = gui.vBox(self.controlArea, "目标类")
        self.class_combo = gui.comboBox(
            box, self, "target_class_index", callback=self._class_combo_changed,
            contentsLength=12)
        self.norm_check = gui.checkBox(
            box, self, "normalize_probabilities", "Normalize probabilities",
            hidden=True, callback=self.update_scene,
            tooltip="For multiclass data 1 vs. all probabilities do not"
                    " sum to 1 and therefore could be normalized.")

        self.scale_radio = gui.radioButtons(
            self.controlArea, self, "scale", ["点量表", "记录比值比"],
            box="规模", callback=self.update_scene)

        box = gui.vBox(self.controlArea, "显示特征")
        grid = QGridLayout()
        radio_group = gui.radioButtonsInBox(
            box, self, "display_index", [], orientation=grid,
            callback=self.update_scene)
        radio_all = gui.appendRadioButton(
            radio_group, "所有", addToLayout=False)
        radio_best = gui.appendRadioButton(
            radio_group, "最佳排名:", addToLayout=False)
        spin_box = gui.hBox(None, margin=0)
        self.n_spin = gui.spin(
            spin_box, self, "n_attributes", 1, self.MAX_N_ATTRS, label=" ",
            controlWidth=60, callback=self._n_spin_changed)
        grid.addWidget(radio_all, 1, 1)
        grid.addWidget(radio_best, 2, 1)
        grid.addWidget(spin_box, 2, 2)

        self.sort_combo = gui.comboBox(
            box, self, "sort_index", label="排名靠前:", items=SortBy.items(),
            orientation=Qt.Horizontal, callback=self.update_scene)

        self.cont_feature_dim_combo = gui.comboBox(
            box, self, "cont_feature_dim_index", label="数字特征: ",
            items=["1维投影", "二维曲线"], orientation=Qt.Horizontal,
            callback=self.update_scene)

        gui.rubber(self.controlArea)

        class _GraphicsView(QGraphicsView):
            def __init__(self, scene, parent, **kwargs):
                for k, v in dict(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 viewportUpdateMode=QGraphicsView.BoundingRectViewportUpdate,
                                 renderHints=(QPainter.Antialiasing |
                                              QPainter.TextAntialiasing |
                                              QPainter.SmoothPixmapTransform),
                                 alignment=(Qt.AlignTop |
                                            Qt.AlignLeft),
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.MinimumExpanding)).items():
                    kwargs.setdefault(k, v)

                super().__init__(scene, parent, **kwargs)

        class GraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                 styleSheet='QGraphicsView {background: white}')
                self.viewport().setMinimumWidth(300)  # XXX: This prevents some tests failing
                self._is_resizing = False

            w = self

            def resizeEvent(self, resizeEvent):
                # Recompute main scene on window width change
                if resizeEvent.size().width() != resizeEvent.oldSize().width():
                    self._is_resizing = True
                    self.w.update_scene()
                    self._is_resizing = False
                return super().resizeEvent(resizeEvent)

            def is_resizing(self):
                return self._is_resizing

            def sizeHint(self):
                return QSize(400, 200)

        class FixedSizeGraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.Minimum))

            def sizeHint(self):
                return QSize(400, 85)

        scene = self.scene = QGraphicsScene(self)

        top_view = self.top_view = FixedSizeGraphicsView(scene, self)
        mid_view = self.view = GraphicsView(scene, self)
        bottom_view = self.bottom_view = FixedSizeGraphicsView(scene, self)

        for view in (top_view, mid_view, bottom_view):
            self.mainArea.layout().addWidget(view)

    def _class_combo_changed(self):
        with np.errstate(invalid='ignore'):
            coeffs = [np.nan_to_num(p[self.target_class_index] /
                                    p[self.old_target_class_index])
                      for p in self.points]
        points = [p[self.old_target_class_index] for p in self.points]
        self.feature_marker_values = [
            self.get_points_from_coeffs(v, c, p) for (v, c, p) in
            zip(self.feature_marker_values, coeffs, points)]
        self.feature_marker_values = np.asarray(self.feature_marker_values)
        self.update_scene()
        self.old_target_class_index = self.target_class_index

    def _n_spin_changed(self):
        self.display_index = 1
        self.update_scene()

    def update_controls(self):
        self.class_combo.clear()
        self.norm_check.setHidden(True)
        self.cont_feature_dim_combo.setEnabled(True)
        if self.domain is not None:
            self.class_combo.addItems(self.domain.class_vars[0].values)
            if len(self.domain.attributes) > self.MAX_N_ATTRS:
                self.display_index = 1
            if len(self.domain.class_vars[0].values) > 2:
                self.norm_check.setHidden(False)
            if not self.domain.has_continuous_attributes():
                self.cont_feature_dim_combo.setEnabled(False)
                self.cont_feature_dim_index = 0
        model = self.sort_combo.model()
        item = model.item(SortBy.POSITIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        item = model.item(SortBy.NEGATIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        self.align = OWNomogram.ALIGN_ZERO
        if self.classifier and isinstance(self.classifier,
                                          LogisticRegressionClassifier):
            self.align = OWNomogram.ALIGN_LEFT

    @Inputs.data
    def set_data(self, data):
        self.instances = data
        self.feature_marker_values = []
        self.set_feature_marker_values()
        self.update_scene()

    @Inputs.classifier
    def set_classifier(self, classifier):
        self.closeContext()
        self.classifier = classifier
        self.Error.clear()
        if self.classifier and not isinstance(self.classifier, self.ACCEPTABLE):
            self.Error.invalid_classifier()
            self.classifier = None
        self.domain = self.classifier.domain if self.classifier else None
        self.data = None
        self.calculate_log_odds_ratios()
        self.calculate_log_reg_coefficients()
        self.update_controls()
        self.target_class_index = 0
        self.openContext(self.domain.class_var if self.domain is not None
                         else None)
        self.points = self.log_odds_ratios or self.log_reg_coeffs
        self.feature_marker_values = []
        self.old_target_class_index = self.target_class_index
        self.update_scene()

    def calculate_log_odds_ratios(self):
        self.log_odds_ratios = []
        self.p = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, NaiveBayesModel):
            return

        log_cont_prob = self.classifier.log_cont_prob
        class_prob = self.classifier.class_prob
        for i in range(len(self.domain.attributes)):
            ca = np.exp(log_cont_prob[i]) * class_prob[:, None]
            _or = (ca / (1 - ca)) / (class_prob / (1 - class_prob))[:, None]
            self.log_odds_ratios.append(np.log(_or))
        self.p = class_prob

    def calculate_log_reg_coefficients(self):
        self.log_reg_coeffs = []
        self.log_reg_cont_data_extremes = []
        self.b0 = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, LogisticRegressionClassifier):
            return

        self.domain = self.reconstruct_domain(self.classifier.original_domain,
                                              self.domain)
        self.data = self.classifier.original_data.transform(self.domain)
        attrs, ranges, start = self.domain.attributes, [], 0
        for attr in attrs:
            stop = start + len(attr.values) if attr.is_discrete else start + 1
            ranges.append(slice(start, stop))
            start = stop

        self.b0 = self.classifier.intercept
        coeffs = self.classifier.coefficients
        if len(self.domain.class_var.values) == 2:
            self.b0 = np.hstack((self.b0 * (-1), self.b0))
            coeffs = np.vstack((coeffs * (-1), coeffs))
        self.log_reg_coeffs = [coeffs[:, ranges[i]] for i in range(len(attrs))]
        self.log_reg_coeffs_orig = self.log_reg_coeffs.copy()

        min_values = nanmin(self.data.X, axis=0)
        max_values = nanmax(self.data.X, axis=0)

        for i, min_t, max_t in zip(range(len(self.log_reg_coeffs)),
                                   min_values, max_values):
            if self.log_reg_coeffs[i].shape[1] == 1:
                coef = self.log_reg_coeffs[i]
                self.log_reg_coeffs[i] = np.hstack((coef * min_t, coef * max_t))
                self.log_reg_cont_data_extremes.append(
                    [sorted([min_t, max_t], reverse=(c < 0)) for c in coef])
            else:
                self.log_reg_cont_data_extremes.append([None])

    def update_scene(self):
        self.clear_scene()
        if self.domain is None or not len(self.points[0]):
            return

        n_attrs = self.n_attributes if self.display_index else int(1e10)
        attr_inds, attributes = zip(*self.get_ordered_attributes()[:n_attrs])

        name_items = [QGraphicsTextItem(attr.name) for attr in attributes]
        point_text = QGraphicsTextItem("Points")
        probs_text = QGraphicsTextItem("Probabilities (%)")
        all_items = name_items + [point_text, probs_text]
        name_offset = -max(t.boundingRect().width() for t in all_items) - 10
        w = self.view.viewport().rect().width()
        max_width = w + name_offset - 30

        points = [self.points[i][self.target_class_index]
                  for i in attr_inds]
        if self.align == OWNomogram.ALIGN_LEFT:
            points = [p - p.min() for p in points]
        max_ = np.nan_to_num(max(max(abs(p)) for p in points))
        d = 100 / max_ if max_ else 1
        minimums = [p[self.target_class_index].min() for p in self.points]
        if self.scale == OWNomogram.POINT_SCALE:
            points = [p * d for p in points]

            if self.align == OWNomogram.ALIGN_LEFT:
                self.scale_marker_values = lambda x: (x - minimums) * d
            else:
                self.scale_marker_values = lambda x: x * d
        else:
            if self.align == OWNomogram.ALIGN_LEFT:
                self.scale_marker_values = lambda x: x - minimums
            else:
                self.scale_marker_values = lambda x: x

        point_item, nomogram_head = self.create_main_nomogram(
            attributes, attr_inds,
            name_items, points, max_width, point_text, name_offset)
        probs_item, nomogram_foot = self.create_footer_nomogram(
            probs_text, d, minimums, max_width, name_offset)
        for item in self.feature_items.values():
            item.dot.point_dot = point_item.dot
            item.dot.probs_dot = probs_item.dot
            item.dot.vertical_line = self.hidden_vertical_line

        self.nomogram = nomogram = NomogramItem()
        nomogram.add_items([nomogram_head, self.nomogram_main, nomogram_foot])
        self.scene.addItem(nomogram)

        self.set_feature_marker_values()

        rect = QRectF(self.scene.itemsBoundingRect().x(),
                      self.scene.itemsBoundingRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height()).adjusted(10, 0, 20, 0)
        self.scene.setSceneRect(rect)

        # Clip top and bottom (60 and 150) parts from the main view
        self.view.setSceneRect(rect.x(), rect.y() + 80, rect.width() - 10, rect.height() - 160)
        self.view.viewport().setMaximumHeight(rect.height() - 160)
        # Clip main part from top/bottom views
        # below point values are imprecise (less/more than required) but this
        # is not a problem due to clipped scene content still being drawn
        self.top_view.setSceneRect(rect.x(), rect.y() + 3, rect.width() - 10, 20)
        self.bottom_view.setSceneRect(rect.x(), rect.height() - 110, rect.width() - 10, 30)

    def create_main_nomogram(self, attributes, attr_inds, name_items, points,
                             max_width, point_text, name_offset):
        cls_index = self.target_class_index
        min_p = min(p.min() for p in points)
        max_p = max(p.max() for p in points)
        values = self.get_ruler_values(min_p, max_p, max_width)
        min_p, max_p = min(values), max(values)
        diff_ = np.nan_to_num(max_p - min_p)
        scale_x = max_width / diff_ if diff_ else max_width

        nomogram_header = NomogramItem()
        point_item = RulerItem(point_text, values, scale_x, name_offset,
                               - scale_x * min_p)
        point_item.setPreferredSize(point_item.preferredWidth(), 35)
        nomogram_header.add_items([point_item])

        self.nomogram_main = NomogramItem()
        cont_feature_item_class = ContinuousFeature2DItem if \
            self.cont_feature_dim_index else ContinuousFeatureItem

        feature_items = [
            DiscreteFeatureItem(
                name_item, attr.values, point,
                scale_x, name_offset, - scale_x * min_p)
            if attr.is_discrete else
            cont_feature_item_class(
                name_item, self.log_reg_cont_data_extremes[i][cls_index],
                self.get_ruler_values(
                    point.min(), point.max(),
                    scale_x * point.ptp(), False),
                scale_x, name_offset, - scale_x * min_p)
            for i, attr, name_item, point in zip(attr_inds, attributes, name_items, points)]

        self.nomogram_main.add_items(feature_items)
        self.feature_items = OrderedDict(sorted(zip(attr_inds, feature_items)))

        x = - scale_x * min_p
        y = self.nomogram_main.layout().preferredHeight() + 10
        self.vertical_line = QGraphicsLineItem(x, -6, x, y)
        self.vertical_line.setPen(QPen(Qt.DotLine))
        self.vertical_line.setParentItem(point_item)
        self.hidden_vertical_line = QGraphicsLineItem(x, -6, x, y)
        pen = QPen(Qt.DashLine)
        pen.setBrush(QColor(Qt.red))
        self.hidden_vertical_line.setPen(pen)
        self.hidden_vertical_line.setParentItem(point_item)

        return point_item, nomogram_header

    def get_ordered_attributes(self):
        """Return (in_domain_index, attr) pairs, ordered by method in SortBy combo"""
        if self.domain is None or not self.domain.attributes:
            return []

        attrs = self.domain.attributes
        sort_by = self.sort_index
        class_value = self.target_class_index

        if sort_by == SortBy.NO_SORTING:
            return list(enumerate(attrs))

        elif sort_by == SortBy.NAME:

            def key(x):
                _, attr = x
                return attr.name.lower()

        elif sort_by == SortBy.ABSOLUTE:

            def key(x):
                i, attr = x
                if attr.is_discrete:
                    ptp = self.points[i][class_value].ptp()
                else:
                    coef = np.abs(self.log_reg_coeffs_orig[i][class_value]).mean()
                    ptp = coef * np.ptp(self.log_reg_cont_data_extremes[i][class_value])
                return -ptp

        elif sort_by == SortBy.POSITIVE:

            def key(x):
                i, attr = x
                max_value = (self.points[i][class_value].max()
                             if attr.is_discrete else
                             np.mean(self.log_reg_cont_data_extremes[i][class_value]))
                return -max_value

        elif sort_by == SortBy.NEGATIVE:

            def key(x):
                i, attr = x
                min_value = (self.points[i][class_value].min()
                             if attr.is_discrete else
                             np.mean(self.log_reg_cont_data_extremes[i][class_value]))
                return min_value

        return sorted(enumerate(attrs), key=key)


    def create_footer_nomogram(self, probs_text, d, minimums,
                               max_width, name_offset):
        # pylint: disable=invalid-unary-operand-type
        eps, d_ = 0.05, 1
        k = - np.log(self.p / (1 - self.p)) if self.p is not None else - self.b0
        min_sum = k[self.target_class_index] - np.log((1 - eps) / eps)
        max_sum = k[self.target_class_index] - np.log(eps / (1 - eps))
        if self.align == OWNomogram.ALIGN_LEFT:
            max_sum = max_sum - sum(minimums)
            min_sum = min_sum - sum(minimums)
            for i in range(len(k)):  # pylint: disable=consider-using-enumerate
                k[i] = k[i] - sum([min(q) for q in [p[i] for p in self.points]])
        if self.scale == OWNomogram.POINT_SCALE:
            min_sum *= d
            max_sum *= d
            d_ = d

        values = self.get_ruler_values(min_sum, max_sum, max_width)
        min_sum, max_sum = min(values), max(values)
        diff_ = np.nan_to_num(max_sum - min_sum)
        scale_x = max_width / diff_ if diff_ else max_width
        cls_var, cls_index = self.domain.class_var, self.target_class_index
        nomogram_footer = NomogramItem()

        def get_normalized_probabilities(val):
            if not self.normalize_probabilities:
                return 1 / (1 + np.exp(k[cls_index] - val / d_))
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return 1 / (1 + np.exp(k[cls_index] - val / d_)) / p_sum

        def get_points(prob):
            if not self.normalize_probabilities:
                return (k[cls_index] - np.log(1 / prob - 1)) * d_
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return (k[cls_index] - np.log(1 / (prob * p_sum) - 1)) * d_

        probs_item = ProbabilitiesRulerItem(
            probs_text, values, scale_x, name_offset, - scale_x * min_sum,
            get_points=get_points,
            title="{}='{}'".format(cls_var.name, cls_var.values[cls_index]),
            get_probabilities=get_normalized_probabilities)
        nomogram_footer.add_items([probs_item])
        return probs_item, nomogram_footer

    def __get_totals_for_class_values(self, minimums):
        cls_index = self.target_class_index
        marker_values = self.scale_marker_values(self.feature_marker_values)
        totals = np.full(len(self.domain.class_var.values), np.nan)
        totals[cls_index] = marker_values.sum()
        for i in range(len(self.domain.class_var.values)):
            if i == cls_index:
                continue
            coeffs = [np.nan_to_num(p[i] / p[cls_index]) for p in self.points]
            points = [p[cls_index] for p in self.points]
            total = sum([self.get_points_from_coeffs(v, c, p) for (v, c, p)
                         in zip(self.feature_marker_values, coeffs, points)])
            if self.align == OWNomogram.ALIGN_LEFT:
                points = [p - m for m, p in zip(minimums, points)]
                total -= sum([min(p) for p in [p[i] for p in self.points]])
            d = 100 / max(max(abs(p)) for p in points)
            if self.scale == OWNomogram.POINT_SCALE:
                total *= d
            totals[i] = total
        assert not np.any(np.isnan(totals))
        return totals

    def set_feature_marker_values(self):
        if not (len(self.points) and len(self.feature_items)):
            return
        if not len(self.feature_marker_values):
            self._init_feature_marker_values()
        marker_values = self.scale_marker_values(self.feature_marker_values)

        invisible_sum = 0
        for i, marker in enumerate(marker_values):
            try:
                item = self.feature_items[i]
            except KeyError:
                invisible_sum += marker
            else:
                item.dot.move_to_val(marker)

        item.dot.probs_dot.move_to_sum(invisible_sum)

    def _init_feature_marker_values(self):
        self.feature_marker_values = []
        cls_index = self.target_class_index
        instances = Table(self.domain, self.instances) \
            if self.instances else None
        values = []
        for i, attr in enumerate(self.domain.attributes):
            value, feature_val = 0, None
            if len(self.log_reg_coeffs):
                if attr.is_discrete:
                    ind, n = unique(self.data.X[:, i], return_counts=True)
                    feature_val = np.nan_to_num(ind[np.argmax(n)])
                else:
                    feature_val = nanmean(self.data.X[:, i])

            # If data is provided on a separate signal, use the first data
            # instance to position the points instead of the mean
            inst_in_dom = instances and attr in instances.domain
            if inst_in_dom and not np.isnan(instances[0][attr]):
                feature_val = instances[0][attr]

            if feature_val is not None:
                value = (self.points[i][cls_index][int(feature_val)]
                         if attr.is_discrete else
                         self.log_reg_coeffs_orig[i][cls_index][0] * feature_val)
            values.append(value)
        self.feature_marker_values = np.asarray(values)

    def clear_scene(self):
        self.feature_items = {}
        self.scale_marker_values = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.scene.clear()

    def send_report(self):
        self.report_plot()

    @staticmethod
    def reconstruct_domain(original, preprocessed):
        # abuse dict to make "in" comparisons faster
        attrs = OrderedDict()
        for attr in preprocessed.attributes:
            cv = attr._compute_value.variable._compute_value
            var = cv.variable if cv else original[attr.name]
            if var in attrs:    # the reason for OrderedDict
                continue
            attrs[var] = None   # we only need keys
        attrs = list(attrs.keys())
        return Domain(attrs, original.class_var, original.metas)

    @staticmethod
    def get_ruler_values(start, stop, max_width, round_to_nearest=True):
        if max_width == 0:
            return [0]
        diff = np.nan_to_num((stop - start) / max_width)
        if diff <= 0:
            return [0]
        decimals = int(np.floor(np.log10(diff)))
        if diff > 4 * pow(10, decimals):
            step = 5 * pow(10, decimals + 2)
        elif diff > 2 * pow(10, decimals):
            step = 2 * pow(10, decimals + 2)
        elif diff > 1 * pow(10, decimals):
            step = 1 * pow(10, decimals + 2)
        else:
            step = 5 * pow(10, decimals + 1)
        round_by = int(- np.floor(np.log10(step)))
        r = start % step
        if not round_to_nearest:
            _range = np.arange(start + step, stop + r, step) - r
            start, stop = np.floor(start * 100) / 100, np.ceil(stop * 100) / 100
            return np.round(np.hstack((start, _range, stop)), 2)
        return np.round(np.arange(start, stop + r + step, step) - r, round_by)

    @staticmethod
    def get_points_from_coeffs(current_value, coefficients, possible_values):
        if np.isnan(possible_values).any():
            return 0
        # pylint: disable=undefined-loop-variable
        indices = np.argsort(possible_values)
        sorted_values = possible_values[indices]
        sorted_coefficients = coefficients[indices]
        for i, val in enumerate(sorted_values):
            if current_value < val:
                break
        diff = sorted_values[i] - sorted_values[i - 1]
        k = 0 if diff < 1e-6 else (sorted_values[i] - current_value) / \
                                  (sorted_values[i] - sorted_values[i - 1])
        return sorted_coefficients[i - 1] * sorted_values[i - 1] * k + \
               sorted_coefficients[i] * sorted_values[i] * (1 - k)
class OWRegressionTreeGraph(OWTreeGraph):
    name = "Regression Tree Viewer"
    description = "A graphical visualization of a regression tree."
    icon = "icons/RegressionTreeGraph.svg"
    priority = 35

    settingsHandler = ClassValuesContextHandler()
    color_index = Setting(0)
    color_settings = Setting(None)
    selected_color_settings_index = Setting(0)

    inputs = [("Regression Tree", TreeRegressor, "ctree")]
    NODE = RegressionTreeNode

    def __init__(self):
        super().__init__()
        box = gui.vBox(self.controlArea, "Nodes", addSpace=True)
        self.color_combo = gui.comboBox(box,
                                        self,
                                        "color_index",
                                        orientation=Qt.Horizontal,
                                        items=[],
                                        label="Colors",
                                        callback=self.toggle_color,
                                        contentsLength=8)
        gui.separator(box)
        gui.rubber(self.controlArea)

    def ctree(self, model=None):
        if model is not None:
            self.color_combo.clear()
            self.color_combo.addItem("Default")
            self.color_combo.addItem("Instances in node")
            self.color_combo.addItem("Impurity")
            self.color_combo.setCurrentIndex(self.color_index)
            self.scene.colorPalette = \
                ContinuousPaletteGenerator(*model.domain.class_var.colors)
        super().ctree(model)

    def update_node_info(self, node):
        distr = node.get_distribution()
        total = node.num_instances()
        total_tree = self.tree.n_node_samples[0]
        impurity = node.impurity()
        text = "{:2.1f}<br/>".format(sum(distr.reshape(1)))
        text += "{:2.1f}%, {}/{}<br/>".format(100 * total / total_tree, total,
                                              total_tree)
        text += "{:2.3f}".format(impurity)
        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.format(text))

    def toggle_node_color(self):
        palette = self.scene.colorPalette
        all_instances = self.tree.n_node_samples[0]
        max_impurity = self.tree.impurity[0]
        for node in self.scene.nodes():
            li = [
                0.5,
                node.num_instances() / all_instances,
                node.impurity() / max_impurity
            ][self.color_index]
            node.backgroundBrush = QBrush(
                palette[self.color_index].lighter(180 - li * 150))
        self.scene.update()

    def send_report(self):
        if not self.tree:
            return
        self.report_items(
            (("Tree size", self.info.text()),
             ("Edge widths", ("Fixed", "Relative to root",
                              "Relative to parent")[self.line_width_method])))
        self.report_plot(self.scene)
Esempio n. 11
0
class OWTreeGraph(OWTreeViewer2D):
    """Graphical visualization of tree models"""

    name = "Tree Viewer"
    icon = "icons/TreeViewer.svg"
    priority = 35
    inputs = [("Tree", TreeModel, "ctree")]
    outputs = [("Selected Data", Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Table)]

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    regression_colors = Setting(0)

    COL_OPTIONS = ["Default", "Number of instances", "Mean value", "Variance"]
    COL_DEFAULT, COL_INSTANCE, COL_MEAN, COL_VARIANCE = range(4)

    def __init__(self):
        super().__init__()
        self.domain = None
        self.dataset = None
        self.clf_dataset = None

        self.color_label = QLabel("Target class: ")
        combo = self.color_combo = gui.OrangeComboBox()
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        combo.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLengthWithIcon)
        combo.setMinimumContentsLength(8)
        combo.activated[int].connect(self.color_changed)
        self.display_box.layout().addRow(self.color_label, combo)

    def set_node_info(self):
        """Set the content of the node"""
        for node in self.scene.nodes():
            node.set_rect(QRectF())
            self.update_node_info(node)
        w = max([n.rect().width() for n in self.scene.nodes()] + [0])
        if w > self.max_node_width:
            w = self.max_node_width
        for node in self.scene.nodes():
            rect = node.rect()
            node.set_rect(QRectF(rect.x(), rect.y(), w, rect.height()))
        self.scene.fix_pos(self.root_node, 10, 10)

    @staticmethod
    def _update_node_info_attr_name(node, text):
        attr = node.node_inst.attr
        if attr is not None:
            text += "<hr/>{}".format(attr.name)
        return text

    def activate_loaded_settings(self):
        if not self.model:
            return
        super().activate_loaded_settings()
        if self.domain.class_var.is_discrete:
            self.color_combo.setCurrentIndex(self.target_class_index)
            self.toggle_node_color_cls()
        else:
            self.color_combo.setCurrentIndex(self.regression_colors)
            self.toggle_node_color_reg()
        self.set_node_info()

    def color_changed(self, i):
        if self.domain.class_var.is_discrete:
            self.target_class_index = i
            self.toggle_node_color_cls()
        else:
            self.regression_colors = i
            self.toggle_node_color_reg()

    def toggle_node_size(self):
        self.set_node_info()
        self.scene.update()
        self.scene_view.repaint()

    def toggle_color_cls(self):
        self.toggle_node_color_cls()
        self.set_node_info()
        self.scene.update()

    def toggle_color_reg(self):
        self.toggle_node_color_reg()
        self.set_node_info()
        self.scene.update()

    def ctree(self, model=None):
        """Input signal handler"""
        self.clear_scene()
        self.color_combo.clear()
        self.closeContext()
        self.model = model
        if model is None:
            self.info.setText('No tree.')
            self.root_node = None
            self.dataset = None
        else:
            self.domain = model.domain
            self.dataset = model.instances
            if self.dataset is not None and self.dataset.domain != self.domain:
                self.clf_dataset = Table.from_table(model.domain, self.dataset)
            else:
                self.clf_dataset = self.dataset
            class_var = self.domain.class_var
            if class_var.is_discrete:
                self.scene.colors = [QColor(*col) for col in class_var.colors]
                self.color_label.setText("Target class: ")
                self.color_combo.addItem("None")
                self.color_combo.addItems(self.domain.class_vars[0].values)
                self.color_combo.setCurrentIndex(self.target_class_index)
            else:
                self.scene.colors = \
                    ContinuousPaletteGenerator(*model.domain.class_var.colors)
                self.color_label.setText("Color by: ")
                self.color_combo.addItems(self.COL_OPTIONS)
                self.color_combo.setCurrentIndex(self.regression_colors)
            self.openContext(self.domain.class_var)
            self.root_node = self.walkcreate(model.root, None)
            self.scene.addItem(self.root_node)
            self.info.setText('{} nodes, {} leaves'.format(
                model.node_count(), model.leaf_count()))
        self.setup_scene()
        self.send("Selected Data", None)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.dataset, None))

    def walkcreate(self, node_inst, parent=None):
        """Create a structure of tree nodes from the given model"""
        node = TreeNode(self.model, node_inst, parent)
        self.scene.addItem(node)
        if parent:
            edge = GraphicsEdge(node1=parent, node2=node)
            self.scene.addItem(edge)
            parent.graph_add_edge(edge)
        for child_inst in node_inst.children:
            if child_inst is not None:
                self.walkcreate(child_inst, node)
        return node

    def node_tooltip(self, node):
        return "<br>".join(
            to_html(rule) for rule in self.model.rule(node.node_inst))

    def update_selection(self):
        if self.model is None:
            return
        nodes = [
            item.node_inst for item in self.scene.selectedItems()
            if isinstance(item, TreeNode)
        ]
        data = self.model.get_instances(nodes)
        self.send("Selected Data", data)
        self.send(
            ANNOTATED_DATA_SIGNAL_NAME,
            create_annotated_table(self.dataset,
                                   self.model.get_indices(nodes)))

    def send_report(self):
        if not self.model:
            return
        items = [
            ("Tree size", self.info.text()),
            (
                "Edge widths",
                ("Fixed", "Relative to root", "Relative to parent")[
                    # pylint: disable=invalid-sequence-index
                    self.line_width_method])
        ]
        if self.domain.class_var.is_discrete:
            items.append(("Target class", self.color_combo.currentText()))
        elif self.regression_colors != self.COL_DEFAULT:
            items.append(
                ("Color by", self.COL_OPTIONS[self.regression_colors]))
        self.report_items(items)
        self.report_plot(self.scene)

    def update_node_info(self, node):
        if self.domain.class_var.is_discrete:
            self.update_node_info_cls(node)
        else:
            self.update_node_info_reg(node)

    def update_node_info_cls(self, node):
        """Update the printed contents of the node for classification trees"""
        node_inst = node.node_inst
        distr = node_inst.value
        total = len(node_inst.subset)
        distr = distr / np.sum(distr)
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = np.argmax(distr)
            tabs = distr[modus]
            text = self.domain.class_vars[0].values[int(modus)] + "<br/>"
        if tabs > 0.999:
            text += "100%, {}/{}".format(total, total)
        else:
            text += "{:2.1f}%, {}/{}".format(100 * tabs, int(total * tabs),
                                             total)

        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.format(text))

    def update_node_info_reg(self, node):
        """Update the printed contents of the node for regression trees"""
        node_inst = node.node_inst
        mean, var = node_inst.value
        insts = len(node_inst.subset)
        text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
        text += "{} instances".format(insts)
        text = self._update_node_info_attr_name(node, text)
        node.setHtml(
            '<p style="line-height: 120%; margin-bottom: 0">{}</p>'.format(
                text))

    def toggle_node_color_cls(self):
        """Update the node color for classification trees"""
        colors = self.scene.colors
        for node in self.scene.nodes():
            distr = node.node_inst.value
            total = sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = colors[self.target_class_index - 1].lighter(200 -
                                                                    100 * p)
            else:
                modus = np.argmax(distr)
                p = distr[modus] / (total or 1)
                color = colors[int(modus)].lighter(300 - 200 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()

    def toggle_node_color_reg(self):
        """Update the node color for regression trees"""
        def_color = QColor(192, 192, 255)
        if self.regression_colors == self.COL_DEFAULT:
            brush = QBrush(def_color.lighter(100))
            for node in self.scene.nodes():
                node.backgroundBrush = brush
        elif self.regression_colors == self.COL_INSTANCE:
            max_insts = len(self.model.instances)
            for node in self.scene.nodes():
                node.backgroundBrush = QBrush(
                    def_color.lighter(120 - 20 * len(node.node_inst.subset) /
                                      max_insts))
        elif self.regression_colors == self.COL_MEAN:
            minv = np.nanmin(self.dataset.Y)
            maxv = np.nanmax(self.dataset.Y)
            fact = 1 / (maxv - minv) if minv != maxv else 1
            colors = self.scene.colors
            for node in self.scene.nodes():
                node.backgroundBrush = QBrush(
                    colors[fact * (node.node_inst.value[0] - minv)])
        else:
            nodes = list(self.scene.nodes())
            variances = [node.node_inst.value[1] for node in nodes]
            max_var = max(variances)
            for node, var in zip(nodes, variances):
                node.backgroundBrush = QBrush(
                    def_color.lighter(120 - 20 * var / max_var))
        self.scene.update()
class TestClassValuesContextHandler(TestCase):
    def setUp(self):
        self.domain = Domain(
            attributes=[ContinuousVariable('c1'),
                        DiscreteVariable('d1', values='abc'),
                        DiscreteVariable('d2', values='def')],
            class_vars=[DiscreteVariable('d3', values='ghi')],
            metas=[ContinuousVariable('c2'),
                   DiscreteVariable('d4', values='jkl')]
        )
        self.args = (self.domain,
                     {'c1': Continuous, 'd1': Discrete,
                      'd2': Discrete, 'd3': Discrete},
                     {'c2': Continuous, 'd4': Discrete, })
        self.handler = ClassValuesContextHandler()
        self.handler.read_defaults = lambda: None

    def test_open_context(self):
        self.handler.bind(SimpleWidget)
        context = Mock(
            classes=['g', 'h', 'i'], values=dict(
                text='u',
                with_metas=[('d1', Discrete), ('d2', Discrete)]
            ))
        self.handler.global_contexts = \
            [Mock(values={}), context, Mock(values={})]

        widget = SimpleWidget()
        self.handler.initialize(widget)
        self.handler.open_context(widget, self.args[0].class_var)
        self.assertEqual(widget.text, 'u')
        self.assertEqual(widget.with_metas, [('d1', Discrete),
                                             ('d2', Discrete)])

    def test_open_context_with_no_match(self):
        self.handler.bind(SimpleWidget)
        context = Mock(
            classes=['g', 'h', 'i'], values=dict(
                text='u',
                with_metas=[('d1', Discrete), ('d2', Discrete)]
            ))
        self.handler.global_contexts = \
            [Mock(values={}), context, Mock(values={})]
        widget = SimpleWidget()
        self.handler.initialize(widget)
        widget.text = 'u'

        self.handler.open_context(widget, self.args[0][1])

        context = widget.current_context
        self.assertEqual(context.classes, ['a', 'b', 'c'])
        self.assertEqual(widget.text, 'u')
        self.assertEqual(widget.with_metas, [])
Esempio n. 13
0
class OWTreeGraph(OWTreeViewer2D):
    """Graphical visualization of tree models"""

    name = "Tree Viewer"
    icon = "icons/TreeViewer.svg"
    priority = 35
    keywords = []

    class Inputs:
        # Had different input names before merging from
        # Classification/Regression tree variants
        tree = Input("Tree", TreeModel, replaces=["Classification Tree", "Regression Tree"])

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True, id="selected-data")
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table, id="annotated-data")

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    regression_colors = Setting(0)

    replaces = [
        "Orange.widgets.classify.owclassificationtreegraph.OWClassificationTreeGraph",
        "Orange.widgets.classify.owregressiontreegraph.OWRegressionTreeGraph"
    ]

    COL_OPTIONS = ["Default", "Number of instances", "Mean value", "Variance"]
    COL_DEFAULT, COL_INSTANCE, COL_MEAN, COL_VARIANCE = range(4)

    def __init__(self):
        super().__init__()
        self.domain = None
        self.dataset = None
        self.clf_dataset = None
        self.tree_adapter = None

        self.color_label = QLabel("Target class: ")
        combo = self.color_combo = ComboBoxSearch()
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        combo.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLengthWithIcon)
        combo.setMinimumContentsLength(8)
        combo.activated[int].connect(self.color_changed)
        self.display_box.layout().addRow(self.color_label, combo)

    def set_node_info(self):
        """Set the content of the node"""
        for node in self.scene.nodes():
            node.set_rect(QRectF())
            self.update_node_info(node)
        w = max([n.rect().width() for n in self.scene.nodes()] + [0])
        if w > self.max_node_width:
            w = self.max_node_width
        for node in self.scene.nodes():
            rect = node.rect()
            node.set_rect(QRectF(rect.x(), rect.y(), w, rect.height()))
        self.scene.fix_pos(self.root_node, 10, 10)

    def _update_node_info_attr_name(self, node, text):
        attr = self.tree_adapter.attribute(node.node_inst)
        if attr is not None:
            text += "<hr/>{}".format(attr.name)
        return text

    def activate_loaded_settings(self):
        if not self.model:
            return
        super().activate_loaded_settings()
        if self.domain.class_var.is_discrete:
            self.color_combo.setCurrentIndex(self.target_class_index)
            self.toggle_node_color_cls()
        else:
            self.color_combo.setCurrentIndex(self.regression_colors)
            self.toggle_node_color_reg()
        self.set_node_info()

    def color_changed(self, i):
        if self.domain.class_var.is_discrete:
            self.target_class_index = i
            self.toggle_node_color_cls()
            self.set_node_info()
        else:
            self.regression_colors = i
            self.toggle_node_color_reg()

    def toggle_node_size(self):
        self.set_node_info()
        self.scene.update()
        self.scene_view.repaint()

    def toggle_color_cls(self):
        self.toggle_node_color_cls()
        self.set_node_info()
        self.scene.update()

    def toggle_color_reg(self):
        self.toggle_node_color_reg()
        self.set_node_info()
        self.scene.update()

    @Inputs.tree
    def ctree(self, model=None):
        """Input signal handler"""
        self.clear_scene()
        self.color_combo.clear()
        self.closeContext()
        self.model = model
        self.target_class_index = 0
        if model is None:
            self.infolabel.setText('No tree.')
            self.root_node = None
            self.dataset = None
            self.tree_adapter = None
        else:
            self.tree_adapter = self._get_tree_adapter(model)
            self.domain = model.domain
            self.dataset = model.instances
            if self.dataset is not None and self.dataset.domain != self.domain:
                self.clf_dataset = self.dataset.transform(model.domain)
            else:
                self.clf_dataset = self.dataset
            class_var = self.domain.class_var
            self.scene.colors = class_var.palette
            if class_var.is_discrete:
                self.color_label.setText("Target class: ")
                self.color_combo.addItem("None")
                self.color_combo.addItems(self.domain.class_vars[0].values)
                self.color_combo.setCurrentIndex(self.target_class_index)
            else:
                self.color_label.setText("Color by: ")
                self.color_combo.addItems(self.COL_OPTIONS)
                self.color_combo.setCurrentIndex(self.regression_colors)
            self.openContext(self.domain.class_var)
            # self.root_node = self.walkcreate(model.root, None)
            self.root_node = self.walkcreate(self.tree_adapter.root)
            self.infolabel.setText('{} nodes, {} leaves'.format(
                self.tree_adapter.num_nodes,
                len(self.tree_adapter.leaves(self.tree_adapter.root))))
        self.setup_scene()
        self.Outputs.selected_data.send(None)
        self.Outputs.annotated_data.send(create_annotated_table(self.dataset, []))

    def walkcreate(self, node, parent=None):
        """Create a structure of tree nodes from the given model"""
        node_obj = TreeNode(self.tree_adapter, node, parent)
        self.scene.addItem(node_obj)
        if parent:
            edge = GraphicsEdge(node1=parent, node2=node_obj)
            self.scene.addItem(edge)
            parent.graph_add_edge(edge)
        for child_inst in self.tree_adapter.children(node):
            if child_inst is not None:
                self.walkcreate(child_inst, node_obj)
        return node_obj

    def node_tooltip(self, node):
        return "<br>".join(to_html(str(rule))
                           for rule in self.tree_adapter.rules(node.node_inst))

    def update_selection(self):
        if self.model is None:
            return
        nodes = [item.node_inst for item in self.scene.selectedItems()
                 if isinstance(item, TreeNode)]
        data = self.tree_adapter.get_instances_in_nodes(nodes)

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(create_annotated_table(
            self.dataset, self.tree_adapter.get_indices(nodes)))

    def send_report(self):
        if not self.model:
            return
        items = [("Tree size", self.infolabel.text()),
                 ("Edge widths",
                  ("Fixed", "Relative to root", "Relative to parent")[
                      # pylint: disable=invalid-sequence-index
                      self.line_width_method])]
        if self.domain.class_var.is_discrete:
            items.append(("Target class", self.color_combo.currentText()))
        elif self.regression_colors != self.COL_DEFAULT:
            items.append(("Color by", self.COL_OPTIONS[self.regression_colors]))
        self.report_items(items)
        self.report_plot(self.scene)

    def update_node_info(self, node):
        if self.domain.class_var.is_discrete:
            self.update_node_info_cls(node)
        else:
            self.update_node_info_reg(node)

    def update_node_info_cls(self, node):
        """Update the printed contents of the node for classification trees"""
        node_inst = node.node_inst
        distr = self.tree_adapter.get_distribution(node_inst)[0]
        total = self.tree_adapter.num_samples(node_inst)
        distr = distr / np.sum(distr)
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = np.argmax(distr)
            tabs = distr[modus]
            text = f"<b>{self.domain.class_vars[0].values[int(modus)]}</b><br/>"
        if tabs > 0.999:
            text += f"100%, {total}/{total}"
        else:
            text += f"{100 * tabs:2.1f}%, {int(total * tabs)}/{total}"

        text = self._update_node_info_attr_name(node, text)
        node.setHtml(
            f'<p style="line-height: 120%; margin-bottom: 0">{text}</p>')

    def update_node_info_reg(self, node):
        """Update the printed contents of the node for regression trees"""
        node_inst = node.node_inst
        mean, var = self.tree_adapter.get_distribution(node_inst)[0]
        insts = self.tree_adapter.num_samples(node_inst)
        text = f"<b>{mean:.1f}</b> ± {var:.1f}<br/>"
        text += f"{insts} instances"
        text = self._update_node_info_attr_name(node, text)
        node.setHtml(
            f'<p style="line-height: 120%; margin-bottom: 0">{text}</p>')

    def toggle_node_color_cls(self):
        """Update the node color for classification trees"""
        colors = self.scene.colors
        for node in self.scene.nodes():
            distr = node.tree_adapter.get_distribution(node.node_inst)[0]
            total = sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = colors[self.target_class_index - 1].lighter(
                    200 - 100 * p)
            else:
                modus = np.argmax(distr)
                p = distr[modus] / (total or 1)
                color = colors.value_to_qcolor(int(modus))
                color = color.lighter(300 - 200 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()

    def toggle_node_color_reg(self):
        """Update the node color for regression trees"""
        def_color = QColor(192, 192, 255)
        if self.regression_colors == self.COL_DEFAULT:
            brush = QBrush(def_color.lighter(100))
            for node in self.scene.nodes():
                node.backgroundBrush = brush
        elif self.regression_colors == self.COL_INSTANCE:
            max_insts = len(self.tree_adapter.get_instances_in_nodes(
                [self.tree_adapter.root]))
            for node in self.scene.nodes():
                node_insts = len(self.tree_adapter.get_instances_in_nodes(
                    [node.node_inst]))
                node.backgroundBrush = QBrush(def_color.lighter(
                    120 - 20 * node_insts / max_insts))
        elif self.regression_colors == self.COL_MEAN:
            minv = np.nanmin(self.dataset.Y)
            maxv = np.nanmax(self.dataset.Y)
            colors = self.scene.colors
            for node in self.scene.nodes():
                node_mean = self.tree_adapter.get_distribution(node.node_inst)[0][0]
                color = colors.value_to_qcolor(node_mean, minv, maxv)
                node.backgroundBrush = QBrush(color)
        else:
            nodes = list(self.scene.nodes())
            variances = [self.tree_adapter.get_distribution(node.node_inst)[0][1]
                         for node in nodes]
            max_var = max(variances)
            for node, var in zip(nodes, variances):
                node.backgroundBrush = QBrush(def_color.lighter(
                    120 - 20 * var / max_var))
        self.scene.update()

    def _get_tree_adapter(self, model):
        if isinstance(model, SklModel):
            return SklTreeAdapter(model)
        return TreeAdapter(model)
Esempio n. 14
0
class OWClassificationTreeGraph(OWTreeGraph):
    name = "Classification Tree Viewer"
    description = "Graphical visualization of a classification tree."
    icon = "icons/ClassificationTreeGraph.svg"

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)

    inputs = [("Classification Tree", TreeClassifier, "ctree")]
    NODE = ClassificationTreeNode

    def __init__(self):
        super().__init__()
        self.target_combo = gui.comboBox(
            None, self, "target_class_index", orientation=0, items=[],
            callback=self.toggle_color, contentsLength=8, addToLayout=False,
            sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                   QSizePolicy.Fixed))
        self.display_box.layout().addRow("Target class ", self.target_combo)
        gui.rubber(self.controlArea)

    def ctree(self, model=None):
        super().ctree(model)
        if model is not None:
            self.target_combo.clear()
            self.target_combo.addItem("None")
            self.target_combo.addItems(self.domain.class_vars[0].values)
            self.target_combo.setCurrentIndex(self.target_class_index)

    def update_node_info(self, node):
        distr = node.get_distribution()
        total = int(node.num_instances())
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = node.majority()
            tabs = distr[modus]
            text = self.domain.class_vars[0].values[modus] + "<br/>"
        if tabs > 0.999:
            text += "100%, {}/{}".format(total, total)
        else:
            text += "{:2.1f}%, {}/{}".format(100 * tabs,
                                             int(total * tabs), total)

        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.
                     format(text))

    def toggle_node_color(self):
        colors = self.scene.colors
        for node in self.scene.nodes():
            distr = node.get_distribution()
            total = numpy.sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = colors[self.target_class_index - 1].light(200 - 100 * p)
            else:
                modus = node.majority()
                p = distr[modus] / (total or 1)
                color = colors[int(modus)].light(400 - 300 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()
Esempio n. 15
0
class OWRegressionTreeGraph(OWTreeGraph):
    name = "Regression Tree Viewer"
    description = "Graphical visualization of a regression tree."
    icon = "icons/RegressionTreeGraph.svg"

    settingsHandler = ClassValuesContextHandler()
    color_index = Setting(0)
    color_settings = Setting(None)
    selected_color_settings_index = Setting(0)

    inputs = [("Regression Tree", TreeRegressor, "ctree")]
    NODE = RegressionTreeNode

    def __init__(self):
        super().__init__()
        box = gui.widgetBox(self.controlArea, "Nodes", addSpace=True)
        self.color_combo = gui.comboBox(box,
                                        self,
                                        "color_index",
                                        orientation=0,
                                        items=[],
                                        label="Target class",
                                        callback=self.toggle_color,
                                        contentsLength=8)
        gui.separator(box)
        gui.button(box, self, "Set Colors", callback=self.set_colors)
        gui.rubber(self.controlArea)

    def ctree(self, model=None):
        super().ctree(model)
        if model is not None:
            self.color_combo.clear()
            self.color_combo.addItem("Default")
            self.color_combo.addItem("Instances in node")
            self.color_combo.addItem("Impurity")
            self.color_combo.setCurrentIndex(self.color_index)

    def update_node_info(self, node):
        distr = node.get_distribution()
        total = node.num_instances()
        total_tree = self.tree.n_node_samples[0]
        impurity = node.impurity()
        text = "{:2.1f}<br/>".format(sum(distr.reshape(1)))
        text += "{:2.1f}%, {}/{}<br/>".format(100 * total / total_tree, total,
                                              total_tree)
        text += "{:2.3f}".format(impurity)
        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.format(text))

    def toggle_node_color(self):
        palette = self.scene.colorPalette
        all_instances = self.tree.n_node_samples[0]
        max_impurity = self.tree.impurity[0]
        for node in self.scene.nodes():
            li = [
                0.5,
                node.num_instances() / all_instances,
                node.impurity() / max_impurity
            ][self.color_index]
            node.backgroundBrush = QBrush(
                palette[self.color_index].light(180 - li * 150))
        self.scene.update()
class OWExplainModel(OWWidget, ConcurrentWidgetMixin):
    name = "Explain Model"
    description = "Model explanation widget."
    keywords = ["explain", "explain prediction", "explain model"]
    icon = "icons/ExplainModel.svg"
    priority = 100
    replaces = [
        "orangecontrib.prototypes.widgets.owexplainmodel.OWExplainModel"
    ]

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model)

    class Outputs:
        selected_data = Output("Selected Data", Table)
        scores = Output("Scores", Table)

    class Error(OWWidget.Error):
        domain_transform_err = Msg("{}")
        unknown_err = Msg("{}")

    class Information(OWWidget.Information):
        data_sampled = Msg("Data has been sampled.")

    settingsHandler = ClassValuesContextHandler()
    target_index = ContextSetting(0)
    n_attributes = Setting(10)
    show_legend = Setting(True)
    selection = Setting((), schema_only=True)  # type: Tuple[str, List[int]]
    auto_send = Setting(True)
    visual_settings = Setting({}, schema_only=True)

    graph_name = "scene"

    def __init__(self):
        OWWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        self.__results = None  # type: Optional[Results]
        self.data = None  # type: Optional[Table]
        self.model = None  # type: Optional[Model]
        self._violin_plot = None  # type: Optional[ViolinPlot]
        self.setup_gui()
        self.__pending_selection = self.selection

        initial = ViolinPlot().parameter_setter.initial_settings
        VisualSettingsDialog(self, initial)

    def setup_gui(self):
        self._add_controls()
        self._add_plot()
        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

    def _add_plot(self):
        self.scene = GraphicsScene()
        self.view = GraphicsView(self.scene)
        self.view.resized.connect(self.update_plot)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def _add_controls(self):
        box = gui.vBox(self.controlArea, "Target class")
        self._target_combo = gui.comboBox(
            box, self, "target_index",
            callback=self.__target_combo_changed,
            contentsLength=12)

        box = gui.hBox(self.controlArea, "Display features")
        gui.label(box, self, "Best ranked: ")
        gui.spin(box, self, "n_attributes", 1, ViolinPlot.MAX_N_ITEMS,
                 controlWidth=80, callback=self.__n_spin_changed)
        box = gui.hBox(self.controlArea, True)
        gui.checkBox(box, self, "show_legend", "Show legend",
                     callback=self.__show_check_changed)

        gui.rubber(self.controlArea)
        box = gui.vBox(self.controlArea, box=True)
        gui.auto_send(box, self, "auto_send", box=False)

    def __target_combo_changed(self):
        self.update_scene()
        self.clear_selection()

    def __n_spin_changed(self):
        if self._violin_plot is not None:
            self._violin_plot.set_n_visible(self.n_attributes)

    def __show_check_changed(self):
        if self._violin_plot is not None:
            self._violin_plot.show_legend(self.show_legend)

    @Inputs.data
    @check_sql_input
    def set_data(self, data: Optional[Table]):
        self.data = data
        summary = len(data) if data else self.info.NoInput
        details = format_summary_details(data) if data else ""
        self.info.set_input_summary(summary, details)

    @Inputs.model
    def set_model(self, model: Optional[Model]):
        self.closeContext()
        self.model = model
        self.setup_controls()
        self.openContext(self.model.domain.class_var if self.model else None)

    def setup_controls(self):
        self._target_combo.clear()
        self._target_combo.setEnabled(True)
        if self.model is not None:
            if self.model.domain.has_discrete_class:
                self._target_combo.addItems(self.model.domain.class_var.values)
                self.target_index = 0
            elif self.model.domain.has_continuous_class:
                self.target_index = -1
                self._target_combo.setEnabled(False)
            else:
                raise NotImplementedError

    def handleNewSignals(self):
        self.clear()
        self.start(run, self.data, self.model)

    def clear(self):
        self.__results = None
        self.cancel()
        self.clear_selection()
        self.clear_scene()
        self.clear_messages()

    def clear_selection(self):
        if self.selection:
            self.selection = ()
            self.commit()

    def clear_scene(self):
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self.view.setSceneRect(QRectF())
        self.view.setHeaderSceneRect(QRectF())
        self.view.setFooterSceneRect(QRectF())
        self._violin_plot = None

    def commit(self):
        if not self.selection or not self.selection[1]:
            self.info.set_output_summary(self.info.NoOutput)
            self.Outputs.selected_data.send(None)
        else:
            data = self.data[self.selection[1]]
            detail = format_summary_details(data)
            self.info.set_output_summary(len(data), detail)
            self.Outputs.selected_data.send(data)

    def update_scene(self):
        self.clear_scene()
        scores = None
        if self.__results is not None:
            assert isinstance(self.__results.x, list)
            x = self.__results.x[self.target_index]
            scores_x = np.mean(np.abs(x), axis=0)
            indices = np.argsort(scores_x)[::-1]
            colors = self.__results.colors
            names = [self.__results.names[i] for i in indices]
            if x.shape[1]:
                self.setup_plot(x[:, indices], colors[:, indices], names)
            scores = self.create_scores_table(scores_x, self.__results.names)
        self.Outputs.scores.send(scores)

    def setup_plot(self, x: np.ndarray, colors: np.ndarray, names: List[str]):
        width = int(self.view.viewport().rect().width())
        self._violin_plot = ViolinPlot()
        self._violin_plot.set_data(x, colors, names, self.n_attributes, width)
        self._violin_plot.show_legend(self.show_legend)
        self._violin_plot.apply_visual_settings(self.visual_settings)
        self._violin_plot.selection_cleared.connect(self.clear_selection)
        self._violin_plot.selection_changed.connect(self.update_selection)
        self._violin_plot.layout().activate()
        self._violin_plot.geometryChanged.connect(self.update_scene_rect)
        self._violin_plot.resized.connect(self.update_plot)
        self.scene.addItem(self._violin_plot)
        self.scene.mouse_clicked.connect(self._violin_plot.deselect)
        self.update_scene_rect()
        self.update_plot()

    def update_plot(self):
        if self._violin_plot is not None:
            width = int(self.view.viewport().rect().width())
            self._violin_plot.rescale(width)

    def update_selection(self, min_val: float, max_val: float, attr_name: str):
        assert self.__results is not None
        x = self.__results.x[self.target_index]
        column = self.__results.names.index(attr_name)
        mask = self.__results.mask.copy()
        mask[self.__results.mask] = np.logical_and(x[:, column] <= max_val,
                                                   x[:, column] >= min_val)
        if not self.selection and not any(mask):
            return
        self.selection = (attr_name, list(np.flatnonzero(mask)))
        self.commit()

    def update_scene_rect(self):
        def extend_horizontal(rect):
            rect = QRectF(rect)
            rect.setLeft(geom.left())
            rect.setRight(geom.right())
            return rect

        geom = self._violin_plot.geometry()
        self.scene.setSceneRect(geom)
        self.view.setSceneRect(geom)

        footer_geom = self._violin_plot.bottom_axis.geometry()
        footer = extend_horizontal(footer_geom.adjusted(0, -3, 0, 10))
        self.view.setFooterSceneRect(footer)

    @staticmethod
    def create_scores_table(scores: np.ndarray, names: List[str]):
        domain = Domain([ContinuousVariable("Score")],
                        metas=[StringVariable("Feature")])
        scores_table = Table(domain, scores[:, None],
                             metas=np.array(names)[:, None])
        scores_table.name = "Feature Scores"
        return scores_table

    def on_partial_result(self, _):
        pass

    def on_done(self, results: Optional[Results]):
        self.__results = results
        if self.data and results is not None and not all(results.mask):
            self.Information.data_sampled()
        self.update_scene()
        self.select_pending()

    def select_pending(self):
        if not self.__pending_selection or not self.__pending_selection[1] \
                or self.__results is None:
            return

        attr_name, row_indices = self.__pending_selection
        names = self.__results.names
        if not names or attr_name not in names:
            return
        col_index = names.index(attr_name)
        mask = np.zeros(self.__results.mask.shape, dtype=bool)
        mask[row_indices] = True
        mask = np.logical_and(self.__results.mask, mask)
        row_indices = np.flatnonzero(mask[self.__results.mask])
        column = self.__results.x[self.target_index][row_indices, col_index]
        x1, x2 = np.min(column), np.max(column)
        self._violin_plot.select_from_settings(x1, x2, attr_name)
        self.__pending_selection = ()
        self.unconditional_commit()

    def on_exception(self, ex: Exception):
        if isinstance(ex, DomainTransformationError):
            self.Error.domain_transform_err(ex)
        else:
            self.Error.unknown_err(ex)

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(800, 520))

    def send_report(self):
        if not self.data or not self.model:
            return
        items = {"Target class": "None"}
        if self.model.domain.has_discrete_class:
            class_var = self.model.domain.class_var
            items["Target class"] = class_var.values[self.target_index]
        self.report_items(items)
        self.report_plot()

    def set_visual_settings(self, key, value):
        self.visual_settings[key] = value
        if self._violin_plot is not None:
            self._violin_plot.parameter_setter.set_parameter(key, value)
class TestClassValuesContextHandler(TestCase):
    def setUp(self):
        self.domain = Domain(
            attributes=[
                ContinuousVariable("c1"),
                DiscreteVariable("d1", values="abc"),
                DiscreteVariable("d2", values="def"),
            ],
            class_vars=[DiscreteVariable("d3", values="ghi")],
            metas=[
                ContinuousVariable("c2"),
                DiscreteVariable("d4", values="jkl")
            ],
        )
        self.args = (
            self.domain,
            {
                "c1": Continuous,
                "d1": Discrete,
                "d2": Discrete,
                "d3": Discrete
            },
            {
                "c2": Continuous,
                "d4": Discrete
            },
        )
        self.handler = ClassValuesContextHandler()
        self.handler.read_defaults = lambda: None

    def test_open_context(self):
        self.handler.bind(SimpleWidget)
        context = Mock(
            classes=["g", "h", "i"],
            values=dict(text="u",
                        with_metas=[("d1", Discrete), ("d2", Discrete)]),
        )
        self.handler.global_contexts = [
            Mock(values={}), context,
            Mock(values={})
        ]

        widget = SimpleWidget()
        self.handler.initialize(widget)
        self.handler.open_context(widget, self.args[0].class_var)
        self.assertEqual(widget.text, "u")
        self.assertEqual(widget.with_metas, [("d1", Discrete),
                                             ("d2", Discrete)])

    def test_open_context_with_no_match(self):
        self.handler.bind(SimpleWidget)
        context = Mock(
            classes=["g", "h", "i"],
            values=dict(text="u",
                        with_metas=[("d1", Discrete), ("d2", Discrete)]),
        )
        self.handler.global_contexts = [
            Mock(values={}), context,
            Mock(values={})
        ]
        widget = SimpleWidget()
        self.handler.initialize(widget)
        widget.text = "u"

        self.handler.open_context(widget, self.args[0][1])

        context = widget.current_context
        self.assertEqual(context.classes, ["a", "b", "c"])
        self.assertEqual(widget.text, "u")
        self.assertEqual(widget.with_metas, [])
class OWClassificationTreeGraph(OWTreeGraph):
    name = "Classification Tree Viewer"
    description = "Graphical visualization of a classification tree."
    icon = "icons/ClassificationTree.svg"

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    color_settings = Setting(None)
    selected_color_settings_index = Setting(0)

    inputs = [("Classification Tree", TreeClassifier, "ctree")]
    NODE = ClassificationTreeNode

    def __init__(self):
        super().__init__()
        box = gui.widgetBox(self.controlArea, "Nodes", addSpace=True)
        self.target_combo = gui.comboBox(box,
                                         self,
                                         "target_class_index",
                                         orientation=0,
                                         items=[],
                                         label="Target class",
                                         callback=self.toggle_color,
                                         contentsLength=8)
        gui.separator(box)
        gui.button(box, self, "Set Colors", callback=self.set_colors)
        gui.rubber(self.controlArea)

    def ctree(self, model=None):
        super().ctree(model)
        if model is not None:
            self.target_combo.clear()
            self.target_combo.addItem("None")
            self.target_combo.addItems(self.domain.class_vars[0].values)
            self.target_combo.setCurrentIndex(self.target_class_index)

    def update_node_info(self, node):
        distr = node.get_distribution()
        total = int(node.num_instances())
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = node.majority()
            tabs = distr[modus]
            text = self.domain.class_vars[0].values[modus] + "<br/>"
        if tabs > 0.999:
            text += "100%, {}/{}".format(total, total)
        else:
            text += "{:2.1f}%, {}/{}".format(100 * tabs, int(total * tabs),
                                             total)

        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.format(text))

    def toggle_node_color(self):
        palette = self.scene.colorPalette
        for node in self.scene.nodes():
            distr = node.get_distribution()
            total = numpy.sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = palette[self.target_class_index].light(200 - 100 * p)
            else:
                modus = node.majority()
                p = distr[modus] / (total or 1)
                color = palette[int(modus)].light(400 - 300 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()
class OWExplainPrediction(OWWidget, ConcurrentWidgetMixin):
    name = "Explain Prediction"
    description = "Prediction explanation widget."
    icon = "icons/ExplainPred.svg"
    priority = 110

    class Inputs:
        model = Input("Model", Model)
        background_data = Input("Background Data", Table)
        data = Input("Data", Table)

    class Outputs:
        scores = Output("Scores", Table)

    class Error(OWWidget.Error):
        domain_transform_err = Msg("{}")
        unknown_err = Msg("{}")

    class Information(OWWidget.Information):
        multiple_instances = Msg("Explaining prediction for the first "
                                 "instance in 'Data'.")

    settingsHandler = ClassValuesContextHandler()
    target_index = ContextSetting(0)
    stripe_len = Setting(10)

    graph_name = "scene"

    def __init__(self):
        OWWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        self.__results = None  # type: Optional[Results]
        self.model = None  # type: Optional[Model]
        self.background_data = None  # type: Optional[Table]
        self.data = None  # type: Optional[Table]
        self._stripe_plot = None  # type: Optional[StripePlot]
        self.mo_info = ""
        self.bv_info = ""
        self.setup_gui()

    def setup_gui(self):
        self._add_controls()
        self._add_plot()
        self.info.set_input_summary(self.info.NoInput)

    def _add_plot(self):
        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignVCenter | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def _add_controls(self):
        box = gui.vBox(self.controlArea, "Target class")
        self._target_combo = gui.comboBox(box,
                                          self,
                                          "target_index",
                                          callback=self.__target_combo_changed,
                                          contentsLength=12)

        box = gui.hBox(self.controlArea, "Zoom")
        gui.hSlider(box,
                    self,
                    "stripe_len",
                    None,
                    minValue=1,
                    maxValue=500,
                    createLabel=False,
                    callback=self.__size_slider_changed)

        gui.rubber(self.controlArea)

        box = gui.vBox(self.controlArea, "Prediction info")
        gui.label(box, self, "%(mo_info)s")  # type: QLabel
        bv_label = gui.label(box, self, "%(bv_info)s")  # type: QLabel
        bv_label.setToolTip("The average prediction for selected class.")

    def __target_combo_changed(self):
        self.update_scene()

    def __size_slider_changed(self):
        if self._stripe_plot is not None:
            self._stripe_plot.set_height(self.stripe_len)

    @Inputs.data
    @check_sql_input
    def set_data(self, data: Optional[Table]):
        self.data = data

    @Inputs.background_data
    @check_sql_input
    def set_background_data(self, data: Optional[Table]):
        self.background_data = data

    @Inputs.model
    def set_model(self, model: Optional[Model]):
        self.closeContext()
        self.model = model
        self.setup_controls()
        self.openContext(self.model.domain.class_var if self.model else None)

    def setup_controls(self):
        self._target_combo.clear()
        self._target_combo.setEnabled(True)
        if self.model is not None:
            if self.model.domain.has_discrete_class:
                self._target_combo.addItems(self.model.domain.class_var.values)
                self.target_index = 0
            elif self.model.domain.has_continuous_class:
                self.target_index = -1
                self._target_combo.setEnabled(False)
            else:
                raise NotImplementedError

    def handleNewSignals(self):
        self.clear()
        self.check_inputs()
        data = self.data and self.data[:1]
        self.start(run, data, self.background_data, self.model)

    def clear(self):
        self.mo_info = ""
        self.bv_info = ""
        self.__results = None
        self.cancel()
        self.clear_scene()
        self.clear_messages()

    def check_inputs(self):
        if self.data and len(self.data) > 1:
            self.Information.multiple_instances()

        summary, details, kwargs = self.info.NoInput, "", {}
        if self.data or self.background_data:
            n_data = len(self.data) if self.data else 0
            n_background_data = len(self.background_data) \
                if self.background_data else 0
            summary = f"{self.info.format_number(n_background_data)}, " \
                      f"{self.info.format_number(n_data)}"
            kwargs = {"format": Qt.RichText}
            details = format_multiple_summaries([("Background data",
                                                  self.background_data),
                                                 ("Data", self.data)])
        self.info.set_input_summary(summary, details, **kwargs)

    def clear_scene(self):
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self.view.setSceneRect(QRectF())
        self._stripe_plot = None

    def update_scene(self):
        self.clear_scene()
        self.mo_info = ""
        self.bv_info = ""
        scores = None
        if self.__results is not None:
            data = self.__results.transformed_data
            pred = self.__results.predictions
            base = self.__results.base_value
            values, _, labels, ranges = prepare_force_plot_data(
                self.__results.values, data, pred, self.target_index)

            index = 0
            HIGH, LOW = 0, 1
            plot_data = PlotData(high_values=values[index][HIGH],
                                 low_values=values[index][LOW][::-1],
                                 high_labels=labels[index][HIGH],
                                 low_labels=labels[index][LOW][::-1],
                                 value_range=ranges[index],
                                 model_output=pred[index][self.target_index],
                                 base_value=base[self.target_index])
            self.setup_plot(plot_data)

            self.mo_info = f"Model prediction: {_str(plot_data.model_output)}"
            self.bv_info = f"Base value: {_str(plot_data.base_value)}"

            assert isinstance(self.__results.values, list)
            scores = self.__results.values[self.target_index][0, :]
            names = [a.name for a in data.domain.attributes]
            scores = self.create_scores_table(scores, names)
        self.Outputs.scores.send(scores)

    def setup_plot(self, plot_data: PlotData):
        self._stripe_plot = StripePlot()
        self._stripe_plot.set_data(plot_data, self.stripe_len)
        self._stripe_plot.layout().activate()
        self._stripe_plot.geometryChanged.connect(self.update_scene_rect)
        self.scene.addItem(self._stripe_plot)
        self.update_scene_rect()

    def update_scene_rect(self):
        geom = self._stripe_plot.geometry()
        self.scene.setSceneRect(geom)
        self.view.setSceneRect(geom)

    @staticmethod
    def create_scores_table(scores: np.ndarray, names: List[str]) -> Table:
        domain = Domain([ContinuousVariable("Score")],
                        metas=[StringVariable("Feature")])
        scores_table = Table(domain,
                             scores[:, None],
                             metas=np.array(names)[:, None])
        scores_table.name = "Feature Scores"
        return scores_table

    def on_partial_result(self, _):
        pass

    def on_done(self, results: Optional[RunnerResults]):
        self.__results = results
        self.update_scene()

    def on_exception(self, ex: Exception):
        if isinstance(ex, DomainTransformationError):
            self.Error.domain_transform_err(ex)
        else:
            self.Error.unknown_err(ex)

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def sizeHint(self) -> QSizeF:
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(700, 700))

    def send_report(self):
        if not self.data or not self.background_data or not self.model:
            return
        items = {"Target class": "None"}
        if self.model.domain.has_discrete_class:
            class_var = self.model.domain.class_var
            items["Target class"] = class_var.values[self.target_index]
        self.report_items(items)
        self.report_plot()
Esempio n. 20
0
class OWExplainModel(OWExplainFeatureBase):
    name = "Explain Model"
    description = "Model explanation widget."
    keywords = ["explain", "explain prediction", "explain model"]
    icon = "icons/ExplainModel.svg"
    priority = 100
    replaces = [
        "orangecontrib.prototypes.widgets.owexplainmodel.OWExplainModel"
    ]

    class Outputs(OWExplainFeatureBase.Outputs):
        impact = Output("Impact", Table)

    settingsHandler = ClassValuesContextHandler()
    target_index = ContextSetting(0)
    show_legend = Setting(True)

    PLOT_CLASS = ViolinPlot

    def __init__(self):
        self._target_combo: QComboBox = None
        super().__init__()

    # GUI setup
    def _add_controls(self):
        box = gui.vBox(self.controlArea, "Target class")
        self._target_combo = gui.comboBox(box,
                                          self,
                                          "target_index",
                                          callback=self.__target_combo_changed,
                                          contentsLength=12)

        super()._add_controls()
        gui.checkBox(self.display_box,
                     self,
                     "show_legend",
                     "Show legend",
                     callback=self.__show_check_changed)

    def __target_combo_changed(self):
        self.update_scene()
        self.update_scores()
        self.update_impact()
        self._clear_selection()

    def __show_check_changed(self):
        if self.plot is not None:
            self.plot.show_legend(self.show_legend)

    def openContext(self, model: Optional[Model]):
        super().openContext(model.domain.class_var if model else None)

    def setup_controls(self):
        self._target_combo.clear()
        self._target_combo.setEnabled(True)
        if self.model is not None:
            if self.model.domain.has_discrete_class:
                self._target_combo.addItems(self.model.domain.class_var.values)
                self.target_index = 0
            elif self.model.domain.has_continuous_class:
                self.target_index = -1
                self._target_combo.setEnabled(False)
            else:
                raise NotImplementedError

    # Plot setup
    def update_scene(self):
        super().update_scene()
        if self.results is not None:
            assert isinstance(self.results.x, list)
            x = self.results.x[self.target_index]
            scores_x = np.mean(np.abs(x), axis=0)
            indices = np.argsort(scores_x)[::-1]
            colors = self.results.colors
            names = [self.results.names[i] for i in indices]
            if x.shape[1]:
                self.setup_plot(x[:, indices], names, colors[:, indices])

    def setup_plot(self, values, names, *plot_args):
        super().setup_plot(values, names, *plot_args)
        self.plot.show_legend(self.show_legend)

    # Selection
    def update_selection(self, min_val: float, max_val: float, attr_name: str):
        assert self.results is not None
        x = self.results.x[self.target_index]
        column = self.results.names.index(attr_name)
        mask = self.results.mask.copy()
        mask[self.results.mask] = np.logical_and(x[:, column] <= max_val,
                                                 x[:, column] >= min_val)
        if not self.selection and not any(mask):
            return
        self.selection = (attr_name, list(np.flatnonzero(mask)))
        self.commit()

    def select_pending(self, pending_selection: Tuple):
        if not pending_selection or not pending_selection[1] \
                or self.results is None:
            return

        attr_name, row_indices = pending_selection
        names = self.results.names
        if not names or attr_name not in names:
            return
        col_index = names.index(attr_name)
        mask = np.zeros(self.results.mask.shape, dtype=bool)
        mask[row_indices] = True
        mask = np.logical_and(self.results.mask, mask)
        row_indices = np.flatnonzero(mask[self.results.mask])
        column = self.results.x[self.target_index][row_indices, col_index]
        x1, x2 = np.min(column), np.max(column)
        self.plot.select_from_settings(x1, x2, attr_name)
        super().select_pending(())

    # Outputs
    def get_selected_data(self):
        return self.data[self.selection[1]] \
            if self.selection and self.selection[1] else None

    def get_scores_table(self) -> Table:
        x = self.results.x[self.target_index]
        scores = np.mean(np.abs(x), axis=0)
        domain = Domain([ContinuousVariable("Score")],
                        metas=[StringVariable("Feature")])
        scores_table = Table(domain,
                             scores[:, None],
                             metas=np.array(self.results.names)[:, None])
        scores_table.name = "Feature Scores"
        return scores_table

    def update_impact(self):
        impact = None
        if self.results is not None:
            impact = self.get_impact_table()
        self.Outputs.impact.send(impact)

    def get_impact_table(self) -> Table:
        data = self.data
        x = self.results.x[self.target_index]
        mask = self.results.mask
        proposed = [f"I({n})" for n in self.results.names]
        names = [v.name for v in data.domain.class_vars + data.domain.metas]
        proposed = get_unique_names(names, proposed)
        domain = Domain([ContinuousVariable(n) for n in proposed],
                        data.domain.class_vars,
                        metas=data.domain.metas)
        impact_table = Table(domain, x, data.Y[mask], data.metas[mask])
        impact_table.name = "Feature Impact"
        return impact_table

    # Concurrent
    def on_done(self, results: Optional[BaseResults]):
        super().on_done(results)
        self.update_impact()

    # Misc
    def send_report(self):
        if not self.data or not self.model:
            return
        items = {"Target class": "None"}
        if self.model.domain.has_discrete_class:
            class_var = self.model.domain.class_var
            items["Target class"] = class_var.values[self.target_index]
        self.report_items(items)
        super().send_report()

    @staticmethod
    def run(data: Table, model: Model, state: TaskState) -> Results:
        if not data or not model:
            return None

        def callback(i: float, status=""):
            state.set_progress_value(i * 100)
            if status:
                state.set_status(status)
            if state.is_interruption_requested():
                raise Exception

        x, names, mask, colors = get_shap_values_and_colors(
            model, data, callback)
        return Results(x=x, colors=colors, names=names, mask=mask)