Exemplo n.º 1
0
class OWQualityControl(widget.OWWidget):
    name = "Quality Control"
    description = "Experiment quality control"
    icon = "../widgets/icons/QualityControl.svg"
    priority = 5000

    inputs = [("Experiment Data", Orange.data.Table, "set_data")]
    outputs = []

    DISTANCE_FUNCTIONS = [("Distance from Pearson correlation",
                           dist_pcorr),
                          ("Euclidean distance",
                           dist_eucl),
                          ("Distance from Spearman correlation",
                           dist_spearman)]

    settingsHandler = SetContextHandler()

    split_by_labels = settings.ContextSetting({})
    sort_by_labels = settings.ContextSetting({})

    selected_distance_index = settings.Setting(0)

    def __init__(self, parent=None):
        super().__init__(parent)

        ## Attributes
        self.data = None
        self.distances = None
        self.groups = None
        self.unique_pos = None
        self.base_group_index = 0

        ## GUI
        box = gui.widgetBox(self.controlArea, "Info")
        self.info_box = gui.widgetLabel(box, "\n")

        ## Separate By box
        box = gui.widgetBox(self.controlArea, "Separate By")
        self.split_by_model = itemmodels.PyListModel(parent=self)
        self.split_by_view = QListView()
        self.split_by_view.setSelectionMode(QListView.ExtendedSelection)
        self.split_by_view.setModel(self.split_by_model)
        box.layout().addWidget(self.split_by_view)

        self.split_by_view.selectionModel().selectionChanged.connect(
            self.on_split_key_changed)

        ## Sort By box
        box = gui.widgetBox(self.controlArea, "Sort By")
        self.sort_by_model = itemmodels.PyListModel(parent=self)
        self.sort_by_view = QListView()
        self.sort_by_view.setSelectionMode(QListView.ExtendedSelection)
        self.sort_by_view.setModel(self.sort_by_model)
        box.layout().addWidget(self.sort_by_view)

        self.sort_by_view.selectionModel().selectionChanged.connect(
            self.on_sort_key_changed)

        ## Distance box
        box = gui.widgetBox(self.controlArea, "Distance Measure")
        gui.comboBox(box, self, "selected_distance_index",
                     items=[name for name, _ in self.DISTANCE_FUNCTIONS],
                     callback=self.on_distance_measure_changed)

        self.scene = QGraphicsScene()
        self.scene_view = QGraphicsView(self.scene)
        self.scene_view.setRenderHints(QPainter.Antialiasing)
        self.scene_view.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
        self.mainArea.layout().addWidget(self.scene_view)

        self.scene_view.installEventFilter(self)

        self._disable_updates = False
        self._cached_distances = {}
        self._base_index_hints = {}
        self.main_widget = None

        self.resize(800, 600)

    def clear(self):
        """Clear the widget state."""
        self.data = None
        self.distances = None
        self.groups = None
        self.unique_pos = None

        with disable_updates(self):
            self.split_by_model[:] = []
            self.sort_by_model[:] = []

        self.main_widget = None
        self.scene.clear()
        self.info_box.setText("\n")
        self._cached_distances = {}

    def set_data(self, data=None):
        """Set input experiment data."""
        self.closeContext()
        self.clear()

        self.error(0)
        self.warning(0)

        if data is not None:
            keys = self.get_suitable_keys(data)
            if not keys:
                self.error(0, "Data has no suitable feature labels.")
                data = None

        self.data = data
        if data is not None:
            self.on_new_data()

    def update_label_candidates(self):
        """Update the label candidates selection GUI 
        (Group/Sort By views).

        """
        keys = self.get_suitable_keys(self.data)
        with disable_updates(self):
            self.split_by_model[:] = keys
            self.sort_by_model[:] = keys

    def get_suitable_keys(self, data):
        """ Return suitable attr label keys from the data where
        the key has at least two unique values in the data.

        """
        attrs = [attr.attributes.items() for attr in data.domain.attributes]
        attrs = reduce(operator.iadd, attrs, [])
        # in case someone put non string values in attributes dict
        attrs = [(str(key), str(value)) for key, value in attrs]
        attrs = set(attrs)
        values = defaultdict(set)
        for key, value in attrs:
            values[key].add(value)
        keys = [key for key in values if len(values[key]) > 1]
        return keys

    def selected_split_by_labels(self):
        """Return the current selected split labels.
        """
        sel_m = self.split_by_view.selectionModel()
        indices = [r.row() for r in sel_m.selectedRows()]
        return [self.sort_by_model[i] for i in indices]

    def selected_sort_by_labels(self):
        """Return the current selected sort labels
        """
        sel_m = self.sort_by_view.selectionModel()
        indices = [r.row() for r in sel_m.selectedRows()]
        return [self.sort_by_model[i] for i in indices]

    def selected_distance(self):
        """Return the selected distance function.
        """
        return self.DISTANCE_FUNCTIONS[self.selected_distance_index][1]

    def selected_base_group_index(self):
        """Return the selected base group index
        """
        return self.base_group_index

    def selected_base_indices(self, base_group_index=None):
        indices = []
        for g, ind in self.groups:
            if base_group_index is None:
                label = group_label(self.selected_split_by_labels(), g)
                ind = [i for i in ind if i is not None]
                i = self._base_index_hints.get(label, ind[0] if ind else None)
            else:
                i = ind[base_group_index]
            indices.append(i)
        return indices

    def on_new_data(self):
        """We have new data and need to recompute all.
        """
        self.closeContext()

        self.update_label_candidates()
        self.info_box.setText(
            "%s genes \n%s experiments" %
            (len(self.data),  len(self.data.domain.attributes))
        )

        self.base_group_index = 0

        keys = self.get_suitable_keys(self.data)
        self.openContext(keys)

        ## Restore saved context settings (split/sort selection)
        split_by_labels = self.split_by_labels
        sort_by_labels = self.sort_by_labels

        def select(model, selection_model, selected_items):
            """Select items in a Qt item model view
            """
            all_items = list(model)
            try:
                indices = [all_items.index(item) for item in selected_items]
            except:
                indices = []
            for ind in indices:
                selection_model.select(model.index(ind),
                                       QItemSelectionModel.Select)

        with disable_updates(self):
            select(self.split_by_view.model(),
                   self.split_by_view.selectionModel(),
                   split_by_labels)

            select(self.sort_by_view.model(),
                   self.sort_by_view.selectionModel(),
                   sort_by_labels)

        with widget_disable(self):
            self.split_and_update()

    def on_split_key_changed(self, *args):
        """Split key has changed
        """
        with widget_disable(self):
            if not self._disable_updates:
                self.base_group_index = 0
                self.split_by_labels = self.selected_split_by_labels()
                self.split_and_update()

    def on_sort_key_changed(self, *args):
        """Sort key has changed
        """
        with widget_disable(self):
            if not self._disable_updates:
                self.base_group_index = 0
                self.sort_by_labels = self.selected_sort_by_labels()
                self.split_and_update()

    def on_distance_measure_changed(self):
        """Distance measure has changed
        """
        if self.data is not None:
            with widget_disable(self):
                self.update_distances()
                self.replot_experiments()

    def on_view_resize(self, size):
        """The view with the quality plot has changed
        """
        if self.main_widget:
            current = self.main_widget.size()
            self.main_widget.resize(size.width() - 6,
                                    current.height())

            self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def on_rug_item_clicked(self, item):
        """An ``item`` in the quality plot has been clicked.
        """
        update = False
        sort_by_labels = self.selected_sort_by_labels()
        if sort_by_labels and item.in_group:
            ## The item is part of the group
            if item.group_index != self.base_group_index:
                self.base_group_index = item.group_index
                update = True

        else:
            if sort_by_labels:
                # If the user clicked on an background item it
                # invalidates the sorted labels selection
                with disable_updates(self):
                    self.sort_by_view.selectionModel().clear()
                    update = True

            index = item.index
            group = item.group
            label = group_label(self.selected_split_by_labels(), group)

            if self._base_index_hints.get(label, 0) != index:
                self._base_index_hints[label] = index
                update = True

        if update:
            with widget_disable(self):
                self.split_and_update()

    def eventFilter(self, obj, event):
        if obj is self.scene_view and event.type() == QEvent.Resize:
            self.on_view_resize(event.size())
        return super().eventFilter(obj, event)

    def split_and_update(self):
        """
        Split the data based on the selected sort/split labels
        and update the quality plot.

        """
        split_labels = self.selected_split_by_labels()
        sort_labels = self.selected_sort_by_labels()

        self.warning(0)
        if not split_labels:
            self.warning(0, "No separate by label selected.")

        self.groups, self.unique_pos = \
                exp.separate_by(self.data, split_labels,
                                consider=sort_labels,
                                add_empty=True)

        self.groups = sorted(self.groups.items(),
                             key=lambda t: list(map(float_if_posible, t[0])))
        self.unique_pos = sorted(self.unique_pos.items(),
                                 key=lambda t: list(map(float_if_posible, t[0])))

        if self.groups:
            if sort_labels:
                group_base = self.selected_base_group_index()
                base_indices = self.selected_base_indices(group_base)
            else:
                base_indices = self.selected_base_indices()
            self.update_distances(base_indices)
            self.replot_experiments()

    def get_cached_distances(self, measure):
        if measure not in self._cached_distances:
            attrs = self.data.domain.attributes
            mat = numpy.zeros((len(attrs), len(attrs)))

            self._cached_distances[measure] = \
                (mat, set(zip(range(len(attrs)), range(len(attrs)))))

        return self._cached_distances[measure]

    def get_cached_distance(self, measure, i, j):
        matrix, computed = self.get_cached_distances(measure)
        key = (i, j) if i < j else (j, i)
        if key in computed:
            return matrix[i, j]
        else:
            return None

    def get_distance(self, measure, i, j):
        d = self.get_cached_distance(measure, i, j)
        if d is None:
            vec_i = take_columns(self.data, [i])
            vec_j = take_columns(self.data, [j])
            d = measure(vec_i, vec_j)

            mat, computed = self.get_cached_distances(measure)
            mat[i, j] = d
            key = key = (i, j) if i < j else (j, i)
            computed.add(key)
        return d

    def store_distance(self, measure, i, j, dist):
        matrix, computed = self.get_cached_distances(measure)
        key = (i, j) if i < j else (j, i)
        matrix[j, i] = matrix[i, j] = dist
        computed.add(key)

    def update_distances(self, base_indices=()):
        """Recompute the experiment distances.
        """
        distance = self.selected_distance()
        if base_indices == ():
            base_group_index = self.selected_base_group_index()
            base_indices = [ind[base_group_index] \
                            for _, ind in self.groups]

        assert(len(base_indices) == len(self.groups))

        base_distances = []
        attributes = self.data.domain.attributes
        pb = gui.ProgressBar(self, len(self.groups) * len(attributes))

        for (group, indices), base_index in zip(self.groups, base_indices):
            # Base column of the group
            if base_index is not None:
                base_vec = take_columns(self.data, [base_index])
                distances = []
                # Compute the distances between base column
                # and all the rest data columns.
                for i in range(len(attributes)):
                    if i == base_index:
                        distances.append(0.0)
                    elif self.get_cached_distance(distance, i, base_index) is not None:
                        distances.append(self.get_cached_distance(distance, i, base_index))
                    else:
                        vec_i = take_columns(self.data, [i])
                        dist = distance(base_vec, vec_i)
                        self.store_distance(distance, i, base_index, dist)
                        distances.append(dist)
                    pb.advance()

                base_distances.append(distances)
            else:
                base_distances.append(None)

        pb.finish()
        self.distances = base_distances

    def replot_experiments(self):
        """Replot the whole quality plot.
        """
        self.scene.clear()
        labels = []

        max_dist = numpy.nanmax(list(filter(None, self.distances)))
        rug_widgets = []

        group_pen = QPen(Qt.black)
        group_pen.setWidth(2)
        group_pen.setCapStyle(Qt.RoundCap)
        background_pen = QPen(QColor(0, 0, 250, 150))
        background_pen.setWidth(1)
        background_pen.setCapStyle(Qt.RoundCap)

        main_widget = QGraphicsWidget()
        layout = QGraphicsGridLayout()
        attributes = self.data.domain.attributes
        if self.data is not None:
            for (group, indices), dist_vec in zip(self.groups, self.distances):
                indices_set = set(indices)
                rug_items = []
                if dist_vec is not None:
                    for i, attr in enumerate(attributes):
                        # Is this a within group distance or background
                        in_group = i in indices_set
                        if in_group:
                            rug_item = ClickableRugItem(dist_vec[i] / max_dist,
                                           1.0, self.on_rug_item_clicked)
                            rug_item.setPen(group_pen)
                            tooltip = experiment_description(attr)
                            rug_item.setToolTip(tooltip)
                            rug_item.group_index = indices.index(i)
                            rug_item.setZValue(rug_item.zValue() + 1)
                        else:
                            rug_item = ClickableRugItem(dist_vec[i] / max_dist,
                                           0.85, self.on_rug_item_clicked)
                            rug_item.setPen(background_pen)
                            tooltip = experiment_description(attr)
                            rug_item.setToolTip(tooltip)

                        rug_item.group = group
                        rug_item.index = i
                        rug_item.in_group = in_group

                        rug_items.append(rug_item)

                rug_widget = RugGraphicsWidget(parent=main_widget)
                rug_widget.set_rug(rug_items)

                rug_widgets.append(rug_widget)

                label = group_label(self.selected_split_by_labels(), group)
                label_item = QGraphicsSimpleTextItem(label, main_widget)
                label_item = GraphicsSimpleTextLayoutItem(label_item, parent=layout)
                label_item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
                labels.append(label_item)

        for i, (label, rug_w) in enumerate(zip(labels, rug_widgets)):
            layout.addItem(label, i, 0, Qt.AlignVCenter)
            layout.addItem(rug_w, i, 1)
            layout.setRowMaximumHeight(i, 30)

        main_widget.setLayout(layout)
        self.scene.addItem(main_widget)
        self.main_widget = main_widget
        self.rug_widgets = rug_widgets
        self.labels = labels
        self.on_view_resize(self.scene_view.size())
Exemplo n.º 2
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    def __init__(self):
        super().__init__()
        self.model = None
        self.forest_adapter = None
        self.instances = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []
        # In some rare cases, we need to prevent commiting, the only one
        # that this currently helps is that when changing the size calculation
        # the trees are all recomputed, but we don't want to output a new tree
        # to keep things consistent with other ui controls.
        self.__prevent_commit = False

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x + 1)),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info)

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(box_display,
                                           self,
                                           'depth_limit',
                                           label='Depth',
                                           ticks=False,
                                           callback=self.update_depth)
        self.ui_target_class_combo = gui.comboBox(box_display,
                                                  self,
                                                  'target_class_index',
                                                  label='Target class',
                                                  orientation=Qt.Horizontal,
                                                  items=[],
                                                  contentsLength=8,
                                                  callback=self.update_colors)
        self.ui_size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
            callback=self.update_size_calc)
        self.ui_zoom_slider = gui.hSlider(box_display,
                                          self,
                                          'zoom',
                                          label='Zoom',
                                          ticks=False,
                                          minValue=20,
                                          maxValue=150,
                                          callback=self.zoom_changed,
                                          createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self._draw_trees()
            self.color_palette = self.forest_adapter.get_trees()[0]

            self.instances = model.instances
            # this bit is important for the regression classifier
            if self.instances is not None and self.instances.domain != model.domain:
                self.clf_dataset = Table.from_table(self.model.domain,
                                                    self.instances)
            else:
                self.clf_dataset = self.instances

            self._update_info_box()
            self._update_target_class_combo()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    def update_depth(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def update_colors(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_changed(self.target_class_index)

    def update_size_calc(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            with self._prevent_commit():
                self.grid.clear()
                self._draw_trees()
                # Keep the selected item
                if self.selected_tree_index != -1:
                    self.grid_items[self.selected_tree_index].setSelected(True)
                self.update_depth()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    @contextmanager
    def _prevent_commit(self):
        try:
            self.__prevent_commit = True
            yield
        finally:
            self.__prevent_commit = False

    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(
            len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    def _get_max_depth(self):
        return max(tree.tree_adapter.max_depth for tree in self.ptrees)

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    @contextmanager
    def disable_ui(self):
        """Temporarly disable the UI while trees may be redrawn."""
        try:
            self.ui_size_calc_combo.setEnabled(False)
            self.ui_depth_slider.setEnabled(False)
            self.ui_target_class_combo.setEnabled(False)
            self.ui_zoom_slider.setEnabled(False)
            yield
        finally:
            self.ui_size_calc_combo.setEnabled(True)
            self.ui_depth_slider.setEnabled(True)
            self.ui_target_class_combo.setEnabled(True)
            self.ui_zoom_slider.setEnabled(True)

    def _draw_trees(self):
        self.grid_items, self.ptrees = [], []

        num_trees = len(self.forest_adapter.get_trees())
        with self.progressBar(num_trees) as prg, self.disable_ui():
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None,
                    tree,
                    interactive=False,
                    padding=100,
                    target_class_index=self.target_class_index,
                    weight_adjustment=self.SIZE_CALCULATION[
                        self.size_calc_idx][1])
                grid_item = GridItem(ptree,
                                     self.grid,
                                     max_size=self._calculate_zoom(self.zoom))
                # We don't want to show flickering while the trees are being
                grid_item.setVisible(False)

                self.grid_items.append(grid_item)
                self.ptrees.append(ptree)
                prg.advance()

            self.grid.set_items(self.grid_items)
            # This is necessary when adding items for the first time
            if self.grid:
                width = (self.view.width() -
                         self.view.verticalScrollBar().width())
                self.grid.reflow(width)
                self.grid.setPreferredWidth(width)
                # After drawing is complete, we show the trees
                for grid_item in self.grid_items:
                    grid_item.setVisible(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if self.__prevent_commit:
            return

        if not self.scene.selectedItems():
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        tree = self.model.trees[self.selected_tree_index]
        tree.instances = self.instances
        tree.meta_target_class_index = self.target_class_index
        tree.meta_size_calc_idx = self.size_calc_idx
        tree.meta_depth_limit = self.depth_limit

        self.send('Tree', tree)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [
            x for x in self.ui_target_class_combo.parent().children()
            if isinstance(x, QLabel)
        ][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [
                c.title() for c in self.instances.domain.class_vars[0].values
            ]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.ui_target_class_combo.addItems(values)
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)
Exemplo n.º 3
0
class OWNomogram(OWWidget):
    name = "Nomogram"
    description = " Nomograms for Visualization of Naive Bayesian" \
                  " and Logistic Regression Classifiers."
    icon = "icons/Nomogram.svg"
    priority = 2000

    inputs = [("Classifier", Model, "set_classifier"),
              ("Data", Table, "set_instance")]

    MAX_N_ATTRS = 1000
    POINT_SCALE = 0
    ALIGN_LEFT = 0
    ALIGN_ZERO = 1
    ACCEPTABLE = (NaiveBayesModel, LogisticRegressionClassifier)
    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    normalize_probabilities = Setting(False)
    scale = Setting(1)
    display_index = Setting(1)
    n_attributes = Setting(10)
    sort_index = Setting(SortBy.ABSOLUTE)
    cont_feature_dim_index = Setting(0)

    graph_name = "scene"

    class Error(OWWidget.Error):
        invalid_classifier = Msg("Nomogram accepts only Naive Bayes and "
                                 "Logistic Regression classifiers.")

    def __init__(self):
        super().__init__()
        self.instances = None
        self.domain = None
        self.data = None
        self.classifier = None
        self.align = OWNomogram.ALIGN_ZERO
        self.log_odds_ratios = []
        self.log_reg_coeffs = []
        self.log_reg_coeffs_orig = []
        self.log_reg_cont_data_extremes = []
        self.p = None
        self.b0 = None
        self.points = []
        self.feature_items = []
        self.feature_marker_values = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.old_target_class_index = self.target_class_index
        self.markers_set = False
        self.repaint = False

        # GUI
        box = gui.vBox(self.controlArea, "Target class")
        self.class_combo = gui.comboBox(box,
                                        self,
                                        "target_class_index",
                                        callback=self._class_combo_changed,
                                        contentsLength=12)
        self.norm_check = gui.checkBox(
            box,
            self,
            "normalize_probabilities",
            "Normalize probabilities",
            hidden=True,
            callback=self._norm_check_changed,
            tooltip="For multiclass data 1 vs. all probabilities do not"
            " sum to 1 and therefore could be normalized.")

        self.scale_radio = gui.radioButtons(
            self.controlArea,
            self,
            "scale", ["Point scale", "Log odds ratios"],
            box="Scale",
            callback=self._radio_button_changed)

        box = gui.vBox(self.controlArea, "Display features")
        grid = QGridLayout()
        self.display_radio = gui.radioButtonsInBox(
            box,
            self,
            "display_index", [],
            orientation=grid,
            callback=self._display_radio_button_changed)
        radio_all = gui.appendRadioButton(self.display_radio,
                                          "All:",
                                          addToLayout=False)
        radio_best = gui.appendRadioButton(self.display_radio,
                                           "Best ranked:",
                                           addToLayout=False)
        spin_box = gui.hBox(None, margin=0)
        self.n_spin = gui.spin(spin_box,
                               self,
                               "n_attributes",
                               1,
                               self.MAX_N_ATTRS,
                               label=" ",
                               controlWidth=60,
                               callback=self._n_spin_changed)
        grid.addWidget(radio_all, 1, 1)
        grid.addWidget(radio_best, 2, 1)
        grid.addWidget(spin_box, 2, 2)

        self.sort_combo = gui.comboBox(box,
                                       self,
                                       "sort_index",
                                       label="Sort by: ",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._sort_combo_changed)

        self.cont_feature_dim_combo = gui.comboBox(
            box,
            self,
            "cont_feature_dim_index",
            label="Continuous features: ",
            items=["1D projection", "2D curve"],
            orientation=Qt.Horizontal,
            callback=self._cont_feature_dim_combo_changed)

        gui.rubber(self.controlArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(
            self.scene,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            renderHints=QPainter.Antialiasing | QPainter.TextAntialiasing
            | QPainter.SmoothPixmapTransform,
            alignment=Qt.AlignLeft)
        self.view.viewport().installEventFilter(self)
        self.view.viewport().setMinimumWidth(300)
        self.view.sizeHint = lambda: QSize(600, 500)
        self.mainArea.layout().addWidget(self.view)

    def _class_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        coeffs = [
            np.nan_to_num(p[self.target_class_index] /
                          p[self.old_target_class_index]) for p in self.points
        ]
        points = [p[self.old_target_class_index] for p in self.points]
        self.feature_marker_values = [
            self.get_points_from_coeffs(v, c, p)
            for (v, c, p) in zip(self.feature_marker_values, coeffs, points)
        ]
        self.update_scene()
        self.old_target_class_index = self.target_class_index

    def _norm_check_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _radio_button_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _display_radio_button_changed(self):
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _n_spin_changed(self):
        self.display_index = 1
        self.__hide_attrs(self.n_attributes)

    def __hide_attrs(self, n_show):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(n_show)
        if self.vertical_line:
            x = self.vertical_line.line().x1()
            y = self.nomogram_main.layout.preferredHeight() + 30
            self.vertical_line.setLine(x, -6, x, y)
            self.hidden_vertical_line.setLine(x, -6, x, y)
        rect = QRectF(self.scene.sceneRect().x(),
                      self.scene.sceneRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def _sort_combo_changed(self):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(None)
        self.nomogram_main.sort(self.sort_index)
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _cont_feature_dim_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def eventFilter(self, obj, event):
        if obj is self.view.viewport() and event.type() == QEvent.Resize:
            self.repaint = True
            values = [item.dot.value for item in self.feature_items]
            self.feature_marker_values = self.scale_back(values)
            self.update_scene()
        return super().eventFilter(obj, event)

    def update_controls(self):
        self.class_combo.clear()
        self.norm_check.setHidden(True)
        self.cont_feature_dim_combo.setEnabled(True)
        if self.domain:
            self.class_combo.addItems(self.domain.class_vars[0].values)
            if len(self.domain.attributes) > self.MAX_N_ATTRS:
                self.display_index = 1
            if len(self.domain.class_vars[0].values) > 2:
                self.norm_check.setHidden(False)
            if not self.domain.has_continuous_attributes():
                self.cont_feature_dim_combo.setEnabled(False)
                self.cont_feature_dim_index = 0
        model = self.sort_combo.model()
        item = model.item(SortBy.POSITIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        item = model.item(SortBy.NEGATIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        self.align = OWNomogram.ALIGN_ZERO
        if self.classifier and isinstance(self.classifier,
                                          LogisticRegressionClassifier):
            self.align = OWNomogram.ALIGN_LEFT
            item = model.item(SortBy.POSITIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            item = model.item(SortBy.NEGATIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sort_index in (SortBy.POSITIVE, SortBy.POSITIVE):
                self.sort_index = SortBy.NO_SORTING

    def set_instance(self, data):
        self.instances = data
        self.feature_marker_values = []
        self.set_feature_marker_values()

    def set_classifier(self, classifier):
        self.closeContext()
        self.classifier = classifier
        self.Error.clear()
        if self.classifier and not isinstance(self.classifier,
                                              self.ACCEPTABLE):
            self.Error.invalid_classifier()
            self.classifier = None
        self.domain = self.classifier.domain if self.classifier else None
        self.data = None
        self.calculate_log_odds_ratios()
        self.calculate_log_reg_coefficients()
        self.update_controls()
        self.target_class_index = 0
        self.openContext(self.domain and self.domain.class_var)
        self.points = self.log_odds_ratios or self.log_reg_coeffs
        self.feature_marker_values = []
        self.old_target_class_index = self.target_class_index
        self.update_scene()

    def calculate_log_odds_ratios(self):
        self.log_odds_ratios = []
        self.p = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, NaiveBayesModel):
            return

        log_cont_prob = self.classifier.log_cont_prob
        class_prob = self.classifier.class_prob
        for i in range(len(self.domain.attributes)):
            ca = np.exp(log_cont_prob[i]) * class_prob[:, None]
            _or = (ca / (1 - ca)) / (class_prob / (1 - class_prob))[:, None]
            self.log_odds_ratios.append(np.log(_or))
        self.p = class_prob

    def calculate_log_reg_coefficients(self):
        self.log_reg_coeffs = []
        self.log_reg_cont_data_extremes = []
        self.b0 = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, LogisticRegressionClassifier):
            return

        self.domain = self.reconstruct_domain(self.classifier.original_domain,
                                              self.domain)
        self.data = self.classifier.original_data.transform(self.domain)
        attrs, ranges, start = self.domain.attributes, [], 0
        for attr in attrs:
            stop = start + len(attr.values) if attr.is_discrete else start + 1
            ranges.append(slice(start, stop))
            start = stop

        self.b0 = self.classifier.intercept
        coeffs = self.classifier.coefficients
        if len(self.domain.class_var.values) == 2:
            self.b0 = np.hstack((self.b0 * (-1), self.b0))
            coeffs = np.vstack((coeffs * (-1), coeffs))
        self.log_reg_coeffs = [coeffs[:, ranges[i]] for i in range(len(attrs))]
        self.log_reg_coeffs_orig = self.log_reg_coeffs.copy()

        min_values = nanmin(self.data.X, axis=0)
        max_values = nanmax(self.data.X, axis=0)

        for i, min_t, max_t in zip(range(len(self.log_reg_coeffs)), min_values,
                                   max_values):
            if self.log_reg_coeffs[i].shape[1] == 1:
                coef = self.log_reg_coeffs[i]
                self.log_reg_coeffs[i] = np.hstack(
                    (coef * min_t, coef * max_t))
                self.log_reg_cont_data_extremes.append(
                    [sorted([min_t, max_t], reverse=(c < 0)) for c in coef])
            else:
                self.log_reg_cont_data_extremes.append([None])

    def update_scene(self):
        if not self.repaint:
            return
        self.clear_scene()
        if self.domain is None or not len(self.points[0]):
            return

        name_items = [
            QGraphicsTextItem(a.name) for a in self.domain.attributes
        ]
        point_text = QGraphicsTextItem("Points")
        probs_text = QGraphicsTextItem("Probabilities (%)")
        all_items = name_items + [point_text, probs_text]
        name_offset = -max(t.boundingRect().width() for t in all_items) - 50
        w = self.view.viewport().rect().width()
        max_width = w + name_offset - 100

        points = [pts[self.target_class_index] for pts in self.points]
        minimums = [min(p) for p in points]
        if self.align == OWNomogram.ALIGN_LEFT:
            points = [p - m for m, p in zip(minimums, points)]
        max_ = np.nan_to_num(max(max(abs(p)) for p in points))
        d = 100 / max_ if max_ else 1
        if self.scale == OWNomogram.POINT_SCALE:
            points = [p * d for p in points]

        if self.scale == OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [
                p / d + m for m, p in zip(minimums, x)
            ]
            self.scale_forth = lambda x: [(p - m) * d
                                          for m, p in zip(minimums, x)]
        if self.scale == OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p / d for p in x]
            self.scale_forth = lambda x: [p * d for p in x]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p + m for m, p in zip(minimums, x)]
            self.scale_forth = lambda x: [p - m for m, p in zip(minimums, x)]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: x
            self.scale_forth = lambda x: x

        point_item, nomogram_head = self.create_main_nomogram(
            name_items, points, max_width, point_text, name_offset)
        probs_item, nomogram_foot = self.create_footer_nomogram(
            probs_text, d, minimums, max_width, name_offset)
        for item in self.feature_items:
            item.dot.point_dot = point_item.dot
            item.dot.probs_dot = probs_item.dot
            item.dot.vertical_line = self.hidden_vertical_line

        self.nomogram = nomogram = NomogramItem()
        nomogram.add_items([nomogram_head, self.nomogram_main, nomogram_foot])
        self.scene.addItem(nomogram)
        self.set_feature_marker_values()
        rect = QRectF(self.scene.itemsBoundingRect().x(),
                      self.scene.itemsBoundingRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def create_main_nomogram(self, name_items, points, max_width, point_text,
                             name_offset):
        cls_index = self.target_class_index
        min_p = min(min(p) for p in points)
        max_p = max(max(p) for p in points)
        values = self.get_ruler_values(min_p, max_p, max_width)
        min_p, max_p = min(values), max(values)
        diff_ = np.nan_to_num(max_p - min_p)
        scale_x = max_width / diff_ if diff_ else max_width

        nomogram_header = NomogramItem()
        point_item = RulerItem(point_text, values, scale_x, name_offset,
                               -scale_x * min_p)
        point_item.setPreferredSize(point_item.preferredWidth(), 35)
        nomogram_header.add_items([point_item])

        self.nomogram_main = SortableNomogramItem()
        cont_feature_item_class = ContinuousFeature2DItem if \
            self.cont_feature_dim_index else ContinuousFeatureItem
        self.feature_items = [
            DiscreteFeatureItem(name_items[i], [val for val in att.values],
                                points[i], scale_x, name_offset, -scale_x *
                                min_p, self.points[i][cls_index])
            if att.is_discrete else cont_feature_item_class(
                name_items[i], self.log_reg_cont_data_extremes[i][cls_index],
                self.get_ruler_values(
                    np.min(points[i]), np.max(points[i]),
                    scale_x * (np.max(points[i]) - np.min(points[i])),
                    False), scale_x, name_offset, -scale_x *
                min_p, self.log_reg_coeffs_orig[i][cls_index][0])
            for i, att in enumerate(self.domain.attributes)
        ]
        self.nomogram_main.add_items(
            self.feature_items, self.sort_index,
            self.n_attributes if self.display_index else None)

        x = -scale_x * min_p
        y = self.nomogram_main.layout.preferredHeight() + 30
        self.vertical_line = QGraphicsLineItem(x, -6, x, y)
        self.vertical_line.setPen(QPen(Qt.DotLine))
        self.vertical_line.setParentItem(point_item)
        self.hidden_vertical_line = QGraphicsLineItem(x, -6, x, y)
        pen = QPen(Qt.DashLine)
        pen.setBrush(QColor(Qt.red))
        self.hidden_vertical_line.setPen(pen)
        self.hidden_vertical_line.setParentItem(point_item)

        return point_item, nomogram_header

    def create_footer_nomogram(self, probs_text, d, minimums, max_width,
                               name_offset):
        eps, d_ = 0.05, 1
        k = -np.log(self.p / (1 - self.p)) if self.p is not None else -self.b0
        min_sum = k[self.target_class_index] - np.log((1 - eps) / eps)
        max_sum = k[self.target_class_index] - np.log(eps / (1 - eps))
        if self.align == OWNomogram.ALIGN_LEFT:
            max_sum = max_sum - sum(minimums)
            min_sum = min_sum - sum(minimums)
            for i in range(len(k)):
                k[i] = k[i] - sum(
                    [min(q) for q in [p[i] for p in self.points]])
        if self.scale == OWNomogram.POINT_SCALE:
            min_sum *= d
            max_sum *= d
            d_ = d

        values = self.get_ruler_values(min_sum, max_sum, max_width)
        min_sum, max_sum = min(values), max(values)
        diff_ = np.nan_to_num(max_sum - min_sum)
        scale_x = max_width / diff_ if diff_ else max_width
        cls_var, cls_index = self.domain.class_var, self.target_class_index
        nomogram_footer = NomogramItem()

        def get_normalized_probabilities(val):
            if not self.normalize_probabilities:
                return 1 / (1 + np.exp(k[cls_index] - val / d_))
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return 1 / (1 + np.exp(k[cls_index] - val / d_)) / p_sum

        def get_points(prob):
            if not self.normalize_probabilities:
                return (k[cls_index] - np.log(1 / prob - 1)) * d_
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return (k[cls_index] - np.log(1 / (prob * p_sum) - 1)) * d_

        self.markers_set = False
        probs_item = ProbabilitiesRulerItem(
            probs_text,
            values,
            scale_x,
            name_offset,
            -scale_x * min_sum,
            get_points=get_points,
            title="{}='{}'".format(cls_var.name, cls_var.values[cls_index]),
            get_probabilities=get_normalized_probabilities)
        self.markers_set = True
        nomogram_footer.add_items([probs_item])
        return probs_item, nomogram_footer

    def __get_totals_for_class_values(self, minimums):
        cls_index = self.target_class_index
        marker_values = [item.dot.value for item in self.feature_items]
        if not self.markers_set:
            marker_values = self.scale_forth(marker_values)
        totals = np.empty(len(self.domain.class_var.values))
        totals[cls_index] = sum(marker_values)
        marker_values = self.scale_back(marker_values)
        for i in range(len(self.domain.class_var.values)):
            if i == cls_index:
                continue
            coeffs = [np.nan_to_num(p[i] / p[cls_index]) for p in self.points]
            points = [p[cls_index] for p in self.points]
            total = sum([
                self.get_points_from_coeffs(v, c, p)
                for (v, c, p) in zip(marker_values, coeffs, points)
            ])
            if self.align == OWNomogram.ALIGN_LEFT:
                points = [p - m for m, p in zip(minimums, points)]
                total -= sum([min(p) for p in [p[i] for p in self.points]])
            d = 100 / max(max(abs(p)) for p in points)
            if self.scale == OWNomogram.POINT_SCALE:
                total *= d
            totals[i] = total
        return totals

    def set_feature_marker_values(self):
        if not (len(self.points) and len(self.feature_items)):
            return
        if not len(self.feature_marker_values):
            self._init_feature_marker_values()
        self.feature_marker_values = self.scale_forth(
            self.feature_marker_values)
        item = self.feature_items[0]
        for i, item in enumerate(self.feature_items):
            item.dot.move_to_val(self.feature_marker_values[i])
        item.dot.probs_dot.move_to_sum()

    def _init_feature_marker_values(self):
        self.feature_marker_values = []
        cls_index = self.target_class_index
        instances = Table(self.domain, self.instances) \
            if self.instances else None
        for i, attr in enumerate(self.domain.attributes):
            value, feature_val = 0, None
            if len(self.log_reg_coeffs):
                if attr.is_discrete:
                    ind, n = unique(self.data.X[:, i], return_counts=True)
                    feature_val = np.nan_to_num(ind[np.argmax(n)])
                else:
                    feature_val = mean(self.data.X[:, i])
            inst_in_dom = instances and attr in instances.domain
            if inst_in_dom and not np.isnan(instances[0][attr]):
                feature_val = instances[0][attr]
            if feature_val is not None:
                value = self.points[i][cls_index][int(feature_val)] \
                    if attr.is_discrete else \
                    self.log_reg_coeffs_orig[i][cls_index][0] * feature_val
            self.feature_marker_values.append(value)

    def clear_scene(self):
        self.feature_items = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.scene.clear()

    def send_report(self):
        self.report_plot()

    @staticmethod
    def reconstruct_domain(original, preprocessed):
        # abuse dict to make "in" comparisons faster
        attrs = OrderedDict()
        for attr in preprocessed.attributes:
            cv = attr._compute_value.variable._compute_value
            var = cv.variable if cv else original[attr.name]
            if var in attrs:  # the reason for OrderedDict
                continue
            attrs[var] = None  # we only need keys
        attrs = list(attrs.keys())
        return Domain(attrs, original.class_var, original.metas)

    @staticmethod
    def get_ruler_values(start, stop, max_width, round_to_nearest=True):
        if max_width == 0:
            return [0]
        diff = np.nan_to_num((stop - start) / max_width)
        if diff <= 0:
            return [0]
        decimals = int(np.floor(np.log10(diff)))
        if diff > 4 * pow(10, decimals):
            step = 5 * pow(10, decimals + 2)
        elif diff > 2 * pow(10, decimals):
            step = 2 * pow(10, decimals + 2)
        elif diff > 1 * pow(10, decimals):
            step = 1 * pow(10, decimals + 2)
        else:
            step = 5 * pow(10, decimals + 1)
        round_by = int(-np.floor(np.log10(step)))
        r = start % step
        if not round_to_nearest:
            _range = np.arange(start + step, stop + r, step) - r
            start, stop = np.floor(start * 100) / 100, np.ceil(
                stop * 100) / 100
            return np.round(np.hstack((start, _range, stop)), 2)
        return np.round(np.arange(start, stop + r + step, step) - r, round_by)

    @staticmethod
    def get_points_from_coeffs(current_value, coefficients, possible_values):
        if any(np.isnan(possible_values)):
            return 0
        indices = np.argsort(possible_values)
        sorted_values = possible_values[indices]
        sorted_coefficients = coefficients[indices]
        for i, val in enumerate(sorted_values):
            if current_value < val:
                break
        diff = sorted_values[i] - sorted_values[i - 1]
        k = 0 if diff < 1e-6 else (sorted_values[i] - current_value) / \
                                  (sorted_values[i] - sorted_values[i - 1])
        return sorted_coefficients[i - 1] * sorted_values[i - 1] * k + \
               sorted_coefficients[i] * sorted_values[i] * (1 - k)
Exemplo n.º 4
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    size_log_scale = settings.Setting(2)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    CLASSIFICATION, REGRESSION = range(2)

    def __init__(self):
        super().__init__()
        # Instance variables
        self.forest_type = self.CLASSIFICATION
        self.model = None
        self.forest_adapter = None
        self.dataset = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x * self.size_log_scale)),
        ]

        self.REGRESSION_COLOR_CALC = [
            ('None', lambda _, __: QColor(255, 255, 255)),
            ('Class mean', self._color_class_mean),
            ('Standard deviation', self._color_stddev),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info, label='')

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(
            box_display, self, 'depth_limit', label='Depth', ticks=False,
            callback=self.max_depth_changed)
        self.ui_target_class_combo = gui.comboBox(
            box_display, self, 'target_class_index', label='Target class',
            orientation=Qt.Horizontal, items=[], contentsLength=8,
            callback=self.target_colors_changed)
        self.ui_size_calc_combo = gui.comboBox(
            box_display, self, 'size_calc_idx', label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8,
            callback=self.size_calc_changed)
        self.ui_zoom_slider = gui.hSlider(
            box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20,
            maxValue=150, callback=self.zoom_changed, createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(
            QSizePolicy.Preferred, QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            if isinstance(model, RandomForestClassifier):
                self.forest_type = self.CLASSIFICATION
            elif isinstance(model, RandomForestRegressor):
                self.forest_type = self.REGRESSION
            else:
                raise RuntimeError('Invalid type of forest.')

            self.forest_adapter = self._get_forest_adapter(self.model)
            self.color_palette = self._type_specific('_get_color_palette')()
            self._draw_trees()

            self.dataset = model.instances
            # this bit is important for the regression classifier
            if self.dataset is not None and \
                    self.dataset.domain != model.domain:
                self.clf_dataset = Table.from_table(
                    self.model.domain, self.dataset)
            else:
                self.clf_dataset = self.dataset

            self._update_info_box()
            self._type_specific('_update_target_class_combo')()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    # CONTROL AREA CALLBACKS
    def max_depth_changed(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def target_colors_changed(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_has_changed()

    def size_calc_changed(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self.grid.clear()
            self._draw_trees()
            # Keep the selected item
            if self.selected_tree_index != -1:
                self.grid_items[self.selected_tree_index].setSelected(True)
            self.max_depth_changed()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    # MODEL CHANGED METHODS
    def _update_info_box(self):
        self.ui_info.setText(
            'Trees: {}'.format(len(self.forest_adapter.get_trees()))
        )

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    # MODEL CLEARED METHODS
    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    # HELPFUL METHODS
    def _get_max_depth(self):
        return max([tree.tree_adapter.max_depth for tree in self.ptrees])

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    def _draw_trees(self):
        self.ui_size_calc_combo.setEnabled(False)
        self.grid_items, self.ptrees = [], []

        with self.progressBar(len(self.forest_adapter.get_trees())) as prg:
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None, tree,
                    node_color_func=self._type_specific('_get_node_color'),
                    interactive=False, padding=100)
                self.grid_items.append(GridItem(
                    ptree, self.grid, max_size=self._calculate_zoom(self.zoom)
                ))
                self.ptrees.append(ptree)
                prg.advance()
        self.grid.set_items(self.grid_items)
        # This is necessary when adding items for the first time
        if self.grid:
            width = (self.view.width() -
                     self.view.verticalScrollBar().width())
            self.grid.reflow(width)
            self.grid.setPreferredWidth(width)
        self.ui_size_calc_combo.setEnabled(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if len(self.scene.selectedItems()) == 0:
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        obj = self.model.trees[self.selected_tree_index]
        obj.instances = self.dataset
        obj.meta_target_class_index = self.target_class_index
        obj.meta_size_calc_idx = self.size_calc_idx
        obj.meta_size_log_scale = self.size_log_scale
        obj.meta_depth_limit = self.depth_limit

        self.send('Tree', obj)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)

    def _type_specific(self, method):
        """A best effort method getter that somewhat separates logic specific
        to classification and regression trees.
        This relies on conventional naming of specific methods, e.g.
        a method name _get_tooltip would need to be defined like so:
        _classification_get_tooltip and _regression_get_tooltip, since they are
        both specific.

        Parameters
        ----------
        method : str
            Method name that we would like to call.

        Returns
        -------
        callable or None

        """
        if self.forest_type == self.CLASSIFICATION:
            return getattr(self, '_classification' + method)
        elif self.forest_type == self.REGRESSION:
            return getattr(self, '_regression' + method)
        else:
            return None

    # CLASSIFICATION FOREST SPECIFIC METHODS
    def _classification_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItem('None')
        values = [c.title() for c in
                  self.model.domain.class_vars[0].values]
        self.ui_target_class_combo.addItems(values)

    def _classification_get_color_palette(self):
        return [QColor(*c) for c in self.model.domain.class_var.colors]

    def _classification_get_node_color(self, adapter, tree_node):
        # this is taken almost directly from the existing classification tree
        # viewer
        colors = self.color_palette
        distribution = adapter.get_distribution(tree_node.label)[0]
        total = np.sum(distribution)

        if self.target_class_index:
            p = distribution[self.target_class_index - 1] / total
            color = colors[self.target_class_index - 1].lighter(200 - 100 * p)
        else:
            modus = np.argmax(distribution)
            p = distribution[modus] / (total or 1)
            color = colors[int(modus)].lighter(400 - 300 * p)
        return color

    # REGRESSION FOREST SPECIFIC METHODS
    def _regression_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItems(
            list(zip(*self.REGRESSION_COLOR_CALC))[0])
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _regression_get_color_palette(self):
        return ContinuousPaletteGenerator(
            *self.forest_adapter.domain.class_var.colors)

    def _regression_get_node_color(self, adapter, tree_node):
        return self.REGRESSION_COLOR_CALC[self.target_class_index][1](
            adapter, tree_node
        )

    def _color_class_mean(self, adapter, tree_node):
        # calculate node colors relative to the mean of the node samples
        min_mean = np.min(self.clf_dataset.Y)
        max_mean = np.max(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        mean = np.mean(instances.Y)

        return self.color_palette[(mean - min_mean) / (max_mean - min_mean)]

    def _color_stddev(self, adapter, tree_node):
        # calculate node colors relative to the standard deviation in the node
        # samples
        min_mean, max_mean = 0, np.std(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        std = np.std(instances.Y)

        return self.color_palette[(std - min_mean) / (max_mean - min_mean)]
Exemplo n.º 5
0
class OWVennDiagram(widget.OWWidget):
    name = "Venn Diagram"
    description = "A graphical visualization of the overlap of data instances " \
                  "from a collection of input datasets."
    icon = "icons/VennDiagram.svg"
    priority = 280
    keywords = []
    settings_version = 2

    class Inputs:
        data = Input("Data", Table, multiple=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    class Error(widget.OWWidget.Error):
        instances_mismatch = Msg("Data sets do not contain the same instances.")
        too_many_inputs = Msg("Venn diagram accepts at most five datasets.")

    class Warning(widget.OWWidget.Warning):
        renamed_vars = Msg("Some variables have been renamed "
                           "to avoid duplicates.\n{}")

    selection: list

    settingsHandler = settings.DomainContextHandler()
    # Indices of selected disjoint areas
    selection = settings.Setting([], schema_only=True)
    #: Output unique items (one output row for every unique instance `key`)
    #: or preserve all duplicates in the output.
    output_duplicates = settings.Setting(False)
    autocommit = settings.Setting(True)
    rowwise = settings.Setting(True)
    selected_feature = settings.ContextSetting(None)

    want_control_area = False
    graph_name = "scene"
    atr_types = ['attributes', 'metas', 'class_vars']
    atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
    row_vals = {'attributes': 'x', 'class_vars': 'y', 'metas': 'metas'}

    def __init__(self):
        super().__init__()

        # Diagram update is in progress
        self._updating = False
        # Input update is in progress
        self._inputUpdate = False
        # Input datasets in the order they were 'connected'.
        self.data = {}
        # Extracted input item sets in the order they were 'connected'
        self.itemsets = {}
        # A list with 2 ** len(self.data) elements that store item sets
        # belonging to each area
        self.disjoint = []
        # A list with  2 ** len(self.data) elements that store keys of tables
        # intersected in each area
        self.area_keys = []

        # Main area view
        self.scene = QGraphicsScene(self)
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.setBackgroundRole(QPalette.Window)
        self.view.setFrameStyle(QGraphicsView.StyledPanel)

        self.mainArea.layout().addWidget(self.view)
        self.vennwidget = VennDiagram()
        self._resize()
        self.vennwidget.itemTextEdited.connect(self._on_itemTextEdited)
        self.scene.selectionChanged.connect(self._on_selectionChanged)

        self.scene.addItem(self.vennwidget)

        controls = gui.hBox(self.mainArea)
        box = gui.radioButtonsInBox(
            controls, self, 'rowwise',
            ["Columns (features)", "Rows (instances), matched by", ],
            box="Elements", callback=self._on_matching_changed
        )
        gui.comboBox(
            gui.indentedBox(box), self, "selected_feature",
            model=itemmodels.VariableListModel(placeholder="Instance identity"),
            callback=self._on_inputAttrActivated
            )
        box.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)

        self.outputs_box = box = gui.vBox(controls, "Output")
        self.output_duplicates_cb = gui.checkBox(
            box, self, "output_duplicates", "Output duplicates",
            callback=lambda: self.commit())  # pylint: disable=unnecessary-lambda
        gui.auto_send(box, self, "autocommit", box=False)
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        self._queue = []

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._resize()

    def showEvent(self, event):
        super().showEvent(event)
        self._resize()

    def _resize(self):
        # vennwidget draws so that the diagram fits into its geometry,
        # while labels take further 120 pixels, hence -120 in below formula
        size = max(200, min(self.view.width(), self.view.height()) - 120)
        self.vennwidget.resize(size, size)
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    @Inputs.data
    @check_sql_input
    def setData(self, data, key=None):
        self.Error.too_many_inputs.clear()
        if not self._inputUpdate:
            self._inputUpdate = True
        if key in self.data:
            if data is None:
                # Remove the input
                # Clear possible warnings.
                self.Warning.clear()
                del self.data[key]
            else:
                # Update existing item
                self.data[key] = self.data[key]._replace(name=data.name, table=data)

        elif data is not None:
            # TODO: Allow setting more them 5 inputs and let the user
            # select the 5 to display.
            if len(self.data) == 5:
                self.Error.too_many_inputs()
                return
            # Add a new input
            self.data[key] = _InputData(key, data.name, data)
        self._setInterAttributes()

    def data_equality(self):
        """ Checks if all input datasets have same ids. """
        if not self.data.values():
            return True
        sets = []
        for val in self.data.values():
            sets.append(set(val.table.ids))
        inter = reduce(set.intersection, sets)
        return len(inter) == max(map(len, sets))

    def settings_compatible(self):
        self.Error.instances_mismatch.clear()
        if not self.rowwise:
            if not self.data_equality():
                self.vennwidget.clear()
                self.Error.instances_mismatch()
                self.itemsets = {}
                return False
        return True

    def handleNewSignals(self):
        self._inputUpdate = False
        self.vennwidget.clear()
        if not self.settings_compatible():
            self.invalidateOutput()
            return

        self._createItemsets()
        self._createDiagram()
        # If autocommit is enabled, _createDiagram already outputs data
        # If not, call unconditional_commit from here
        if not self.autocommit:
            self.unconditional_commit()

        self._updateInfo()
        super().handleNewSignals()

    def intersectionStringAttrs(self):
        sets = [set(string_attributes(data_.table.domain)) for data_ in self.data.values()]
        if sets:
            return reduce(set.intersection, sets)
        return set()

    def _setInterAttributes(self):
        model = self.controls.selected_feature.model()
        model[:] = [None] + list(self.intersectionStringAttrs())
        if self.selected_feature:
            names = (var.name for var in model if var)
            if self.selected_feature.name not in names:
                self.selected_feature = model[0]

    def _itemsForInput(self, key):
        """
        Calculates input for venn diagram, according to user's settings.
        """
        table = self.data[key].table
        attr = self.selected_feature
        if attr:
            return [str(inst[attr]) for inst in table
                    if not np.isnan(inst[attr])]
        else:
            return list(table.ids)

    def _createItemsets(self):
        """
        Create itemsets over rows or columns (domains) of input tables.
        """
        olditemsets = dict(self.itemsets)
        self.itemsets.clear()

        for key, input_ in self.data.items():
            if self.rowwise:
                items = self._itemsForInput(key)
            else:
                items = [el.name for el in input_.table.domain.attributes]
            name = input_.name
            if key in olditemsets and olditemsets[key].name == name:
                # Reuse the title (which might have been changed by the user)
                title = olditemsets[key].title
            else:
                title = name

            itemset = _ItemSet(key=key, name=name, title=title, items=items)
            self.itemsets[key] = itemset

    def _createDiagram(self):
        self._updating = True

        oldselection = list(self.selection)

        n = len(self.itemsets)
        self.disjoint, self.area_keys = \
            self.get_disjoint(set(s.items) for s in self.itemsets.values())

        vennitems = []
        colors = colorpalettes.LimitedDiscretePalette(n, force_hsv=True)

        for i, item in enumerate(self.itemsets.values()):
            cnt = len(set(item.items))
            cnt_all = len(item.items)
            if cnt != cnt_all:
                fmt = '{} <i>(all: {})</i>'
            else:
                fmt = '{}'
            counts = fmt.format(cnt, cnt_all)
            gr = VennSetItem(text=item.title, informativeText=counts)
            color = colors[i]
            color.setAlpha(100)
            gr.setBrush(QBrush(color))
            gr.setPen(QPen(Qt.NoPen))
            vennitems.append(gr)

        self.vennwidget.setItems(vennitems)

        for i, area in enumerate(self.vennwidget.vennareas()):
            area_items = list(map(str, list(self.disjoint[i])))
            if i:
                area.setText("{0}".format(len(area_items)))

            label = disjoint_set_label(i, n, simplify=False)
            head = "<h4>|{}| = {}</h4>".format(label, len(area_items))
            if len(area_items) > 32:
                items_str = ", ".join(map(escape, area_items[:32]))
                hidden = len(area_items) - 32
                tooltip = ("{}<span>{}, ...</br>({} items not shown)<span>"
                           .format(head, items_str, hidden))
            elif area_items:
                tooltip = "{}<span>{}</span>".format(
                    head,
                    ", ".join(map(escape, area_items))
                )
            else:
                tooltip = head

            area.setToolTip(tooltip)

            area.setPen(QPen(QColor(10, 10, 10, 200), 1.5))
            area.setFlag(QGraphicsPathItem.ItemIsSelectable, True)
            area.setSelected(i in oldselection)

        self._updating = False
        self._on_selectionChanged()

    def _updateInfo(self):
        # Clear all warnings
        self.warning()

        if self.selected_feature is None:
            no_idx = ["#{}".format(i + 1)
                      for i, key in enumerate(self.data)
                      if not source_attributes(self.data[key].table.domain)]
            if len(no_idx) == 1:
                self.warning("Dataset {} has no suitable identifiers."
                             .format(no_idx[0]))
            elif len(no_idx) > 1:
                self.warning("Datasets {} and {} have no suitable identifiers."
                             .format(", ".join(no_idx[:-1]), no_idx[-1]))

    def _on_selectionChanged(self):
        if self._updating:
            return

        areas = self.vennwidget.vennareas()
        self.selection = [i for i, area in enumerate(areas) if area.isSelected()]
        self.invalidateOutput()

    def _on_matching_changed(self):
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        if not self.settings_compatible():
            self.invalidateOutput()
            return
        self._createItemsets()
        self._createDiagram()
        self._updateInfo()

    def _on_inputAttrActivated(self):
        self.rowwise = 1
        self._on_matching_changed()

    def _on_itemTextEdited(self, index, text):
        text = str(text)
        key = list(self.itemsets)[index]
        self.itemsets[key] = self.itemsets[key]._replace(title=text)

    def invalidateOutput(self):
        self.commit()

    def merge_data(self, domain, values, ids=None):
        X, metas, class_vars = None, None, None
        renamed = []
        for val in domain.values():
            names = [var.name for var in val]
            unique_names = get_unique_names_duplicates(names)
            for n, u, idx, var in zip(names, unique_names, count(), val):
                if n != u:
                    val[idx] = var.copy(name=u)
                    renamed.append(n)
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        if 'attributes' in values:
            X = np.hstack(values['attributes'])
        if 'metas' in values:
            metas = np.hstack(values['metas'])
            n = len(metas)
        if 'class_vars' in values:
            class_vars = np.hstack(values['class_vars'])
            n = len(class_vars)
        if X is None:
            X = np.empty((n, 0))
        table = Table.from_numpy(Domain(**domain), X, class_vars, metas)
        if ids is not None:
            table.ids = ids
        return table

    def extract_columnwise(self, var_dict, columns=None):
        domain = {type_ : [] for type_ in self.atr_types}
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                is_selected = bool(columns) and var_name.name in columns
                if var_data[0]:
                    #columns are different, copy all, rename them
                    for var, table_key in var_data[1]:
                        idx = list(self.data).index(table_key) + 1
                        new_atr = var.copy(name=f'{var_name.name} ({idx})')
                        if columns and atr_type == 'attributes':
                            new_atr.attributes['Selected'] = is_selected
                        domain[atr_type].append(new_atr)
                        renamed.append(var_name.name)
                        values[atr_type].append(getattr(self.data[table_key].table[:, var_name],
                                                        self.atr_vals[atr_type])
                                                .reshape(-1, 1))
                else:
                    new_atr = var_data[1][0][0].copy()
                    if columns and atr_type == 'attributes':
                        new_atr.attributes['Selected'] = is_selected
                    domain[atr_type].append(new_atr)
                    values[atr_type].append(getattr(self.data[var_data[1][0][1]].table[:, var_name],
                                                    self.atr_vals[atr_type])
                                            .reshape(-1, 1))
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        return self.merge_data(domain, values)

    def curry_merge(self, table_key, atr_type, ids=None, selection=False):
        if self.rowwise:
            check_equality = self.arrays_equal_rows
        else:
            check_equality = self.arrays_equal_cols

        def inner(new_atrs, atr):
            """
            Atrs - list of variables we wish to merge
            new_atrs - dictionary where key is old var, val
                is [is_different:bool, table_keys:list]), is_different is set to True,
                if we are outputing duplicates, but the value is arbitrary
            """
            if atr in new_atrs:
                if not selection and self.output_duplicates:
                    #if output_duplicates, we just check if compute value is the same
                    new_atrs[atr][0] = True
                elif not new_atrs[atr][0]:
                    for var, key in new_atrs[atr][1]:
                        if not check_equality(table_key,
                                              key,
                                              atr.name,
                                              self.atr_vals[atr_type],
                                              type(var), ids):
                            new_atrs[atr][0] = True
                            break
                new_atrs[atr][1].append((atr, table_key))
            else:
                new_atrs[atr] = [False, [(atr, table_key)]]
            return new_atrs
        return inner

    def arrays_equal_rows(self, key1, key2, name, data_type, type_, ids):
        #gets masks, compares same as cols
        t1 = self.data[key1].table
        t2 = self.data[key2].table
        inter_val = set(ids[key1]) & set(ids[key2])
        t1_inter = [ids[key1][val] for val in inter_val]
        t2_inter = [ids[key2][val] for val in inter_val]
        return arrays_equal(
            getattr(t1[t1_inter, name],
                    data_type).reshape(-1, 1),
            getattr(t2[t2_inter, name],
                    data_type).reshape(-1, 1),
            type_)

    def arrays_equal_cols(self, key1, key2, name, data_type, type_, _ids=None):
        return arrays_equal(
            getattr(self.data[key1].table[:, name],
                    data_type),
            getattr(self.data[key2].table[:, name],
                    data_type),
            type_)

    def create_from_columns(self, columns, relevant_keys, get_selected):
        """
        Columns are duplicated only if values differ (even
        if only in order of values), origin table name and input slot is added to column name.
        """
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_keys:
                table = self.data[table_key].table
                if atr_type == 'attributes':
                    if get_selected:
                        atrs = list(compress(table.domain.attributes,
                                             [c.name in columns for c in table.domain.attributes]))
                    else:
                        atrs = getattr(table.domain, atr_type)
                else:
                    atrs = getattr(table.domain, atr_type)
                merge_vars = self.curry_merge(table_key, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container

        if get_selected:
            annotated = self.extract_columnwise(var_dict, None)
        else:
            annotated = self.extract_columnwise(var_dict, columns)

        return annotated

    def extract_rowwise(self, var_dict, ids=None, selection=False):
        """
        keys : ['attributes', 'metas', 'class_vars']
        vals: new_atrs - dictionary where key is old name, val
            is [is_different:bool, table_keys:list])
        ids: dict with ids for each table
        """
        all_ids = sorted(reduce(set.union, [set(val) for val in ids.values()], set()))

        permutations = {}
        for table_key, dict_ in ids.items():
            permutations[table_key] = get_perm(list(dict_), all_ids)

        domain = {type_ : [] for type_ in self.atr_types}
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                different = var_data[0]
                if different:
                    # Columns are different, copy and rename them.
                    # Renaming is done here to mark appropriately the source table.
                    # Additional strange clashes are checked later in merge_data
                    for var, table_key in var_data[1]:
                        temp = self.data[table_key].table
                        idx = list(self.data).index(table_key) + 1
                        domain[atr_type].append(var.copy(name='{} ({})'.format(var_name, idx)))
                        renamed.append(var_name.name)
                        v = getattr(temp[list(ids[table_key].values()), var_name],
                                    self.atr_vals[atr_type])
                        perm = permutations[table_key]
                        if len(v) < len(all_ids):
                            values[atr_type].append(pad_columns(v, perm, len(all_ids)))
                        else:
                            values[atr_type].append(v[perm].reshape(-1, 1))
                else:
                    value = np.full((len(all_ids), 1), np.nan)
                    domain[atr_type].append(var_data[1][0][0].copy())
                    for _, table_key in var_data[1]:
                        #different tables have different part of the same attribute vector
                        perm = permutations[table_key]
                        v = getattr(self.data[table_key].table[list(ids[table_key].values()),
                                                               var_name],
                                    self.atr_vals[atr_type]).reshape(-1, 1)
                        value = value.astype(v.dtype, copy=False)
                        value[perm] = v
                    values[atr_type].append(value)

        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        ids = None if self.selected_feature else np.array(all_ids)
        table = self.merge_data(domain, values, ids)
        if selection:
            mask = [idx in self.selected_items for idx in all_ids]
            return create_annotated_table(table, mask)
        return table

    def get_indices(self, table, selection):
        """Returns mappings of ids (be it row id or string) to indices in tables"""
        if self.selected_feature:
            if self.output_duplicates and selection:
                items, inverse = np.unique(getattr(table[:, self.selected_feature], 'metas'),
                                           return_inverse=True)
                ids = [np.nonzero(inverse == idx)[0] for idx in range(len(items))]
            else:
                items, ids = np.unique(getattr(table[:, self.selected_feature], 'metas'),
                                       return_index=True)

        else:
            items = table.ids
            ids = range(len(table))

        if selection:
            return {item: idx for item, idx in zip(items, ids)
                    if item in self.selected_items}

        return dict(zip(items, ids))

    def get_indices_to_match_by(self, relevant_keys, selection=False):
        dict_ = {}
        for key in relevant_keys:
            table = self.data[key].table
            dict_[key] = self.get_indices(table, selection)
        return dict_

    def create_from_rows(self, relevant_ids, selection=False):
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_ids:
                merge_vars = self.curry_merge(table_key, atr_type, relevant_ids, selection)
                atrs = getattr(self.data[table_key].table.domain, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container
        if self.output_duplicates and not selection:
            return self.extract_rowwise_duplicates(var_dict, relevant_ids)
        return self.extract_rowwise(var_dict, relevant_ids, selection)

    def expand_table(self, table, atrs, metas, cv):
        exp = []
        n = 1 if isinstance(table, RowInstance) else len(table)
        if isinstance(table, RowInstance):
            ids = table.id.reshape(-1, 1)
            atr_vals = self.row_vals
        else:
            ids = table.ids.reshape(-1, 1)
            atr_vals = self.atr_vals
        for all_el, atr_type in zip([atrs, metas, cv], self.atr_types):
            cur_el = getattr(table.domain, atr_type)
            array = np.full((n, len(all_el)), np.nan)
            if cur_el:
                perm = get_perm(cur_el, all_el)
                b = getattr(table, atr_vals[atr_type]).reshape(len(array), len(perm))
                array = array.astype(b.dtype, copy=False)
                array[:, perm] = b
            exp.append(array)
        return (*exp, ids)

    def extract_rowwise_duplicates(self, var_dict, ids):
        all_ids = sorted(reduce(set.union, [set(val) for val in ids.values()], set()))
        sort_key = attrgetter("name")
        all_atrs = sorted(var_dict['attributes'], key=sort_key)
        all_metas = sorted(var_dict['metas'], key=sort_key)
        all_cv = sorted(var_dict['class_vars'], key=sort_key)

        all_x, all_y, all_m = [], [], []
        new_table_ids = []
        for idx in all_ids:
            #iterate trough tables with same idx
            for table_key, t_indices in ids.items():
                if idx not in t_indices:
                    continue
                map_ = t_indices[idx]
                extracted = self.data[table_key].table[map_]
                # pylint: disable=unbalanced-tuple-unpacking
                x, m, y, t_ids = self.expand_table(extracted, all_atrs, all_metas, all_cv)
                all_x.append(x)
                all_y.append(y)
                all_m.append(m)
                new_table_ids.append(t_ids)
        domain = {'attributes': all_atrs, 'metas': all_metas, 'class_vars': all_cv}
        values = {'attributes': [np.vstack(all_x)],
                  'metas': [np.vstack(all_m)],
                  'class_vars': [np.vstack(all_y)]}
        return self.merge_data(domain, values, np.vstack(new_table_ids))

    def commit(self):
        if not self.vennwidget.vennareas() or not self.data:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            return

        self.selected_items = reduce(
            set.union, [self.disjoint[index] for index in self.selection],
            set()
        )
        selected_keys = reduce(
            set.union, [set(self.area_keys[area]) for area in self.selection],
            set())
        selected = None

        if self.rowwise:
            if self.selected_items:
                selected_ids = self.get_indices_to_match_by(
                    selected_keys, bool(self.selection))
                selected = self.create_from_rows(selected_ids, False)
            annotated_ids = self.get_indices_to_match_by(self.data)
            annotated = self.create_from_rows(annotated_ids, True)
        else:
            annotated = self.create_from_columns(self.selected_items, self.data, False)
            if self.selected_items:
                selected = self.create_from_columns(self.selected_items, selected_keys, True)

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        self.report_plot()

    def get_disjoint(self, sets):
        """
        Return all disjoint subsets.
        """
        sets = list(sets)
        n = len(sets)
        disjoint_sets = [None] * (2 ** n)
        included_tables = [None] * (2 ** n)
        for i in range(2 ** n):
            key = setkey(i, n)
            included = [s for s, inc in zip(sets, key) if inc]
            if included:
                excluded = [s for s, inc in zip(sets, key) if not inc]
                s = reduce(set.intersection, included)
                s = reduce(set.difference, excluded, s)
            else:
                s = set()
            disjoint_sets[i] = s
            included_tables[i] = [k for k, inc in zip(self.data, key) if inc]

        return disjoint_sets, included_tables
Exemplo n.º 6
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    size_log_scale = settings.Setting(2)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    CLASSIFICATION, REGRESSION = range(2)

    def __init__(self):
        super().__init__()
        # Instance variables
        self.forest_type = self.CLASSIFICATION
        self.model = None
        self.forest_adapter = None
        self.dataset = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x * self.size_log_scale)),
        ]

        self.REGRESSION_COLOR_CALC = [
            ('None', lambda _, __: QColor(255, 255, 255)),
            ('Class mean', self._color_class_mean),
            ('Standard deviation', self._color_stddev),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info, label='')

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(box_display,
                                           self,
                                           'depth_limit',
                                           label='Depth',
                                           ticks=False,
                                           callback=self.max_depth_changed)
        self.ui_target_class_combo = gui.comboBox(
            box_display,
            self,
            'target_class_index',
            label='Target class',
            orientation=Qt.Horizontal,
            items=[],
            contentsLength=8,
            callback=self.target_colors_changed)
        self.ui_size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
            callback=self.size_calc_changed)
        self.ui_zoom_slider = gui.hSlider(box_display,
                                          self,
                                          'zoom',
                                          label='Zoom',
                                          ticks=False,
                                          minValue=20,
                                          maxValue=150,
                                          callback=self.zoom_changed,
                                          createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            if isinstance(model, RandomForestClassifier):
                self.forest_type = self.CLASSIFICATION
            elif isinstance(model, RandomForestRegressor):
                self.forest_type = self.REGRESSION
            else:
                raise RuntimeError('Invalid type of forest.')

            self.forest_adapter = self._get_forest_adapter(self.model)
            self.color_palette = self._type_specific('_get_color_palette')()
            self._draw_trees()

            self.dataset = model.instances
            # this bit is important for the regression classifier
            if self.dataset is not None and \
                    self.dataset.domain != model.domain:
                self.clf_dataset = Table.from_table(self.model.domain,
                                                    self.dataset)
            else:
                self.clf_dataset = self.dataset

            self._update_info_box()
            self._type_specific('_update_target_class_combo')()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    # CONTROL AREA CALLBACKS
    def max_depth_changed(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def target_colors_changed(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_has_changed()

    def size_calc_changed(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self.grid.clear()
            self._draw_trees()
            # Keep the selected item
            if self.selected_tree_index != -1:
                self.grid_items[self.selected_tree_index].setSelected(True)
            self.max_depth_changed()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    # MODEL CHANGED METHODS
    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(
            len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    # MODEL CLEARED METHODS
    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    # HELPFUL METHODS
    def _get_max_depth(self):
        return max([tree.tree_adapter.max_depth for tree in self.ptrees])

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    def _draw_trees(self):
        self.ui_size_calc_combo.setEnabled(False)
        self.grid_items, self.ptrees = [], []

        with self.progressBar(len(self.forest_adapter.get_trees())) as prg:
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None,
                    tree,
                    node_color_func=self._type_specific('_get_node_color'),
                    interactive=False,
                    padding=100)
                self.grid_items.append(
                    GridItem(ptree,
                             self.grid,
                             max_size=self._calculate_zoom(self.zoom)))
                self.ptrees.append(ptree)
                prg.advance()
        self.grid.set_items(self.grid_items)
        # This is necessary when adding items for the first time
        if self.grid:
            width = (self.view.width() - self.view.verticalScrollBar().width())
            self.grid.reflow(width)
            self.grid.setPreferredWidth(width)
        self.ui_size_calc_combo.setEnabled(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if len(self.scene.selectedItems()) == 0:
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        obj = self.model.trees[self.selected_tree_index]
        obj.instances = self.dataset
        obj.meta_target_class_index = self.target_class_index
        obj.meta_size_calc_idx = self.size_calc_idx
        obj.meta_size_log_scale = self.size_log_scale
        obj.meta_depth_limit = self.depth_limit

        self.send('Tree', obj)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)

    def _type_specific(self, method):
        """A best effort method getter that somewhat separates logic specific
        to classification and regression trees.
        This relies on conventional naming of specific methods, e.g.
        a method name _get_tooltip would need to be defined like so:
        _classification_get_tooltip and _regression_get_tooltip, since they are
        both specific.

        Parameters
        ----------
        method : str
            Method name that we would like to call.

        Returns
        -------
        callable or None

        """
        if self.forest_type == self.CLASSIFICATION:
            return getattr(self, '_classification' + method)
        elif self.forest_type == self.REGRESSION:
            return getattr(self, '_regression' + method)
        else:
            return None

    # CLASSIFICATION FOREST SPECIFIC METHODS
    def _classification_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItem('None')
        values = [c.title() for c in self.model.domain.class_vars[0].values]
        self.ui_target_class_combo.addItems(values)

    def _classification_get_color_palette(self):
        return [QColor(*c) for c in self.model.domain.class_var.colors]

    def _classification_get_node_color(self, adapter, tree_node):
        # this is taken almost directly from the existing classification tree
        # viewer
        colors = self.color_palette
        distribution = adapter.get_distribution(tree_node.label)[0]
        total = np.sum(distribution)

        if self.target_class_index:
            p = distribution[self.target_class_index - 1] / total
            color = colors[self.target_class_index - 1].lighter(200 - 100 * p)
        else:
            modus = np.argmax(distribution)
            p = distribution[modus] / (total or 1)
            color = colors[int(modus)].lighter(400 - 300 * p)
        return color

    # REGRESSION FOREST SPECIFIC METHODS
    def _regression_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItems(
            list(zip(*self.REGRESSION_COLOR_CALC))[0])
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _regression_get_color_palette(self):
        return ContinuousPaletteGenerator(
            *self.forest_adapter.domain.class_var.colors)

    def _regression_get_node_color(self, adapter, tree_node):
        return self.REGRESSION_COLOR_CALC[self.target_class_index][1](
            adapter, tree_node)

    def _color_class_mean(self, adapter, tree_node):
        # calculate node colors relative to the mean of the node samples
        min_mean = np.min(self.clf_dataset.Y)
        max_mean = np.max(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        mean = np.mean(instances.Y)

        return self.color_palette[(mean - min_mean) / (max_mean - min_mean)]

    def _color_stddev(self, adapter, tree_node):
        # calculate node colors relative to the standard deviation in the node
        # samples
        min_mean, max_mean = 0, np.std(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        std = np.std(instances.Y)

        return self.color_palette[(std - min_mean) / (max_mean - min_mean)]
class OWExplainPredictions(OWWidget):

    name = "Explain Predictions"
    description = "Computes attribute contributions to the final prediction with an approximation algorithm for shapely value"
    icon = "icons/ExplainPredictions.svg"
    priority = 200
    gui_error = settings.Setting(0.05)
    gui_p_val = settings.Setting(0.05)
    gui_num_atr = settings.Setting(20)
    sort_index = settings.Setting(SortBy.ABSOLUTE)

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model, multiple=False)
        sample = Input("Sample", Table)

    class Outputs:
        explanations = Output("Explanations", Table)

    class Error(OWWidget.Error):
        sample_too_big = widget.Msg("Can only explain one sample at the time.")

    class Warning(OWWidget.Warning):
        unknowns_increased = widget.Msg(
            "Number of unknown values increased, Data and Sample domains mismatch."
        )

    def __init__(self):
        super().__init__()
        self.data = None
        self.model = None
        self.to_explain = None
        self.explanations = None
        self.stop = True
        self.e = None

        self._task = None
        self._executor = ThreadExecutor()

        info_box = gui.vBox(self.controlArea, "Info")
        self.data_info = gui.widgetLabel(info_box, "Data: N/A")
        self.model_info = gui.widgetLabel(info_box, "Model: N/A")
        self.sample_info = gui.widgetLabel(info_box, "Sample: N/A")

        criteria_box = gui.vBox(self.controlArea, "Stopping criteria")
        self.error_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_error",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error < ",
                                   spinType=float,
                                   callback=self._update_error_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        self.p_val_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_p_val",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error p-value < ",
                                   spinType=float,
                                   callback=self._update_p_val_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        plot_properties_box = gui.vBox(self.controlArea, "Display features")
        self.num_atr_spin = gui.spin(plot_properties_box,
                                     self,
                                     "gui_num_atr",
                                     1,
                                     100,
                                     step=1,
                                     label="Show attributes",
                                     callback=self._update_num_atr_spin,
                                     controlWidth=80,
                                     keyboardTracking=False)

        self.sort_combo = gui.comboBox(plot_properties_box,
                                       self,
                                       "sort_index",
                                       label="Rank by",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._update_combo)

        gui.rubber(self.controlArea)

        self.cancel_button = gui.button(
            self.controlArea,
            self,
            "Stop Computation",
            callback=self.toggle_button,
            autoDefault=True,
            tooltip="Stops and restarts computation")
        self.cancel_button.setDisabled(True)

        predictions_box = gui.vBox(self.mainArea, "Model prediction")
        self.predict_info = gui.widgetLabel(predictions_box, "")

        self.mainArea.setMinimumWidth(700)
        self.resize(700, 400)

        class _GraphicsView(QGraphicsView):
            def __init__(self, scene, parent, **kwargs):
                for k, v in dict(
                        verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                        horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                        viewportUpdateMode=QGraphicsView.
                        BoundingRectViewportUpdate,
                        renderHints=(QPainter.Antialiasing
                                     | QPainter.TextAntialiasing
                                     | QPainter.SmoothPixmapTransform),
                        alignment=(Qt.AlignTop | Qt.AlignLeft),
                        sizePolicy=QSizePolicy(
                            QSizePolicy.MinimumExpanding,
                            QSizePolicy.MinimumExpanding)).items():
                    kwargs.setdefault(k, v)
                super().__init__(scene, parent, **kwargs)

        class GraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(
                    scene,
                    parent,
                    verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                    styleSheet='QGraphicsView {background: white}')
                self.viewport().setMinimumWidth(500)
                self._is_resizing = False

            w = self

            def resizeEvent(self, resizeEvent):
                self._is_resizing = True
                self.w.draw()
                self._is_resizing = False
                return super().resizeEvent(resizeEvent)

            def is_resizing(self):
                return self._is_resizing

            def sizeHint(self):
                return QSize(600, 300)

        class FixedSizeGraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene,
                                 parent,
                                 sizePolicy=QSizePolicy(
                                     QSizePolicy.MinimumExpanding,
                                     QSizePolicy.Minimum))

            def sizeHint(self):
                return QSize(600, 30)

        """all will share the same scene, but will show different parts of it"""
        self.box_scene = QGraphicsScene(self)

        self.box_view = GraphicsView(self.box_scene, self)
        self.header_view = FixedSizeGraphicsView(self.box_scene, self)
        self.footer_view = FixedSizeGraphicsView(self.box_scene, self)

        self.mainArea.layout().addWidget(self.header_view)
        self.mainArea.layout().addWidget(self.box_view)
        self.mainArea.layout().addWidget(self.footer_view)

        self.painter = None

    def draw(self):
        """Uses GraphAttributes class to draw the explanaitons """
        self.box_scene.clear()
        wp = self.box_view.viewport().rect()
        header_height = 30
        if self.explanations is not None:
            self.painter = GraphAttributes(
                self.box_scene,
                min(self.gui_num_atr, self.explanations.Y.shape[0]))
            self.painter.paint(wp, self.explanations, header_h=header_height)
        """set appropriate boxes for different views"""
        rect = QRectF(self.box_scene.itemsBoundingRect().x(),
                      self.box_scene.itemsBoundingRect().y(),
                      self.box_scene.itemsBoundingRect().width(),
                      self.box_scene.itemsBoundingRect().height())

        self.box_scene.setSceneRect(rect)
        self.box_view.setSceneRect(rect.x(),
                                   rect.y() + header_height + 2, rect.width(),
                                   rect.height() - 80)
        self.header_view.setSceneRect(rect.x(), rect.y(), rect.width(), 10)
        self.header_view.setFixedHeight(header_height)
        self.footer_view.setSceneRect(rect.x(),
                                      rect.y() + rect.height() - 50,
                                      rect.width(), 35)

    def sort_explanations(self):
        """sorts explanations according to users choice from combo box"""
        if self.sort_index == SortBy.POSITIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])][::-1]
        elif self.sort_index == SortBy.NEGATIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])]
        elif self.sort_index == SortBy.ABSOLUTE:
            self.explanations = self.explanations[np.argsort(
                np.abs(self.explanations.X[:, 0]))][::-1]
        elif self.sort_index == SortBy.BY_NAME:
            l = np.array(
                list(map(np.chararray.lower, self.explanations.metas[:, 0])))
            self.explanations = self.explanations[np.argsort(l)]
        else:
            return

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input 'Data"""
        self.data = data
        self.explanations = None
        self.data_info.setText("Data: N/A")
        self.e = None
        if data is not None:
            model = TableModel(data, parent=None)
            if data.X.shape[0] == 1:
                inst = "1 instance and "
            else:
                inst = str(data.X.shape[0]) + " instances and "
            if data.X.shape[1] == 1:
                feat = "1 feature "
            else:
                feat = str(data.X.shape[1]) + " features"
            self.data_info.setText("Data: " + inst + feat)

    @Inputs.model
    def set_predictor(self, model):
        """Set input 'Model"""
        self.model = model
        self.model_info.setText("Model: N/A")
        self.explanations = None
        self.e = None
        if model is not None:
            self.model_info.setText("Model: " + str(model.name))

    @Inputs.sample
    @check_sql_input
    def set_sample(self, sample):
        """Set input 'Sample', checks if size is appropriate"""
        self.to_explain = sample
        self.explanations = None
        self.Error.sample_too_big.clear()
        self.sample_info.setText("Sample: N/A")
        if sample is not None:
            if len(sample.X) != 1:
                self.to_explain = None
                self.Error.sample_too_big()
            else:
                if sample.X.shape[1] == 1:
                    feat = "1 feature"
                else:
                    feat = str(sample.X.shape[1]) + " features"
                self.sample_info.setText("Sample: " + feat)
                if self.e is not None:
                    self.e.saved = False

    def handleNewSignals(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.predict_info.setText("")
        self.Warning.unknowns_increased.clear()
        self.stop = True
        self.cancel_button.setText("Stop Computation")
        self.commit_calc_or_output()

    def commit_calc_or_output(self):
        if self.data is not None and self.to_explain is not None:
            self.commit_calc()
        else:
            self.commit_output()

    def commit_calc(self):
        num_nan = np.count_nonzero(np.isnan(self.to_explain.X[0]))

        self.to_explain = self.to_explain.transform(self.data.domain)
        if num_nan != np.count_nonzero(np.isnan(self.to_explain.X[0])):
            self.Warning.unknowns_increased()
        if self.model is not None:
            # calculate contributions
            if self.e is None:
                self.e = ExplainPredictions(self.data,
                                            self.model,
                                            batch_size=min(
                                                len(self.data.X), 500),
                                            p_val=self.gui_p_val,
                                            error=self.gui_error)
            self._task = task = Task()

            def callback(progress):
                nonlocal task
                # update progress bar
                QMetaObject.invokeMethod(self, "set_progress_value",
                                         Qt.QueuedConnection,
                                         Q_ARG(int, progress))
                if task.canceled:
                    return True
                return False

            def callback_update(table):
                QMetaObject.invokeMethod(self, "update_view",
                                         Qt.QueuedConnection,
                                         Q_ARG(Orange.data.Table, table))

            def callback_prediction(class_value):
                QMetaObject.invokeMethod(self, "update_model_prediction",
                                         Qt.QueuedConnection,
                                         Q_ARG(float, class_value))

            self.was_canceled = False
            explain_func = partial(self.e.anytime_explain,
                                   self.to_explain[0],
                                   callback=callback,
                                   update_func=callback_update,
                                   update_prediction=callback_prediction)

            self.progressBarInit(processEvents=None)
            task.future = self._executor.submit(explain_func)
            task.watcher = FutureWatcher(task.future)
            task.watcher.done.connect(self._task_finished)
            self.cancel_button.setDisabled(False)

    @pyqtSlot(Orange.data.Table)
    def update_view(self, table):
        self.explanations = table
        self.sort_explanations()
        self.draw()
        self.commit_output()

    @pyqtSlot(float)
    def update_model_prediction(self, value):
        self._print_prediction(value)

    @pyqtSlot(int)
    def set_progress_value(self, value):
        self.progressBarSet(value, processEvents=False)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters:
        ----------
        f: conncurent.futures.Future
            future instance holding the result of learner evaluation
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None

        if not self.was_canceled:
            self.cancel_button.setDisabled(True)

        try:
            results = f.result()
        except Exception as ex:
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occured during evaluation: {!r}".format(ex))

            for key in self.results.keys():
                self.results[key] = None
        else:
            self.update_view(results[1])

        self.progressBarFinished(processEvents=False)

    def commit_output(self):
        """
        Sends best-so-far results forward
        """
        self.Outputs.explanations.send(self.explanations)

    def toggle_button(self):
        if self.stop:
            self.stop = False
            self.cancel_button.setText("Restart Computation")
            self.cancel()
        else:
            self.stop = True
            self.cancel_button.setText("Stop Computation")
            self.commit_calc_or_output()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self.was_canceled = True
            self._task_finished(self._task.future)

    def _print_prediction(self, class_value):
        """
        Parameters
        ----------
        class_value: float 
            Number representing either index of predicted class value, looked up in domain, or predicted value (regression)
        """
        name = self.data.domain.class_vars[0].name
        if isinstance(self.data.domain.class_vars[0], ContinuousVariable):
            self.predict_info.setText(name + ":      " + str(class_value))
        else:
            self.predict_info.setText(
                name + ":      " +
                self.data.domain.class_vars[0].values[int(class_value)])

    def _update_error_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.error = self.gui_error
        self.handleNewSignals()

    def _update_p_val_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.p_val = self.gui_p_val
        self.handleNewSignals()

    def _update_num_atr_spin(self):
        self.cancel()
        self.handleNewSignals()

    def _update_combo(self):
        if self.explanations != None:
            self.sort_explanations()
            self.draw()
            self.commit_output()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Exemplo n.º 8
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    def __init__(self):
        super().__init__()
        self.model = None
        self.forest_adapter = None
        self.instances = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []
        # In some rare cases, we need to prevent commiting, the only one
        # that this currently helps is that when changing the size calculation
        # the trees are all recomputed, but we don't want to output a new tree
        # to keep things consistent with other ui controls.
        self.__prevent_commit = False

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x + 1)),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info)

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(
            box_display, self, 'depth_limit', label='Depth', ticks=False,
            callback=self.update_depth)
        self.ui_target_class_combo = gui.comboBox(
            box_display, self, 'target_class_index', label='Target class',
            orientation=Qt.Horizontal, items=[], contentsLength=8,
            callback=self.update_colors)
        self.ui_size_calc_combo = gui.comboBox(
            box_display, self, 'size_calc_idx', label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8,
            callback=self.update_size_calc)
        self.ui_zoom_slider = gui.hSlider(
            box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20,
            maxValue=150, callback=self.zoom_changed, createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self._draw_trees()
            self.color_palette = self.forest_adapter.get_trees()[0]

            self.instances = model.instances
            # this bit is important for the regression classifier
            if self.instances is not None and self.instances.domain != model.domain:
                self.clf_dataset = self.instances.transform(self.model.domain)
            else:
                self.clf_dataset = self.instances

            self._update_info_box()
            self._update_target_class_combo()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    def update_depth(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def update_colors(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_changed(self.target_class_index)

    def update_size_calc(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            with self._prevent_commit():
                self.grid.clear()
                self._draw_trees()
                # Keep the selected item
                if self.selected_tree_index != -1:
                    self.grid_items[self.selected_tree_index].setSelected(True)
                self.update_depth()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    @contextmanager
    def _prevent_commit(self):
        try:
            self.__prevent_commit = True
            yield
        finally:
            self.__prevent_commit = False

    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    def _get_max_depth(self):
        return max(tree.tree_adapter.max_depth for tree in self.ptrees)

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    @contextmanager
    def disable_ui(self):
        """Temporarly disable the UI while trees may be redrawn."""
        try:
            self.ui_size_calc_combo.setEnabled(False)
            self.ui_depth_slider.setEnabled(False)
            self.ui_target_class_combo.setEnabled(False)
            self.ui_zoom_slider.setEnabled(False)
            yield
        finally:
            self.ui_size_calc_combo.setEnabled(True)
            self.ui_depth_slider.setEnabled(True)
            self.ui_target_class_combo.setEnabled(True)
            self.ui_zoom_slider.setEnabled(True)

    def _draw_trees(self):
        self.grid_items, self.ptrees = [], []

        num_trees = len(self.forest_adapter.get_trees())
        with self.progressBar(num_trees) as prg, self.disable_ui():
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None, tree, interactive=False, padding=100,
                    target_class_index=self.target_class_index,
                    weight_adjustment=self.SIZE_CALCULATION[self.size_calc_idx][1]
                )
                grid_item = GridItem(
                    ptree, self.grid, max_size=self._calculate_zoom(self.zoom)
                )
                # We don't want to show flickering while the trees are being
                grid_item.setVisible(False)

                self.grid_items.append(grid_item)
                self.ptrees.append(ptree)
                prg.advance()

            self.grid.set_items(self.grid_items)
            # This is necessary when adding items for the first time
            if self.grid:
                width = (self.view.width() - self.view.verticalScrollBar().width())
                self.grid.reflow(width)
                self.grid.setPreferredWidth(width)
                # After drawing is complete, we show the trees
                for grid_item in self.grid_items:
                    grid_item.setVisible(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if self.__prevent_commit:
            return

        if not self.scene.selectedItems():
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        tree = self.model.trees[self.selected_tree_index]
        tree.instances = self.instances
        tree.meta_target_class_index = self.target_class_index
        tree.meta_size_calc_idx = self.size_calc_idx
        tree.meta_depth_limit = self.depth_limit

        self.send('Tree', tree)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [x for x in self.ui_target_class_combo.parent().children()
                 if isinstance(x, QLabel)][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [c.title() for c in self.instances.domain.class_vars[0].values]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.ui_target_class_combo.addItems(values)
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)
class OWExplainPredictions(OWWidget):

    name = "Explain Predictions"
    description = "Computes attribute contributions to the final prediction with an approximation algorithm for shapely value"
    icon = "icons/ExplainPredictions.svg"
    priority = 200
    gui_error = settings.Setting(0.05)
    gui_p_val = settings.Setting(0.05)
    gui_num_atr = settings.Setting(20)
    sort_index = settings.Setting(SortBy.ABSOLUTE)

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model, multiple=False)
        sample = Input("Sample", Table)

    class Outputs:
        explanations = Output("Explanations", Table)

    class Error(OWWidget.Error):
        sample_too_big = widget.Msg("Can only explain one sample at the time.")

    class Warning(OWWidget.Warning):
        unknowns_increased = widget.Msg(
            "Number of unknown values increased, Data and Sample domains mismatch.")

    def __init__(self):
        super().__init__()
        self.data = None
        self.model = None
        self.to_explain = None
        self.explanations = None
        self.stop = True
        self.e = None

        self._task = None
        self._executor = ThreadExecutor()

        info_box = gui.vBox(self.controlArea, "Info")
        self.data_info = gui.widgetLabel(info_box, "Data: N/A")
        self.model_info = gui.widgetLabel(info_box, "Model: N/A")
        self.sample_info = gui.widgetLabel(info_box, "Sample: N/A")

        criteria_box = gui.vBox(self.controlArea, "Stopping criteria")
        self.error_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_error",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error < ",
                                   spinType=float,
                                   callback=self._update_error_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        self.p_val_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_p_val",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error p-value < ",
                                   spinType=float,
                                   callback=self._update_p_val_spin,
                                   controlWidth=80, keyboardTracking=False)

        plot_properties_box = gui.vBox(self.controlArea, "Display features")
        self.num_atr_spin = gui.spin(plot_properties_box,
                                     self,
                                     "gui_num_atr",
                                     1,
                                     100,
                                     step=1,
                                     label="Show attributes",
                                     callback=self._update_num_atr_spin,
                                     controlWidth=80,
                                     keyboardTracking=False)

        self.sort_combo = gui.comboBox(plot_properties_box,
                                       self,
                                       "sort_index",
                                       label="Rank by",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._update_combo)

        gui.rubber(self.controlArea)

        self.cancel_button = gui.button(self.controlArea,
                                        self,
                                        "Stop Computation",
                                        callback=self.toggle_button,
                                        autoDefault=True,
                                        tooltip="Stops and restarts computation")
        self.cancel_button.setDisabled(True)

        predictions_box = gui.vBox(self.mainArea, "Model prediction")
        self.predict_info = gui.widgetLabel(predictions_box, "")

        self.mainArea.setMinimumWidth(700)
        self.resize(700, 400)

        class _GraphicsView(QGraphicsView):
            def __init__(self, scene, parent, **kwargs):
                for k, v in dict(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 viewportUpdateMode=QGraphicsView.BoundingRectViewportUpdate,
                                 renderHints=(QPainter.Antialiasing |
                                              QPainter.TextAntialiasing |
                                              QPainter.SmoothPixmapTransform),
                                 alignment=(Qt.AlignTop |
                                            Qt.AlignLeft),
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.MinimumExpanding)).items():
                    kwargs.setdefault(k, v)
                super().__init__(scene, parent, **kwargs)

        class GraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                 styleSheet='QGraphicsView {background: white}')
                self.viewport().setMinimumWidth(500)
                self._is_resizing = False

            w = self

            def resizeEvent(self, resizeEvent):
                self._is_resizing = True
                self.w.draw()
                self._is_resizing = False
                return super().resizeEvent(resizeEvent)

            def is_resizing(self):
                return self._is_resizing

            def sizeHint(self):
                return QSize(600, 300)

        class FixedSizeGraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.Minimum))

            def sizeHint(self):
                return QSize(600, 30)

        """all will share the same scene, but will show different parts of it"""
        self.box_scene = QGraphicsScene(self)

        self.box_view = GraphicsView(self.box_scene, self)
        self.header_view = FixedSizeGraphicsView(self.box_scene, self)
        self.footer_view = FixedSizeGraphicsView(self.box_scene, self)

        self.mainArea.layout().addWidget(self.header_view)
        self.mainArea.layout().addWidget(self.box_view)
        self.mainArea.layout().addWidget(self.footer_view)

        self.painter = None

    def draw(self):
        """Uses GraphAttributes class to draw the explanaitons """
        self.box_scene.clear()
        wp = self.box_view.viewport().rect()
        header_height = 30
        if self.explanations is not None:
            self.painter = GraphAttributes(self.box_scene, min(
                self.gui_num_atr, self.explanations.Y.shape[0]))
            self.painter.paint(wp, self.explanations, header_h=header_height)

        """set appropriate boxes for different views"""
        rect = QRectF(self.box_scene.itemsBoundingRect().x(),
                      self.box_scene.itemsBoundingRect().y(),
                      self.box_scene.itemsBoundingRect().width(),
                      self.box_scene.itemsBoundingRect().height())

        self.box_scene.setSceneRect(rect)
        self.box_view.setSceneRect(
            rect.x(), rect.y()+header_height+2, rect.width(), rect.height() - 80)
        self.header_view.setSceneRect(
            rect.x(), rect.y(), rect.width(), 10)
        self.header_view.setFixedHeight(header_height)
        self.footer_view.setSceneRect(
            rect.x(), rect.y() + rect.height() - 50, rect.width(), 35)

    def sort_explanations(self):
        """sorts explanations according to users choice from combo box"""
        if self.sort_index == SortBy.POSITIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])][::-1]
        elif self.sort_index == SortBy.NEGATIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])]
        elif self.sort_index == SortBy.ABSOLUTE:
            self.explanations = self.explanations[np.argsort(
                np.abs(self.explanations.X[:, 0]))][::-1]
        elif self.sort_index == SortBy.BY_NAME:
            l = np.array(
                list(map(np.chararray.lower, self.explanations.metas[:, 0])))
            self.explanations = self.explanations[np.argsort(l)]
        else:
            return

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input 'Data"""
        self.data = data
        self.explanations = None
        self.data_info.setText("Data: N/A")
        self.e = None
        if data is not None:
            model = TableModel(data, parent=None)
            if data.X.shape[0] == 1:
                inst = "1 instance and "
            else:
                inst = str(data.X.shape[0]) + " instances and "
            if data.X.shape[1] == 1:
                feat = "1 feature "
            else:
                feat = str(data.X.shape[1]) + " features"
            self.data_info.setText("Data: " + inst + feat)

    @Inputs.model
    def set_predictor(self, model):
        """Set input 'Model"""
        self.model = model
        self.model_info.setText("Model: N/A")
        self.explanations = None
        self.e = None
        if model is not None:
            self.model_info.setText("Model: " + str(model.name))

    @Inputs.sample
    @check_sql_input
    def set_sample(self, sample):
        """Set input 'Sample', checks if size is appropriate"""
        self.to_explain = sample
        self.explanations = None
        self.Error.sample_too_big.clear()
        self.sample_info.setText("Sample: N/A")
        if sample is not None:
            if len(sample.X) != 1:
                self.to_explain = None
                self.Error.sample_too_big()
            else:
                if sample.X.shape[1] == 1:
                    feat = "1 feature"
                else:
                    feat = str(sample.X.shape[1]) + " features"
                self.sample_info.setText("Sample: " + feat)
                if self.e is not None:
                    self.e.saved = False

    def handleNewSignals(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.predict_info.setText("")
        self.Warning.unknowns_increased.clear()
        self.stop = True
        self.cancel_button.setText("Stop Computation")
        self.commit_calc_or_output()

    def commit_calc_or_output(self):
        if self.data is not None and self.to_explain is not None:
            self.commit_calc()
        else:
            self.commit_output()

    def commit_calc(self):
        num_nan = np.count_nonzero(np.isnan(self.to_explain.X[0]))

        self.to_explain = self.to_explain.transform(self.data.domain)
        if num_nan != np.count_nonzero(np.isnan(self.to_explain.X[0])):
            self.Warning.unknowns_increased()
        if self.model is not None:
            # calculate contributions
            if self.e is None:
                self.e = ExplainPredictions(self.data,
                                            self.model,
                                            batch_size=min(
                                                len(self.data.X), 500),
                                            p_val=self.gui_p_val,
                                            error=self.gui_error)
            self._task = task = Task()

            def callback(progress):
                nonlocal task
                # update progress bar
                QMetaObject.invokeMethod(
                    self, "set_progress_value", Qt.QueuedConnection, Q_ARG(int, progress))
                if task.canceled:
                    return True
                return False

            def callback_update(table):
                QMetaObject.invokeMethod(
                    self, "update_view", Qt.QueuedConnection, Q_ARG(Orange.data.Table, table))

            def callback_prediction(class_value):
                QMetaObject.invokeMethod(
                    self, "update_model_prediction", Qt.QueuedConnection, Q_ARG(float, class_value))

            self.was_canceled = False
            explain_func = partial(
                self.e.anytime_explain, self.to_explain[0], callback=callback, update_func=callback_update, update_prediction=callback_prediction)

            self.progressBarInit(processEvents=None)
            task.future = self._executor.submit(explain_func)
            task.watcher = FutureWatcher(task.future)
            task.watcher.done.connect(self._task_finished)
            self.cancel_button.setDisabled(False)

    @pyqtSlot(Orange.data.Table)
    def update_view(self, table):
        self.explanations = table
        self.sort_explanations()
        self.draw()
        self.commit_output()

    @pyqtSlot(float)
    def update_model_prediction(self, value):
        self._print_prediction(value)

    @pyqtSlot(int)
    def set_progress_value(self, value):
        self.progressBarSet(value, processEvents=False)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters:
        ----------
        f: conncurent.futures.Future
            future instance holding the result of learner evaluation
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None

        if not self.was_canceled:
            self.cancel_button.setDisabled(True)

        try:
            results = f.result()
        except Exception as ex:
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occured during evaluation: {!r}".format(ex))

            for key in self.results.keys():
                self.results[key] = None
        else:
            self.update_view(results[1])

        self.progressBarFinished(processEvents=False)

    def commit_output(self):
        """
        Sends best-so-far results forward
        """
        self.Outputs.explanations.send(self.explanations)

    def toggle_button(self):
        if self.stop:
            self.stop = False
            self.cancel_button.setText("Restart Computation")
            self.cancel()
        else:
            self.stop = True
            self.cancel_button.setText("Stop Computation")
            self.commit_calc_or_output()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self.was_canceled = True
            self._task_finished(self._task.future)

    def _print_prediction(self, class_value):
        """
        Parameters
        ----------
        class_value: float 
            Number representing either index of predicted class value, looked up in domain, or predicted value (regression)
        """
        name = self.data.domain.class_vars[0].name
        if isinstance(self.data.domain.class_vars[0], ContinuousVariable):
            self.predict_info.setText(name + ":      " + str(class_value))
        else:
            self.predict_info.setText(
                name + ":      " + self.data.domain.class_vars[0].values[int(class_value)])

    def _update_error_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.error = self.gui_error
        self.handleNewSignals()

    def _update_p_val_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.p_val = self.gui_p_val
        self.handleNewSignals()

    def _update_num_atr_spin(self):
        self.cancel()
        self.handleNewSignals()

    def _update_combo(self):
        if self.explanations != None:
            self.sort_explanations()
            self.draw()
            self.commit_output()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()