Ejemplo n.º 1
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    size_log_scale = settings.Setting(2)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    CLASSIFICATION, REGRESSION = range(2)

    def __init__(self):
        super().__init__()
        # Instance variables
        self.forest_type = self.CLASSIFICATION
        self.model = None
        self.forest_adapter = None
        self.dataset = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x * self.size_log_scale)),
        ]

        self.REGRESSION_COLOR_CALC = [
            ('None', lambda _, __: QColor(255, 255, 255)),
            ('Class mean', self._color_class_mean),
            ('Standard deviation', self._color_stddev),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info, label='')

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(
            box_display, self, 'depth_limit', label='Depth', ticks=False,
            callback=self.max_depth_changed)
        self.ui_target_class_combo = gui.comboBox(
            box_display, self, 'target_class_index', label='Target class',
            orientation=Qt.Horizontal, items=[], contentsLength=8,
            callback=self.target_colors_changed)
        self.ui_size_calc_combo = gui.comboBox(
            box_display, self, 'size_calc_idx', label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8,
            callback=self.size_calc_changed)
        self.ui_zoom_slider = gui.hSlider(
            box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20,
            maxValue=150, callback=self.zoom_changed, createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(
            QSizePolicy.Preferred, QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            if isinstance(model, RandomForestClassifier):
                self.forest_type = self.CLASSIFICATION
            elif isinstance(model, RandomForestRegressor):
                self.forest_type = self.REGRESSION
            else:
                raise RuntimeError('Invalid type of forest.')

            self.forest_adapter = self._get_forest_adapter(self.model)
            self.color_palette = self._type_specific('_get_color_palette')()
            self._draw_trees()

            self.dataset = model.instances
            # this bit is important for the regression classifier
            if self.dataset is not None and \
                    self.dataset.domain != model.domain:
                self.clf_dataset = Table.from_table(
                    self.model.domain, self.dataset)
            else:
                self.clf_dataset = self.dataset

            self._update_info_box()
            self._type_specific('_update_target_class_combo')()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    # CONTROL AREA CALLBACKS
    def max_depth_changed(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def target_colors_changed(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_has_changed()

    def size_calc_changed(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self.grid.clear()
            self._draw_trees()
            # Keep the selected item
            if self.selected_tree_index != -1:
                self.grid_items[self.selected_tree_index].setSelected(True)
            self.max_depth_changed()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    # MODEL CHANGED METHODS
    def _update_info_box(self):
        self.ui_info.setText(
            'Trees: {}'.format(len(self.forest_adapter.get_trees()))
        )

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    # MODEL CLEARED METHODS
    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    # HELPFUL METHODS
    def _get_max_depth(self):
        return max([tree.tree_adapter.max_depth for tree in self.ptrees])

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    def _draw_trees(self):
        self.ui_size_calc_combo.setEnabled(False)
        self.grid_items, self.ptrees = [], []

        with self.progressBar(len(self.forest_adapter.get_trees())) as prg:
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None, tree,
                    node_color_func=self._type_specific('_get_node_color'),
                    interactive=False, padding=100)
                self.grid_items.append(GridItem(
                    ptree, self.grid, max_size=self._calculate_zoom(self.zoom)
                ))
                self.ptrees.append(ptree)
                prg.advance()
        self.grid.set_items(self.grid_items)
        # This is necessary when adding items for the first time
        if self.grid:
            width = (self.view.width() -
                     self.view.verticalScrollBar().width())
            self.grid.reflow(width)
            self.grid.setPreferredWidth(width)
        self.ui_size_calc_combo.setEnabled(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if len(self.scene.selectedItems()) == 0:
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        obj = self.model.trees[self.selected_tree_index]
        obj.instances = self.dataset
        obj.meta_target_class_index = self.target_class_index
        obj.meta_size_calc_idx = self.size_calc_idx
        obj.meta_size_log_scale = self.size_log_scale
        obj.meta_depth_limit = self.depth_limit

        self.send('Tree', obj)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)

    def _type_specific(self, method):
        """A best effort method getter that somewhat separates logic specific
        to classification and regression trees.
        This relies on conventional naming of specific methods, e.g.
        a method name _get_tooltip would need to be defined like so:
        _classification_get_tooltip and _regression_get_tooltip, since they are
        both specific.

        Parameters
        ----------
        method : str
            Method name that we would like to call.

        Returns
        -------
        callable or None

        """
        if self.forest_type == self.CLASSIFICATION:
            return getattr(self, '_classification' + method)
        elif self.forest_type == self.REGRESSION:
            return getattr(self, '_regression' + method)
        else:
            return None

    # CLASSIFICATION FOREST SPECIFIC METHODS
    def _classification_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItem('None')
        values = [c.title() for c in
                  self.model.domain.class_vars[0].values]
        self.ui_target_class_combo.addItems(values)

    def _classification_get_color_palette(self):
        return [QColor(*c) for c in self.model.domain.class_var.colors]

    def _classification_get_node_color(self, adapter, tree_node):
        # this is taken almost directly from the existing classification tree
        # viewer
        colors = self.color_palette
        distribution = adapter.get_distribution(tree_node.label)[0]
        total = np.sum(distribution)

        if self.target_class_index:
            p = distribution[self.target_class_index - 1] / total
            color = colors[self.target_class_index - 1].lighter(200 - 100 * p)
        else:
            modus = np.argmax(distribution)
            p = distribution[modus] / (total or 1)
            color = colors[int(modus)].lighter(400 - 300 * p)
        return color

    # REGRESSION FOREST SPECIFIC METHODS
    def _regression_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItems(
            list(zip(*self.REGRESSION_COLOR_CALC))[0])
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _regression_get_color_palette(self):
        return ContinuousPaletteGenerator(
            *self.forest_adapter.domain.class_var.colors)

    def _regression_get_node_color(self, adapter, tree_node):
        return self.REGRESSION_COLOR_CALC[self.target_class_index][1](
            adapter, tree_node
        )

    def _color_class_mean(self, adapter, tree_node):
        # calculate node colors relative to the mean of the node samples
        min_mean = np.min(self.clf_dataset.Y)
        max_mean = np.max(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        mean = np.mean(instances.Y)

        return self.color_palette[(mean - min_mean) / (max_mean - min_mean)]

    def _color_stddev(self, adapter, tree_node):
        # calculate node colors relative to the standard deviation in the node
        # samples
        min_mean, max_mean = 0, np.std(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        std = np.std(instances.Y)

        return self.color_palette[(std - min_mean) / (max_mean - min_mean)]
Ejemplo n.º 2
0
class OWSilhouettePlot(widget.OWWidget):
    name = "Silhouette Plot"
    description = "Visually assess cluster quality and " \
                  "the degree of cluster membership."

    icon = "icons/SilhouettePlot.svg"
    priority = 300
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data", Orange.data.Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    replaces = [
        "orangecontrib.prototypes.widgets.owsilhouetteplot.OWSilhouettePlot",
        "Orange.widgets.unsupervised.owsilhouetteplot.OWSilhouettePlot"
    ]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Distance metric index
    distance_idx = settings.Setting(0)
    #: Group/cluster variable index
    cluster_var_idx = settings.ContextSetting(0)
    #: Annotation variable index
    annotation_var_idx = settings.ContextSetting(0)
    #: Group the (displayed) silhouettes by cluster
    group_by_cluster = settings.Setting(True)
    #: A fixed size for an instance bar
    bar_size = settings.Setting(3)
    #: Add silhouette scores to output data
    add_scores = settings.Setting(False)
    auto_commit = settings.Setting(True)

    Distances = [("Euclidean", Orange.distance.Euclidean),
                 ("Manhattan", Orange.distance.Manhattan),
                 ("Cosine", Orange.distance.Cosine)]

    graph_name = "scene"
    buttons_area_orientation = Qt.Vertical

    class Error(widget.OWWidget.Error):
        need_two_clusters = Msg("Need at least two non-empty clusters")
        singleton_clusters_all = Msg("All clusters are singletons")
        memory_error = Msg("Not enough memory")
        value_error = Msg("Distances could not be computed: '{}'")

    class Warning(widget.OWWidget.Warning):
        missing_cluster_assignment = Msg(
            "{} instance{s} omitted (missing cluster assignment)")
        nan_distances = Msg("{} instance{s} omitted (undefined distances)")
        ignoring_categorical = Msg("Ignoring categorical features")

    def __init__(self):
        super().__init__()
        #: The input data
        self.data = None         # type: Optional[Orange.data.Table]
        #: Distance matrix computed from data
        self._matrix = None      # type: Optional[Orange.misc.DistMatrix]
        #: An bool mask (size == len(data)) indicating missing group/cluster
        #: assignments
        self._mask = None        # type: Optional[np.ndarray]
        #: An array of cluster/group labels for instances with valid group
        #: assignment
        self._labels = None      # type: Optional[np.ndarray]
        #: An array of silhouette scores for instances with valid group
        #: assignment
        self._silhouette = None  # type: Optional[np.ndarray]
        self._silplot = None     # type: Optional[SilhouettePlot]

        gui.comboBox(
            self.controlArea, self, "distance_idx", box="Distance",
            items=[name for name, _ in OWSilhouettePlot.Distances],
            orientation=Qt.Horizontal, callback=self._invalidate_distances)

        box = gui.vBox(self.controlArea, "Cluster Label")
        self.cluster_var_cb = gui.comboBox(
            box, self, "cluster_var_idx", contentsLength=14, addSpace=4,
            callback=self._invalidate_scores
        )
        gui.checkBox(
            box, self, "group_by_cluster", "Group by cluster",
            callback=self._replot)
        self.cluster_var_model = itemmodels.VariableListModel(parent=self)
        self.cluster_var_cb.setModel(self.cluster_var_model)

        box = gui.vBox(self.controlArea, "Bars")
        gui.widgetLabel(box, "Bar width:")
        gui.hSlider(
            box, self, "bar_size", minValue=1, maxValue=10, step=1,
            callback=self._update_bar_size, addSpace=6)
        gui.widgetLabel(box, "Annotations:")
        self.annotation_cb = gui.comboBox(
            box, self, "annotation_var_idx", contentsLength=14,
            callback=self._update_annotations)
        self.annotation_var_model = itemmodels.VariableListModel(parent=self)
        self.annotation_var_model[:] = ["None"]
        self.annotation_cb.setModel(self.annotation_var_model)
        ibox = gui.indentedBox(box, 5)
        self.ann_hidden_warning = warning = gui.widgetLabel(
            ibox, "(increase the width to show)")
        ibox.setFixedWidth(ibox.sizeHint().width())
        warning.setVisible(False)

        gui.rubber(self.controlArea)

        gui.separator(self.buttonsArea)
        box = gui.vBox(self.buttonsArea, "Output")
        # Thunk the call to commit to call conditional commit
        gui.checkBox(box, self, "add_scores", "Add silhouette scores",
                     callback=lambda: self.commit())
        gui.auto_commit(
            box, self, "auto_commit", "Commit",
            auto_label="Auto commit", box=False)
        # Ensure that the controlArea is not narrower than buttonsArea
        self.controlArea.layout().addWidget(self.buttonsArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(600, 720))

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """
        Set the input dataset.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        warning_msg = ""
        candidatevars = []
        if data is not None:
            candidatevars = [
                v for v in data.domain.variables + data.domain.metas
                if v.is_discrete and len(v.values) >= 2]
            if not candidatevars:
                error_msg = "Input does not have any suitable labels."
                data = None

        self.data = data
        if data is not None:
            self.cluster_var_model[:] = candidatevars
            if data.domain.class_var in candidatevars:
                self.cluster_var_idx = \
                    candidatevars.index(data.domain.class_var)
            else:
                self.cluster_var_idx = 0

            annotvars = [var for var in data.domain.metas if var.is_string]
            self.annotation_var_model[:] = ["None"] + annotvars
            self.annotation_var_idx = 1 if len(annotvars) else 0
            self.openContext(Orange.data.Domain(candidatevars))

        self.error(error_msg)
        self.warning(warning_msg)

    def handleNewSignals(self):
        if self.data is not None:
            self._update()
            self._replot()

        self.unconditional_commit()

    def clear(self):
        """
        Clear the widget state.
        """
        self.data = None
        self._matrix = None
        self._mask = None
        self._silhouette = None
        self._labels = None
        self.cluster_var_model[:] = []
        self.annotation_var_model[:] = ["None"]
        self._clear_scene()
        self.Error.clear()
        self.Warning.clear()

    def _clear_scene(self):
        # Clear the graphics scene and associated objects
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self._silplot = None

    def _invalidate_distances(self):
        # Invalidate the computed distance matrix and recompute the silhouette.
        self._matrix = None
        self._invalidate_scores()

    def _invalidate_scores(self):
        # Invalidate and recompute the current silhouette scores.
        self._labels = self._silhouette = self._mask = None
        self._update()
        self._replot()
        if self.data is not None:
            self.commit()

    def _update(self):
        # Update/recompute the distances/scores as required
        self._clear_messages()

        if self.data is None or not len(self.data):
            self._reset_all()
            return

        if self._matrix is None and self.data is not None:
            _, metric = self.Distances[self.distance_idx]
            data = self.data
            if not metric.supports_discrete and any(
                    a.is_discrete for a in data.domain.attributes):
                self.Warning.ignoring_categorical()
                data = Orange.distance.remove_discrete_features(data)
            try:
                self._matrix = np.asarray(metric(data))
            except MemoryError:
                self.Error.memory_error()
                return
            except ValueError as err:
                self.Error.value_error(str(err))
                return

        self._update_labels()

    def _reset_all(self):
        self._mask = None
        self._silhouette = None
        self._labels = None
        self._matrix = None
        self._clear_scene()

    def _clear_messages(self):
        self.Error.clear()
        self.Warning.clear()

    def _update_labels(self):
        labelvar = self.cluster_var_model[self.cluster_var_idx]
        labels, _ = self.data.get_column_view(labelvar)
        labels = np.asarray(labels, dtype=float)
        cluster_mask = np.isnan(labels)
        dist_mask = np.isnan(self._matrix).all(axis=0)
        mask = cluster_mask | dist_mask
        labels = labels.astype(int)
        labels = labels[~mask]

        labels_unq, _ = np.unique(labels, return_counts=True)

        if len(labels_unq) < 2:
            self.Error.need_two_clusters()
            labels = silhouette = mask = None
        elif len(labels_unq) == len(labels):
            self.Error.singleton_clusters_all()
            labels = silhouette = mask = None
        else:
            silhouette = sklearn.metrics.silhouette_samples(
                self._matrix[~mask, :][:, ~mask], labels, metric="precomputed")
        self._mask = mask
        self._labels = labels
        self._silhouette = silhouette

        if mask is not None:
            count_missing = np.count_nonzero(cluster_mask)
            if count_missing:
                self.Warning.missing_cluster_assignment(
                    count_missing, s="s" if count_missing > 1 else "")
            count_nandist = np.count_nonzero(dist_mask)
            if count_nandist:
                self.Warning.nan_distances(
                    count_nandist, s="s" if count_nandist > 1 else "")

    def _set_bar_height(self):
        visible = self.bar_size >= 5
        self._silplot.setBarHeight(self.bar_size)
        self._silplot.setRowNamesVisible(visible)
        self.ann_hidden_warning.setVisible(
            not visible and self.annotation_var_idx > 0)

    def _replot(self):
        # Clear and replot/initialize the scene
        self._clear_scene()
        if self._silhouette is not None and self._labels is not None:
            var = self.cluster_var_model[self.cluster_var_idx]
            self._silplot = silplot = SilhouettePlot()
            self._set_bar_height()

            if self.group_by_cluster:
                silplot.setScores(self._silhouette, self._labels, var.values,
                                  var.colors)
            else:
                silplot.setScores(
                    self._silhouette,
                    np.zeros(len(self._silhouette), dtype=int),
                    [""], np.array([[63, 207, 207]])
                )

            self.scene.addItem(silplot)
            self._update_annotations()
            silplot.selectionChanged.connect(self.commit)
            silplot.layout().activate()
            self._update_scene_rect()
            silplot.geometryChanged.connect(self._update_scene_rect)

    def _update_bar_size(self):
        if self._silplot is not None:
            self._set_bar_height()

    def _update_annotations(self):
        if 0 < self.annotation_var_idx < len(self.annotation_var_model):
            annot_var = self.annotation_var_model[self.annotation_var_idx]
        else:
            annot_var = None
        self.ann_hidden_warning.setVisible(
            self.bar_size < 5 and annot_var is not None)

        if self._silplot is not None:
            if annot_var is not None:
                column, _ = self.data.get_column_view(annot_var)
                if self._mask is not None:
                    assert column.shape == self._mask.shape
                    # pylint: disable=invalid-unary-operand-type
                    column = column[~self._mask]
                self._silplot.setRowNames(
                    [annot_var.str_val(value) for value in column])
            else:
                self._silplot.setRowNames(None)

    def _update_scene_rect(self):
        self.scene.setSceneRect(self._silplot.geometry())

    def commit(self):
        """
        Commit/send the current selection to the output.
        """
        selected = indices = data = None
        if self.data is not None:
            selectedmask = np.full(len(self.data), False, dtype=bool)
            if self._silplot is not None:
                indices = self._silplot.selection()
                assert (np.diff(indices) > 0).all(), "strictly increasing"
                if self._mask is not None:
                    # pylint: disable=invalid-unary-operand-type
                    indices = np.flatnonzero(~self._mask)[indices]
                selectedmask[indices] = True

            if self._mask is not None:
                scores = np.full(shape=selectedmask.shape,
                                 fill_value=np.nan)
                # pylint: disable=invalid-unary-operand-type
                scores[~self._mask] = self._silhouette
            else:
                scores = self._silhouette

            silhouette_var = None
            if self.add_scores:
                var = self.cluster_var_model[self.cluster_var_idx]
                silhouette_var = Orange.data.ContinuousVariable(
                    "Silhouette ({})".format(escape(var.name)))
                domain = Orange.data.Domain(
                    self.data.domain.attributes,
                    self.data.domain.class_vars,
                    self.data.domain.metas + (silhouette_var, ))
                data = self.data.transform(domain)
            else:
                domain = self.data.domain
                data = self.data

            if np.count_nonzero(selectedmask):
                selected = self.data.from_table(
                    domain, self.data, np.flatnonzero(selectedmask))

            if self.add_scores:
                if selected is not None:
                    selected[:, silhouette_var] = np.c_[scores[selectedmask]]
                data[:, silhouette_var] = np.c_[scores]

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(create_annotated_table(data, indices))

    def send_report(self):
        if not len(self.cluster_var_model):
            return

        self.report_plot()
        caption = "Silhouette plot ({} distance), clustered by '{}'".format(
            self.Distances[self.distance_idx][0],
            self.cluster_var_model[self.cluster_var_idx])
        if self.annotation_var_idx and self._silplot.rowNamesVisible():
            caption += ", annotated with '{}'".format(
                self.annotation_var_model[self.annotation_var_idx])
        self.report_caption(caption)

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
Ejemplo n.º 3
0
class OWVennDiagram(widget.OWWidget):
    name = "Venn Diagram"
    description = "A graphical visualization of the overlap of data instances " \
                  "from a collection of input datasets."
    icon = "icons/VennDiagram.svg"
    priority = 280
    keywords = []
    settings_version = 2

    class Inputs:
        data = Input("Data", Table, multiple=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    class Error(widget.OWWidget.Error):
        instances_mismatch = Msg("Data sets do not contain the same instances.")
        too_many_inputs = Msg("Venn diagram accepts at most five datasets.")

    class Warning(widget.OWWidget.Warning):
        renamed_vars = Msg("Some variables have been renamed "
                           "to avoid duplicates.\n{}")

    selection: list

    settingsHandler = settings.DomainContextHandler()
    # Indices of selected disjoint areas
    selection = settings.Setting([], schema_only=True)
    #: Output unique items (one output row for every unique instance `key`)
    #: or preserve all duplicates in the output.
    output_duplicates = settings.Setting(False)
    autocommit = settings.Setting(True)
    rowwise = settings.Setting(True)
    selected_feature = settings.ContextSetting(None)

    want_control_area = False
    graph_name = "scene"
    atr_types = ['attributes', 'metas', 'class_vars']
    atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
    row_vals = {'attributes': 'x', 'class_vars': 'y', 'metas': 'metas'}

    def __init__(self):
        super().__init__()

        # Diagram update is in progress
        self._updating = False
        # Input update is in progress
        self._inputUpdate = False
        # Input datasets in the order they were 'connected'.
        self.data = {}
        # Extracted input item sets in the order they were 'connected'
        self.itemsets = {}
        # A list with 2 ** len(self.data) elements that store item sets
        # belonging to each area
        self.disjoint = []
        # A list with  2 ** len(self.data) elements that store keys of tables
        # intersected in each area
        self.area_keys = []

        # Main area view
        self.scene = QGraphicsScene(self)
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.setBackgroundRole(QPalette.Window)
        self.view.setFrameStyle(QGraphicsView.StyledPanel)

        self.mainArea.layout().addWidget(self.view)
        self.vennwidget = VennDiagram()
        self._resize()
        self.vennwidget.itemTextEdited.connect(self._on_itemTextEdited)
        self.scene.selectionChanged.connect(self._on_selectionChanged)

        self.scene.addItem(self.vennwidget)

        controls = gui.hBox(self.mainArea)
        box = gui.radioButtonsInBox(
            controls, self, 'rowwise',
            ["Columns (features)", "Rows (instances), matched by", ],
            box="Elements", callback=self._on_matching_changed
        )
        gui.comboBox(
            gui.indentedBox(box), self, "selected_feature",
            model=itemmodels.VariableListModel(placeholder="Instance identity"),
            callback=self._on_inputAttrActivated
            )
        box.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)

        self.outputs_box = box = gui.vBox(controls, "Output")
        self.output_duplicates_cb = gui.checkBox(
            box, self, "output_duplicates", "Output duplicates",
            callback=lambda: self.commit())  # pylint: disable=unnecessary-lambda
        gui.auto_send(box, self, "autocommit", box=False)
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        self._queue = []

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._resize()

    def showEvent(self, event):
        super().showEvent(event)
        self._resize()

    def _resize(self):
        # vennwidget draws so that the diagram fits into its geometry,
        # while labels take further 120 pixels, hence -120 in below formula
        size = max(200, min(self.view.width(), self.view.height()) - 120)
        self.vennwidget.resize(size, size)
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    @Inputs.data
    @check_sql_input
    def setData(self, data, key=None):
        self.Error.too_many_inputs.clear()
        if not self._inputUpdate:
            self._inputUpdate = True
        if key in self.data:
            if data is None:
                # Remove the input
                # Clear possible warnings.
                self.Warning.clear()
                del self.data[key]
            else:
                # Update existing item
                self.data[key] = self.data[key]._replace(name=data.name, table=data)

        elif data is not None:
            # TODO: Allow setting more them 5 inputs and let the user
            # select the 5 to display.
            if len(self.data) == 5:
                self.Error.too_many_inputs()
                return
            # Add a new input
            self.data[key] = _InputData(key, data.name, data)
        self._setInterAttributes()

    def data_equality(self):
        """ Checks if all input datasets have same ids. """
        if not self.data.values():
            return True
        sets = []
        for val in self.data.values():
            sets.append(set(val.table.ids))
        inter = reduce(set.intersection, sets)
        return len(inter) == max(map(len, sets))

    def settings_compatible(self):
        self.Error.instances_mismatch.clear()
        if not self.rowwise:
            if not self.data_equality():
                self.vennwidget.clear()
                self.Error.instances_mismatch()
                self.itemsets = {}
                return False
        return True

    def handleNewSignals(self):
        self._inputUpdate = False
        self.vennwidget.clear()
        if not self.settings_compatible():
            self.invalidateOutput()
            return

        self._createItemsets()
        self._createDiagram()
        # If autocommit is enabled, _createDiagram already outputs data
        # If not, call unconditional_commit from here
        if not self.autocommit:
            self.unconditional_commit()

        self._updateInfo()
        super().handleNewSignals()

    def intersectionStringAttrs(self):
        sets = [set(string_attributes(data_.table.domain)) for data_ in self.data.values()]
        if sets:
            return reduce(set.intersection, sets)
        return set()

    def _setInterAttributes(self):
        model = self.controls.selected_feature.model()
        model[:] = [None] + list(self.intersectionStringAttrs())
        if self.selected_feature:
            names = (var.name for var in model if var)
            if self.selected_feature.name not in names:
                self.selected_feature = model[0]

    def _itemsForInput(self, key):
        """
        Calculates input for venn diagram, according to user's settings.
        """
        table = self.data[key].table
        attr = self.selected_feature
        if attr:
            return [str(inst[attr]) for inst in table
                    if not np.isnan(inst[attr])]
        else:
            return list(table.ids)

    def _createItemsets(self):
        """
        Create itemsets over rows or columns (domains) of input tables.
        """
        olditemsets = dict(self.itemsets)
        self.itemsets.clear()

        for key, input_ in self.data.items():
            if self.rowwise:
                items = self._itemsForInput(key)
            else:
                items = [el.name for el in input_.table.domain.attributes]
            name = input_.name
            if key in olditemsets and olditemsets[key].name == name:
                # Reuse the title (which might have been changed by the user)
                title = olditemsets[key].title
            else:
                title = name

            itemset = _ItemSet(key=key, name=name, title=title, items=items)
            self.itemsets[key] = itemset

    def _createDiagram(self):
        self._updating = True

        oldselection = list(self.selection)

        n = len(self.itemsets)
        self.disjoint, self.area_keys = \
            self.get_disjoint(set(s.items) for s in self.itemsets.values())

        vennitems = []
        colors = colorpalettes.LimitedDiscretePalette(n, force_hsv=True)

        for i, item in enumerate(self.itemsets.values()):
            cnt = len(set(item.items))
            cnt_all = len(item.items)
            if cnt != cnt_all:
                fmt = '{} <i>(all: {})</i>'
            else:
                fmt = '{}'
            counts = fmt.format(cnt, cnt_all)
            gr = VennSetItem(text=item.title, informativeText=counts)
            color = colors[i]
            color.setAlpha(100)
            gr.setBrush(QBrush(color))
            gr.setPen(QPen(Qt.NoPen))
            vennitems.append(gr)

        self.vennwidget.setItems(vennitems)

        for i, area in enumerate(self.vennwidget.vennareas()):
            area_items = list(map(str, list(self.disjoint[i])))
            if i:
                area.setText("{0}".format(len(area_items)))

            label = disjoint_set_label(i, n, simplify=False)
            head = "<h4>|{}| = {}</h4>".format(label, len(area_items))
            if len(area_items) > 32:
                items_str = ", ".join(map(escape, area_items[:32]))
                hidden = len(area_items) - 32
                tooltip = ("{}<span>{}, ...</br>({} items not shown)<span>"
                           .format(head, items_str, hidden))
            elif area_items:
                tooltip = "{}<span>{}</span>".format(
                    head,
                    ", ".join(map(escape, area_items))
                )
            else:
                tooltip = head

            area.setToolTip(tooltip)

            area.setPen(QPen(QColor(10, 10, 10, 200), 1.5))
            area.setFlag(QGraphicsPathItem.ItemIsSelectable, True)
            area.setSelected(i in oldselection)

        self._updating = False
        self._on_selectionChanged()

    def _updateInfo(self):
        # Clear all warnings
        self.warning()

        if self.selected_feature is None:
            no_idx = ["#{}".format(i + 1)
                      for i, key in enumerate(self.data)
                      if not source_attributes(self.data[key].table.domain)]
            if len(no_idx) == 1:
                self.warning("Dataset {} has no suitable identifiers."
                             .format(no_idx[0]))
            elif len(no_idx) > 1:
                self.warning("Datasets {} and {} have no suitable identifiers."
                             .format(", ".join(no_idx[:-1]), no_idx[-1]))

    def _on_selectionChanged(self):
        if self._updating:
            return

        areas = self.vennwidget.vennareas()
        self.selection = [i for i, area in enumerate(areas) if area.isSelected()]
        self.invalidateOutput()

    def _on_matching_changed(self):
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        if not self.settings_compatible():
            self.invalidateOutput()
            return
        self._createItemsets()
        self._createDiagram()
        self._updateInfo()

    def _on_inputAttrActivated(self):
        self.rowwise = 1
        self._on_matching_changed()

    def _on_itemTextEdited(self, index, text):
        text = str(text)
        key = list(self.itemsets)[index]
        self.itemsets[key] = self.itemsets[key]._replace(title=text)

    def invalidateOutput(self):
        self.commit()

    def merge_data(self, domain, values, ids=None):
        X, metas, class_vars = None, None, None
        renamed = []
        for val in domain.values():
            names = [var.name for var in val]
            unique_names = get_unique_names_duplicates(names)
            for n, u, idx, var in zip(names, unique_names, count(), val):
                if n != u:
                    val[idx] = var.copy(name=u)
                    renamed.append(n)
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        if 'attributes' in values:
            X = np.hstack(values['attributes'])
        if 'metas' in values:
            metas = np.hstack(values['metas'])
            n = len(metas)
        if 'class_vars' in values:
            class_vars = np.hstack(values['class_vars'])
            n = len(class_vars)
        if X is None:
            X = np.empty((n, 0))
        table = Table.from_numpy(Domain(**domain), X, class_vars, metas)
        if ids is not None:
            table.ids = ids
        return table

    def extract_columnwise(self, var_dict, columns=None):
        domain = {type_ : [] for type_ in self.atr_types}
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                is_selected = bool(columns) and var_name.name in columns
                if var_data[0]:
                    #columns are different, copy all, rename them
                    for var, table_key in var_data[1]:
                        idx = list(self.data).index(table_key) + 1
                        new_atr = var.copy(name=f'{var_name.name} ({idx})')
                        if columns and atr_type == 'attributes':
                            new_atr.attributes['Selected'] = is_selected
                        domain[atr_type].append(new_atr)
                        renamed.append(var_name.name)
                        values[atr_type].append(getattr(self.data[table_key].table[:, var_name],
                                                        self.atr_vals[atr_type])
                                                .reshape(-1, 1))
                else:
                    new_atr = var_data[1][0][0].copy()
                    if columns and atr_type == 'attributes':
                        new_atr.attributes['Selected'] = is_selected
                    domain[atr_type].append(new_atr)
                    values[atr_type].append(getattr(self.data[var_data[1][0][1]].table[:, var_name],
                                                    self.atr_vals[atr_type])
                                            .reshape(-1, 1))
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        return self.merge_data(domain, values)

    def curry_merge(self, table_key, atr_type, ids=None, selection=False):
        if self.rowwise:
            check_equality = self.arrays_equal_rows
        else:
            check_equality = self.arrays_equal_cols

        def inner(new_atrs, atr):
            """
            Atrs - list of variables we wish to merge
            new_atrs - dictionary where key is old var, val
                is [is_different:bool, table_keys:list]), is_different is set to True,
                if we are outputing duplicates, but the value is arbitrary
            """
            if atr in new_atrs:
                if not selection and self.output_duplicates:
                    #if output_duplicates, we just check if compute value is the same
                    new_atrs[atr][0] = True
                elif not new_atrs[atr][0]:
                    for var, key in new_atrs[atr][1]:
                        if not check_equality(table_key,
                                              key,
                                              atr.name,
                                              self.atr_vals[atr_type],
                                              type(var), ids):
                            new_atrs[atr][0] = True
                            break
                new_atrs[atr][1].append((atr, table_key))
            else:
                new_atrs[atr] = [False, [(atr, table_key)]]
            return new_atrs
        return inner

    def arrays_equal_rows(self, key1, key2, name, data_type, type_, ids):
        #gets masks, compares same as cols
        t1 = self.data[key1].table
        t2 = self.data[key2].table
        inter_val = set(ids[key1]) & set(ids[key2])
        t1_inter = [ids[key1][val] for val in inter_val]
        t2_inter = [ids[key2][val] for val in inter_val]
        return arrays_equal(
            getattr(t1[t1_inter, name],
                    data_type).reshape(-1, 1),
            getattr(t2[t2_inter, name],
                    data_type).reshape(-1, 1),
            type_)

    def arrays_equal_cols(self, key1, key2, name, data_type, type_, _ids=None):
        return arrays_equal(
            getattr(self.data[key1].table[:, name],
                    data_type),
            getattr(self.data[key2].table[:, name],
                    data_type),
            type_)

    def create_from_columns(self, columns, relevant_keys, get_selected):
        """
        Columns are duplicated only if values differ (even
        if only in order of values), origin table name and input slot is added to column name.
        """
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_keys:
                table = self.data[table_key].table
                if atr_type == 'attributes':
                    if get_selected:
                        atrs = list(compress(table.domain.attributes,
                                             [c.name in columns for c in table.domain.attributes]))
                    else:
                        atrs = getattr(table.domain, atr_type)
                else:
                    atrs = getattr(table.domain, atr_type)
                merge_vars = self.curry_merge(table_key, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container

        if get_selected:
            annotated = self.extract_columnwise(var_dict, None)
        else:
            annotated = self.extract_columnwise(var_dict, columns)

        return annotated

    def extract_rowwise(self, var_dict, ids=None, selection=False):
        """
        keys : ['attributes', 'metas', 'class_vars']
        vals: new_atrs - dictionary where key is old name, val
            is [is_different:bool, table_keys:list])
        ids: dict with ids for each table
        """
        all_ids = sorted(reduce(set.union, [set(val) for val in ids.values()], set()))

        permutations = {}
        for table_key, dict_ in ids.items():
            permutations[table_key] = get_perm(list(dict_), all_ids)

        domain = {type_ : [] for type_ in self.atr_types}
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                different = var_data[0]
                if different:
                    # Columns are different, copy and rename them.
                    # Renaming is done here to mark appropriately the source table.
                    # Additional strange clashes are checked later in merge_data
                    for var, table_key in var_data[1]:
                        temp = self.data[table_key].table
                        idx = list(self.data).index(table_key) + 1
                        domain[atr_type].append(var.copy(name='{} ({})'.format(var_name, idx)))
                        renamed.append(var_name.name)
                        v = getattr(temp[list(ids[table_key].values()), var_name],
                                    self.atr_vals[atr_type])
                        perm = permutations[table_key]
                        if len(v) < len(all_ids):
                            values[atr_type].append(pad_columns(v, perm, len(all_ids)))
                        else:
                            values[atr_type].append(v[perm].reshape(-1, 1))
                else:
                    value = np.full((len(all_ids), 1), np.nan)
                    domain[atr_type].append(var_data[1][0][0].copy())
                    for _, table_key in var_data[1]:
                        #different tables have different part of the same attribute vector
                        perm = permutations[table_key]
                        v = getattr(self.data[table_key].table[list(ids[table_key].values()),
                                                               var_name],
                                    self.atr_vals[atr_type]).reshape(-1, 1)
                        value = value.astype(v.dtype, copy=False)
                        value[perm] = v
                    values[atr_type].append(value)

        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        ids = None if self.selected_feature else np.array(all_ids)
        table = self.merge_data(domain, values, ids)
        if selection:
            mask = [idx in self.selected_items for idx in all_ids]
            return create_annotated_table(table, mask)
        return table

    def get_indices(self, table, selection):
        """Returns mappings of ids (be it row id or string) to indices in tables"""
        if self.selected_feature:
            if self.output_duplicates and selection:
                items, inverse = np.unique(getattr(table[:, self.selected_feature], 'metas'),
                                           return_inverse=True)
                ids = [np.nonzero(inverse == idx)[0] for idx in range(len(items))]
            else:
                items, ids = np.unique(getattr(table[:, self.selected_feature], 'metas'),
                                       return_index=True)

        else:
            items = table.ids
            ids = range(len(table))

        if selection:
            return {item: idx for item, idx in zip(items, ids)
                    if item in self.selected_items}

        return dict(zip(items, ids))

    def get_indices_to_match_by(self, relevant_keys, selection=False):
        dict_ = {}
        for key in relevant_keys:
            table = self.data[key].table
            dict_[key] = self.get_indices(table, selection)
        return dict_

    def create_from_rows(self, relevant_ids, selection=False):
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_ids:
                merge_vars = self.curry_merge(table_key, atr_type, relevant_ids, selection)
                atrs = getattr(self.data[table_key].table.domain, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container
        if self.output_duplicates and not selection:
            return self.extract_rowwise_duplicates(var_dict, relevant_ids)
        return self.extract_rowwise(var_dict, relevant_ids, selection)

    def expand_table(self, table, atrs, metas, cv):
        exp = []
        n = 1 if isinstance(table, RowInstance) else len(table)
        if isinstance(table, RowInstance):
            ids = table.id.reshape(-1, 1)
            atr_vals = self.row_vals
        else:
            ids = table.ids.reshape(-1, 1)
            atr_vals = self.atr_vals
        for all_el, atr_type in zip([atrs, metas, cv], self.atr_types):
            cur_el = getattr(table.domain, atr_type)
            array = np.full((n, len(all_el)), np.nan)
            if cur_el:
                perm = get_perm(cur_el, all_el)
                b = getattr(table, atr_vals[atr_type]).reshape(len(array), len(perm))
                array = array.astype(b.dtype, copy=False)
                array[:, perm] = b
            exp.append(array)
        return (*exp, ids)

    def extract_rowwise_duplicates(self, var_dict, ids):
        all_ids = sorted(reduce(set.union, [set(val) for val in ids.values()], set()))
        sort_key = attrgetter("name")
        all_atrs = sorted(var_dict['attributes'], key=sort_key)
        all_metas = sorted(var_dict['metas'], key=sort_key)
        all_cv = sorted(var_dict['class_vars'], key=sort_key)

        all_x, all_y, all_m = [], [], []
        new_table_ids = []
        for idx in all_ids:
            #iterate trough tables with same idx
            for table_key, t_indices in ids.items():
                if idx not in t_indices:
                    continue
                map_ = t_indices[idx]
                extracted = self.data[table_key].table[map_]
                # pylint: disable=unbalanced-tuple-unpacking
                x, m, y, t_ids = self.expand_table(extracted, all_atrs, all_metas, all_cv)
                all_x.append(x)
                all_y.append(y)
                all_m.append(m)
                new_table_ids.append(t_ids)
        domain = {'attributes': all_atrs, 'metas': all_metas, 'class_vars': all_cv}
        values = {'attributes': [np.vstack(all_x)],
                  'metas': [np.vstack(all_m)],
                  'class_vars': [np.vstack(all_y)]}
        return self.merge_data(domain, values, np.vstack(new_table_ids))

    def commit(self):
        if not self.vennwidget.vennareas() or not self.data:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            return

        self.selected_items = reduce(
            set.union, [self.disjoint[index] for index in self.selection],
            set()
        )
        selected_keys = reduce(
            set.union, [set(self.area_keys[area]) for area in self.selection],
            set())
        selected = None

        if self.rowwise:
            if self.selected_items:
                selected_ids = self.get_indices_to_match_by(
                    selected_keys, bool(self.selection))
                selected = self.create_from_rows(selected_ids, False)
            annotated_ids = self.get_indices_to_match_by(self.data)
            annotated = self.create_from_rows(annotated_ids, True)
        else:
            annotated = self.create_from_columns(self.selected_items, self.data, False)
            if self.selected_items:
                selected = self.create_from_columns(self.selected_items, selected_keys, True)

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        self.report_plot()

    def get_disjoint(self, sets):
        """
        Return all disjoint subsets.
        """
        sets = list(sets)
        n = len(sets)
        disjoint_sets = [None] * (2 ** n)
        included_tables = [None] * (2 ** n)
        for i in range(2 ** n):
            key = setkey(i, n)
            included = [s for s, inc in zip(sets, key) if inc]
            if included:
                excluded = [s for s, inc in zip(sets, key) if not inc]
                s = reduce(set.intersection, included)
                s = reduce(set.difference, excluded, s)
            else:
                s = set()
            disjoint_sets[i] = s
            included_tables[i] = [k for k, inc in zip(self.data, key) if inc]

        return disjoint_sets, included_tables
Ejemplo n.º 4
0
class OWSilhouettePlot(widget.OWWidget):
    name = "Silhouette Plot"
    description = "Visually assess cluster quality and " \
                  "the degree of cluster membership."

    icon = "icons/SilhouettePlot.svg"
    priority = 300
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    replaces = [
        "orangecontrib.prototypes.widgets.owsilhouetteplot.OWSilhouettePlot",
        "Orange.widgets.unsupervised.owsilhouetteplot.OWSilhouettePlot"
    ]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Distance metric index
    distance_idx = settings.Setting(0)
    #: Group/cluster variable index
    cluster_var_idx = settings.ContextSetting(0)
    #: Annotation variable index
    annotation_var_idx = settings.ContextSetting(0)
    #: Group the (displayed) silhouettes by cluster
    group_by_cluster = settings.Setting(True)
    #: A fixed size for an instance bar
    bar_size = settings.Setting(3)
    #: Add silhouette scores to output data
    add_scores = settings.Setting(False)
    auto_commit = settings.Setting(True)

    Distances = [("Euclidean", Orange.distance.Euclidean),
                 ("Manhattan", Orange.distance.Manhattan)]

    graph_name = "scene"
    buttons_area_orientation = Qt.Vertical

    class Error(widget.OWWidget.Error):
        need_two_clusters = Msg("Need at least two non-empty clusters")
        singleton_clusters_all = Msg("All clusters are singletons")
        memory_error = Msg("Not enough memory")
        value_error = Msg("Distances could not be computed: '{}'")

    class Warning(widget.OWWidget.Warning):
        missing_cluster_assignment = Msg(
            "{} instance{s} omitted (missing cluster assignment)")

    def __init__(self):
        super().__init__()
        #: The input data
        self.data = None  # type: Optional[Orange.data.Table]
        #: Distance matrix computed from data
        self._matrix = None  # type: Optional[Orange.misc.DistMatrix]
        #: An bool mask (size == len(data)) indicating missing group/cluster
        #: assignments
        self._mask = None  # type: Optional[np.ndarray]
        #: An array of cluster/group labels for instances with valid group
        #: assignment
        self._labels = None  # type: Optional[np.ndarray]
        #: An array of silhouette scores for instances with valid group
        #: assignment
        self._silhouette = None  # type: Optional[np.ndarray]
        self._silplot = None  # type: Optional[SilhouettePlot]

        gui.comboBox(self.controlArea,
                     self,
                     "distance_idx",
                     box="Distance",
                     items=[name for name, _ in OWSilhouettePlot.Distances],
                     orientation=Qt.Horizontal,
                     callback=self._invalidate_distances)

        box = gui.vBox(self.controlArea, "Cluster Label")
        self.cluster_var_cb = gui.comboBox(box,
                                           self,
                                           "cluster_var_idx",
                                           contentsLength=14,
                                           addSpace=4,
                                           callback=self._invalidate_scores)
        gui.checkBox(box,
                     self,
                     "group_by_cluster",
                     "Group by cluster",
                     callback=self._replot)
        self.cluster_var_model = itemmodels.VariableListModel(parent=self)
        self.cluster_var_cb.setModel(self.cluster_var_model)

        box = gui.vBox(self.controlArea, "Bars")
        gui.widgetLabel(box, "Bar width:")
        gui.hSlider(box,
                    self,
                    "bar_size",
                    minValue=1,
                    maxValue=10,
                    step=1,
                    callback=self._update_bar_size,
                    addSpace=6)
        gui.widgetLabel(box, "Annotations:")
        self.annotation_cb = gui.comboBox(box,
                                          self,
                                          "annotation_var_idx",
                                          contentsLength=14,
                                          callback=self._update_annotations)
        self.annotation_var_model = itemmodels.VariableListModel(parent=self)
        self.annotation_var_model[:] = ["None"]
        self.annotation_cb.setModel(self.annotation_var_model)
        ibox = gui.indentedBox(box, 5)
        self.ann_hidden_warning = warning = gui.widgetLabel(
            ibox, "(increase the width to show)")
        ibox.setFixedWidth(ibox.sizeHint().width())
        warning.setVisible(False)

        gui.rubber(self.controlArea)

        gui.separator(self.buttonsArea)
        box = gui.vBox(self.buttonsArea, "Output")
        # Thunk the call to commit to call conditional commit
        gui.checkBox(box,
                     self,
                     "add_scores",
                     "Add silhouette scores",
                     callback=lambda: self.commit())
        gui.auto_commit(box,
                        self,
                        "auto_commit",
                        "Commit",
                        auto_label="Auto commit",
                        box=False)
        # Ensure that the controlArea is not narrower than buttonsArea
        self.controlArea.layout().addWidget(self.buttonsArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(600, 720))

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """
        Set the input dataset.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        warning_msg = ""
        candidatevars = []
        if data is not None:
            candidatevars = [
                v for v in data.domain.variables + data.domain.metas
                if v.is_discrete and len(v.values) >= 2
            ]
            if not candidatevars:
                error_msg = "Input does not have any suitable labels."
                data = None

        self.data = data
        if data is not None:
            self.cluster_var_model[:] = candidatevars
            if data.domain.class_var in candidatevars:
                self.cluster_var_idx = \
                    candidatevars.index(data.domain.class_var)
            else:
                self.cluster_var_idx = 0

            annotvars = [var for var in data.domain.metas if var.is_string]
            self.annotation_var_model[:] = ["None"] + annotvars
            self.annotation_var_idx = 1 if len(annotvars) else 0
            self.openContext(Orange.data.Domain(candidatevars))

        self.error(error_msg)
        self.warning(warning_msg)

    def handleNewSignals(self):
        if self.data is not None:
            self._update()
            self._replot()

        self.unconditional_commit()

    def clear(self):
        """
        Clear the widget state.
        """
        self.data = None
        self._matrix = None
        self._mask = None
        self._silhouette = None
        self._labels = None
        self.cluster_var_model[:] = []
        self.annotation_var_model[:] = ["None"]
        self._clear_scene()
        self.Error.clear()
        self.Warning.clear()

    def _clear_scene(self):
        # Clear the graphics scene and associated objects
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self._silplot = None

    def _invalidate_distances(self):
        # Invalidate the computed distance matrix and recompute the silhouette.
        self._matrix = None
        self._invalidate_scores()

    def _invalidate_scores(self):
        # Invalidate and recompute the current silhouette scores.
        self._labels = self._silhouette = self._mask = None
        self._update()
        self._replot()
        if self.data is not None:
            self.commit()

    def _update(self):
        # Update/recompute the distances/scores as required
        self._clear_messages()

        if self.data is None or not len(self.data):
            self._reset_all()
            return

        if self._matrix is None and self.data is not None:
            _, metric = self.Distances[self.distance_idx]
            try:
                self._matrix = np.asarray(metric(self.data))
            except MemoryError:
                self.Error.memory_error()
                return
            except ValueError as err:
                self.Error.value_error(str(err))
                return

        self._update_labels()

    def _reset_all(self):
        self._mask = None
        self._silhouette = None
        self._labels = None
        self._matrix = None
        self._clear_scene()

    def _clear_messages(self):
        self.Error.clear()
        self.Warning.missing_cluster_assignment.clear()

    def _update_labels(self):
        labelvar = self.cluster_var_model[self.cluster_var_idx]
        labels, _ = self.data.get_column_view(labelvar)
        labels = np.asarray(labels, dtype=float)
        mask = np.isnan(labels)
        labels = labels.astype(int)
        labels = labels[~mask]

        labels_unq, _ = np.unique(labels, return_counts=True)

        if len(labels_unq) < 2:
            self.Error.need_two_clusters()
            labels = silhouette = mask = None
        elif len(labels_unq) == len(labels):
            self.Error.singleton_clusters_all()
            labels = silhouette = mask = None
        else:
            silhouette = sklearn.metrics.silhouette_samples(
                self._matrix[~mask, :][:, ~mask], labels, metric="precomputed")
        self._mask = mask
        self._labels = labels
        self._silhouette = silhouette

        if labels is not None:
            count_missing = np.count_nonzero(mask)
            if count_missing:
                self.Warning.missing_cluster_assignment(
                    count_missing, s="s" if count_missing > 1 else "")

    def _set_bar_height(self):
        visible = self.bar_size >= 5
        self._silplot.setBarHeight(self.bar_size)
        self._silplot.setRowNamesVisible(visible)
        self.ann_hidden_warning.setVisible(not visible
                                           and self.annotation_var_idx > 0)

    def _replot(self):
        # Clear and replot/initialize the scene
        self._clear_scene()
        if self._silhouette is not None and self._labels is not None:
            var = self.cluster_var_model[self.cluster_var_idx]
            self._silplot = silplot = SilhouettePlot()
            self._set_bar_height()

            if self.group_by_cluster:
                silplot.setScores(self._silhouette, self._labels, var.values,
                                  var.colors)
            else:
                silplot.setScores(self._silhouette,
                                  np.zeros(len(self._silhouette), dtype=int),
                                  [""], np.array([[63, 207, 207]]))

            self.scene.addItem(silplot)
            self._update_annotations()
            silplot.selectionChanged.connect(self.commit)
            silplot.layout().activate()
            self._update_scene_rect()
            silplot.geometryChanged.connect(self._update_scene_rect)

    def _update_bar_size(self):
        if self._silplot is not None:
            self._set_bar_height()

    def _update_annotations(self):
        if 0 < self.annotation_var_idx < len(self.annotation_var_model):
            annot_var = self.annotation_var_model[self.annotation_var_idx]
        else:
            annot_var = None
        self.ann_hidden_warning.setVisible(self.bar_size < 5
                                           and annot_var is not None)

        if self._silplot is not None:
            if annot_var is not None:
                column, _ = self.data.get_column_view(annot_var)
                if self._mask is not None:
                    assert column.shape == self._mask.shape
                    column = column[~self._mask]
                self._silplot.setRowNames(
                    [annot_var.str_val(value) for value in column])
            else:
                self._silplot.setRowNames(None)

    def _update_scene_rect(self):
        self.scene.setSceneRect(self._silplot.geometry())

    def commit(self):
        """
        Commit/send the current selection to the output.
        """
        selected = indices = data = None
        if self.data is not None:
            selectedmask = np.full(len(self.data), False, dtype=bool)
            if self._silplot is not None:
                indices = self._silplot.selection()
                assert (np.diff(indices) > 0).all(), "strictly increasing"
                if self._mask is not None:
                    indices = np.flatnonzero(~self._mask)[indices]
                selectedmask[indices] = True

            if self._mask is not None:
                scores = np.full(shape=selectedmask.shape, fill_value=np.nan)
                scores[~self._mask] = self._silhouette
            else:
                scores = self._silhouette

            silhouette_var = None
            if self.add_scores:
                var = self.cluster_var_model[self.cluster_var_idx]
                silhouette_var = Orange.data.ContinuousVariable(
                    "Silhouette ({})".format(escape(var.name)))
                domain = Orange.data.Domain(
                    self.data.domain.attributes, self.data.domain.class_vars,
                    self.data.domain.metas + (silhouette_var, ))
                data = self.data.transform(domain)
            else:
                domain = self.data.domain
                data = self.data

            if np.count_nonzero(selectedmask):
                selected = self.data.from_table(domain, self.data,
                                                np.flatnonzero(selectedmask))

            if self.add_scores:
                if selected is not None:
                    selected[:, silhouette_var] = np.c_[scores[selectedmask]]
                data[:, silhouette_var] = np.c_[scores]

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(create_annotated_table(data, indices))

    def send_report(self):
        if not len(self.cluster_var_model):
            return

        self.report_plot()
        caption = "Silhouette plot ({} distance), clustered by '{}'".format(
            self.Distances[self.distance_idx][0],
            self.cluster_var_model[self.cluster_var_idx])
        if self.annotation_var_idx and self._silplot.rowNamesVisible():
            caption += ", annotated with '{}'".format(
                self.annotation_var_model[self.annotation_var_idx])
        self.report_caption(caption)

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
Ejemplo n.º 5
0
class EditLinksDialog(QDialog):
    """
    A dialog for editing links.

    >>> dlg = EditLinksDialog()
    >>> dlg.setNodes(source_node, sink_node)
    >>> dlg.setLinks([(source_node.output_channel("Data"),
    ...                sink_node.input_channel("Data"))])
    >>> if dlg.exec_() == EditLinksDialog.Accepted:
    ...     new_links = dlg.links()
    ...
    """
    def __init__(self, parent=None, **kwargs):
        # type: (Optional[QWidget], Any) -> None
        super().__init__(parent, **kwargs)

        self.setModal(True)

        self.__setupUi()

    def __setupUi(self):
        layout = QVBoxLayout()

        # Scene with the link editor.
        self.scene = LinksEditScene()
        self.view = QGraphicsView(self.scene)
        self.view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.view.setRenderHint(QPainter.Antialiasing)

        self.scene.editWidget.geometryChanged.connect(self.__onGeometryChanged)

        # Ok/Cancel/Clear All buttons.
        buttons = QDialogButtonBox(QDialogButtonBox.Ok |
                                   QDialogButtonBox.Cancel |
                                   QDialogButtonBox.Reset,
                                   Qt.Horizontal)

        clear_button = buttons.button(QDialogButtonBox.Reset)
        clear_button.setText(self.tr("Clear All"))

        buttons.accepted.connect(self.accept)
        buttons.rejected.connect(self.reject)
        clear_button.clicked.connect(self.scene.editWidget.clearLinks)

        layout.addWidget(self.view)
        layout.addWidget(buttons)

        self.setLayout(layout)
        layout.setSizeConstraint(QVBoxLayout.SetFixedSize)

        self.setSizeGripEnabled(False)

    def setNodes(self, source_node, sink_node):
        # type: (SchemeNode, SchemeNode) -> None
        """
        Set the source/sink nodes (:class:`.SchemeNode` instances)
        between which to edit the links.

        .. note:: This should be called before :func:`setLinks`.

        """
        self.scene.editWidget.setNodes(source_node, sink_node)

    def setLinks(self, links):
        # type: (List[IOPair]) -> None
        """
        Set a list of links to display between the source and sink
        nodes. The `links` is a list of (`OutputSignal`, `InputSignal`)
        tuples where the first element is an output signal of the source
        node and the second an input signal of the sink node.

        """
        self.scene.editWidget.setLinks(links)

    def links(self):
        # type: () -> List[IOPair]
        """
        Return the links between the source and sink node.
        """
        return self.scene.editWidget.links()

    def __onGeometryChanged(self):
        size = self.scene.editWidget.size()
        left, top, right, bottom = self.getContentsMargins()
        self.view.setFixedSize(size.toSize() + \
                               QSize(left + right + 4, top + bottom + 4))
        self.view.setSceneRect(self.scene.editWidget.geometry())
Ejemplo n.º 6
0
class OWVennDiagram(widget.OWWidget):
    name = "Venn Diagram"
    description = "A graphical visualization of the overlap of data instances " "from a collection of input data sets."
    icon = "icons/VennDiagram.svg"
    priority = 280

    inputs = [("Data", Orange.data.Table, "setData", widget.Multiple)]
    outputs = [("Selected Data", Orange.data.Table)]

    # Selected disjoint subset indices
    selection = settings.Setting([])
    #: Stored input set hints
    #: {(index, inputname, attributes): (selectedattrname, itemsettitle)}
    #: The 'selectedattrname' can be None
    inputhints = settings.Setting({})
    #: Use identifier columns for instance matching
    useidentifiers = settings.Setting(True)
    #: Output unique items (one output row for every unique instance `key`)
    #: or preserve all duplicates in the output.
    output_duplicates = settings.Setting(False)
    autocommit = settings.Setting(True)

    graph_name = "scene"

    def __init__(self):
        super().__init__()

        # Diagram update is in progress
        self._updating = False
        # Input update is in progress
        self._inputUpdate = False
        # All input tables have the same domain.
        self.samedomain = True
        # Input datasets in the order they were 'connected'.
        self.data = OrderedDict()
        # Extracted input item sets in the order they were 'connected'
        self.itemsets = OrderedDict()

        # GUI
        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, "No data on input.\n")

        self.identifiersBox = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "useidentifiers",
            [],
            box="Data Instance Identifiers",
            callback=self._on_useidentifiersChanged,
        )
        self.useequalityButton = gui.appendRadioButton(self.identifiersBox, "Use instance equality")
        self.useidentifiersButton = rb = gui.appendRadioButton(self.identifiersBox, "Use identifiers")
        self.inputsBox = gui.indentedBox(self.identifiersBox, sep=gui.checkButtonOffsetHint(rb))
        self.inputsBox.setEnabled(bool(self.useidentifiers))

        for i in range(5):
            box = gui.vBox(self.inputsBox, "Data set #%i" % (i + 1), addSpace=False)
            box.setFlat(True)
            model = itemmodels.VariableListModel(parent=self)
            cb = QComboBox(minimumContentsLength=12, sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
            cb.setModel(model)
            cb.activated[int].connect(self._on_inputAttrActivated)
            box.setEnabled(False)
            # Store the combo in the box for later use.
            box.combo_box = cb
            box.layout().addWidget(cb)

        gui.rubber(self.controlArea)

        box = gui.vBox(self.controlArea, "Output")
        gui.checkBox(box, self, "output_duplicates", "Output duplicates", callback=lambda: self.commit())
        gui.auto_commit(box, self, "autocommit", "Send Selection", "Send Automatically", box=False)

        # Main area view
        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.setBackgroundRole(QPalette.Window)
        self.view.setFrameStyle(QGraphicsView.StyledPanel)

        self.mainArea.layout().addWidget(self.view)
        self.vennwidget = VennDiagram()
        self.vennwidget.resize(400, 400)
        self.vennwidget.itemTextEdited.connect(self._on_itemTextEdited)
        self.scene.selectionChanged.connect(self._on_selectionChanged)

        self.scene.addItem(self.vennwidget)

        self.resize(self.controlArea.sizeHint().width() + 550, max(self.controlArea.sizeHint().height(), 550))

        self._queue = []

    @check_sql_input
    def setData(self, data, key=None):
        self.error()
        if not self._inputUpdate:
            # Store hints only on the first setData call.
            self._storeHints()
            self._inputUpdate = True

        if key in self.data:
            if data is None:
                # Remove the input
                self._remove(key)
            else:
                # Update existing item
                self._update(key, data)
        elif data is not None:
            # TODO: Allow setting more them 5 inputs and let the user
            # select the 5 to display.
            if len(self.data) == 5:
                self.error("Venn diagram accepts at most five data sets.")
                return
            # Add a new input
            self._add(key, data)

    def handleNewSignals(self):
        self._inputUpdate = False

        # Check if all inputs are from the same domain.
        domains = [input.table.domain for input in self.data.values()]
        samedomain = all(domain_eq(d1, d2) for d1, d2 in pairwise(domains))

        self.samedomain = samedomain

        has_identifiers = all(source_attributes(input.table.domain) for input in self.data.values())
        has_any_identifiers = any(source_attributes(input.table.domain) for input in self.data.values())
        self.useequalityButton.setEnabled(samedomain)
        self.useidentifiersButton.setEnabled(has_any_identifiers or len(self.data) == 0)
        self.inputsBox.setEnabled(has_any_identifiers)

        if not samedomain and has_any_identifiers and not self.useidentifiers:
            self.useidentifiers = 1
        elif samedomain and not has_identifiers:
            self.useidentifiers = 0

        incremental = all(inc for _, inc in self._queue)

        if incremental:
            # Only received updated data on existing link.
            self._updateItemsets()
        else:
            # Links were removed and/or added.
            self._createItemsets()
            self._restoreHints()
            self._updateItemsets()

        del self._queue[:]

        self._createDiagram()
        if self.data:
            self.info.setText("{} data sets on input.\n".format(len(self.data)))
        else:
            self.info.setText("No data on input\n")

        self._updateInfo()
        super().handleNewSignals()

    def _invalidate(self, keys=None, incremental=True):
        """
        Invalidate input for a list of input keys.
        """
        if keys is None:
            keys = list(self.data.keys())

        self._queue.extend((key, incremental) for key in keys)

    def itemsetAttr(self, key):
        index = list(self.data.keys()).index(key)
        _, combo = self._controlAtIndex(index)
        model = combo.model()
        attr_index = combo.currentIndex()
        if attr_index >= 0:
            return model[attr_index]
        else:
            return None

    def _controlAtIndex(self, index):
        group_box = self.inputsBox.layout().itemAt(index).widget()
        combo = group_box.combo_box
        return group_box, combo

    def _setAttributes(self, index, attrs):
        box, combo = self._controlAtIndex(index)
        model = combo.model()

        if attrs is None:
            model[:] = []
            box.setEnabled(False)
        else:
            if model[:] != attrs:
                model[:] = attrs

            box.setEnabled(True)

    def _add(self, key, table):
        name = table.name
        index = len(self.data)
        attrs = source_attributes(table.domain)

        self.data[key] = _InputData(key, name, table)

        self._setAttributes(index, attrs)

        self._invalidate([key], incremental=False)

        item = self.inputsBox.layout().itemAt(index)
        box = item.widget()
        box.setTitle("Data set: {}".format(name))

    def _remove(self, key):
        index = list(self.data.keys()).index(key)

        # Clear possible warnings.
        self.warning()

        self._setAttributes(index, None)

        del self.data[key]

        layout = self.inputsBox.layout()
        item = layout.takeAt(index)
        layout.addItem(item)
        inputs = list(self.data.values())

        for i in range(5):
            box, _ = self._controlAtIndex(i)
            if i < len(inputs):
                title = "Data set: {}".format(inputs[i].name)
            else:
                title = "Data set #{}".format(i + 1)
            box.setTitle(title)

        self._invalidate([key], incremental=False)

    def _update(self, key, table):
        name = table.name
        index = list(self.data.keys()).index(key)
        attrs = source_attributes(table.domain)

        self.data[key] = self.data[key]._replace(name=name, table=table)

        self._setAttributes(index, attrs)
        self._invalidate([key])

        item = self.inputsBox.layout().itemAt(index)
        box = item.widget()
        box.setTitle("Data set: {}".format(name))

    def _itemsForInput(self, key):
        useidentifiers = self.useidentifiers or not self.samedomain

        def items_by_key(key, input):
            attr = self.itemsetAttr(key)
            if attr is not None:
                return [str(inst[attr]) for inst in input.table if not numpy.isnan(inst[attr])]
            else:
                return []

        def items_by_eq(key, input):
            return list(map(ComparableInstance, input.table))

        input = self.data[key]
        if useidentifiers:
            items = items_by_key(key, input)
        else:
            items = items_by_eq(key, input)
        return items

    def _updateItemsets(self):
        assert list(self.data.keys()) == list(self.itemsets.keys())
        for key, input in list(self.data.items()):
            items = self._itemsForInput(key)
            item = self.itemsets[key]
            item = item._replace(items=items)
            name = input.name
            if item.name != name:
                item = item._replace(name=name, title=name)
            self.itemsets[key] = item

    def _createItemsets(self):
        olditemsets = dict(self.itemsets)
        self.itemsets.clear()

        for key, input in self.data.items():
            items = self._itemsForInput(key)
            name = input.name
            if key in olditemsets and olditemsets[key].name == name:
                # Reuse the title (which might have been changed by the user)
                title = olditemsets[key].title
            else:
                title = name

            itemset = _ItemSet(key=key, name=name, title=title, items=items)
            self.itemsets[key] = itemset

    def _storeHints(self):
        if self.data:
            self.inputhints.clear()
            for i, (key, input) in enumerate(self.data.items()):
                attrs = source_attributes(input.table.domain)
                attrs = tuple(attr.name for attr in attrs)
                selected = self.itemsetAttr(key)
                if selected is not None:
                    attr_name = selected.name
                else:
                    attr_name = None
                itemset = self.itemsets[key]
                self.inputhints[(i, input.name, attrs)] = (attr_name, itemset.title)

    def _restoreHints(self):
        settings = []
        for i, (key, input) in enumerate(self.data.items()):
            attrs = source_attributes(input.table.domain)
            attrs = tuple(attr.name for attr in attrs)
            hint = self.inputhints.get((i, input.name, attrs), None)
            if hint is not None:
                attr, name = hint
                attr_ind = attrs.index(attr) if attr is not None else -1
                settings.append((attr_ind, name))
            else:
                return

        # all inputs match the stored hints
        for i, key in enumerate(self.itemsets):
            attr, itemtitle = settings[i]
            self.itemsets[key] = self.itemsets[key]._replace(title=itemtitle)
            _, cb = self._controlAtIndex(i)
            cb.setCurrentIndex(attr)

    def _createDiagram(self):
        self._updating = True

        oldselection = list(self.selection)

        self.vennwidget.clear()
        n = len(self.itemsets)
        self.disjoint = disjoint(set(s.items) for s in self.itemsets.values())

        vennitems = []
        colors = colorpalette.ColorPaletteHSV(n)

        for i, (key, item) in enumerate(self.itemsets.items()):
            count = len(set(item.items))
            count_all = len(item.items)
            if count != count_all:
                fmt = "{} <i>(all: {})</i>"
            else:
                fmt = "{}"
            counts = fmt.format(count, count_all)
            gr = VennSetItem(text=item.title, informativeText=counts)
            color = colors[i]
            color.setAlpha(100)
            gr.setBrush(QBrush(color))
            gr.setPen(QPen(Qt.NoPen))
            vennitems.append(gr)

        self.vennwidget.setItems(vennitems)

        for i, area in enumerate(self.vennwidget.vennareas()):
            area_items = list(map(str, list(self.disjoint[i])))
            if i:
                area.setText("{0}".format(len(area_items)))

            label = disjoint_set_label(i, n, simplify=False)
            head = "<h4>|{}| = {}</h4>".format(label, len(area_items))
            if len(area_items) > 32:
                items_str = ", ".join(map(escape, area_items[:32]))
                hidden = len(area_items) - 32
                tooltip = "{}<span>{}, ...</br>({} items not shown)<span>".format(head, items_str, hidden)
            elif area_items:
                tooltip = "{}<span>{}</span>".format(head, ", ".join(map(escape, area_items)))
            else:
                tooltip = head

            area.setToolTip(tooltip)

            area.setPen(QPen(QColor(10, 10, 10, 200), 1.5))
            area.setFlag(QGraphicsPathItem.ItemIsSelectable, True)
            area.setSelected(i in oldselection)

        self._updating = False
        self._on_selectionChanged()

    def _updateInfo(self):
        # Clear all warnings
        self.warning()

        if not len(self.data):
            self.info.setText("No data on input\n")
        else:
            self.info.setText("{0} data sets on input\n".format(len(self.data)))

        if self.useidentifiers:
            no_idx = [
                "#{}".format(i + 1)
                for i, key in enumerate(self.data)
                if not source_attributes(self.data[key].table.domain)
            ]
            if len(no_idx) == 1:
                self.warning("Data set {} has no suitable identifiers.".format(no_idx[0]))
            elif len(no_idx) > 1:
                self.warning(
                    "Data sets {} and {} have no suitable identifiers.".format(", ".join(no_idx[:-1]), no_idx[-1])
                )

    def _on_selectionChanged(self):
        if self._updating:
            return

        areas = self.vennwidget.vennareas()
        indices = [i for i, area in enumerate(areas) if area.isSelected()]

        self.selection = indices

        self.invalidateOutput()

    def _on_useidentifiersChanged(self):
        self.inputsBox.setEnabled(self.useidentifiers == 1)
        # Invalidate all itemsets
        self._invalidate()
        self._updateItemsets()
        self._createDiagram()

        self._updateInfo()

    def _on_inputAttrActivated(self, attr_index):
        combo = self.sender()
        # Find the input index to which the combo box belongs
        # (they are reordered when removing inputs).
        index = None
        inputs = list(self.data.items())
        for i in range(len(inputs)):
            _, c = self._controlAtIndex(i)
            if c is combo:
                index = i
                break

        assert index is not None

        key, _ = inputs[index]

        self._invalidate([key])
        self._updateItemsets()
        self._createDiagram()

    def _on_itemTextEdited(self, index, text):
        text = str(text)
        key = list(self.itemsets.keys())[index]
        self.itemsets[key] = self.itemsets[key]._replace(title=text)

    def invalidateOutput(self):
        self.commit()

    def commit(self):
        selected_subsets = []

        selected_items = reduce(set.union, [self.disjoint[index] for index in self.selection], set())

        def match(val):
            if numpy.isnan(val):
                return False
            else:
                return str(val) in selected_items

        source_var = Orange.data.StringVariable("source")
        item_id_var = Orange.data.StringVariable("item_id")

        names = [itemset.title.strip() for itemset in self.itemsets.values()]
        names = uniquify(names)

        for i, (key, input) in enumerate(self.data.items()):
            if self.useidentifiers:
                attr = self.itemsetAttr(key)
                if attr is not None:
                    mask = list(map(match, (inst[attr] for inst in input.table)))
                else:
                    mask = [False] * len(input.table)

                def instance_key(inst):
                    return str(inst[attr])

            else:
                mask = [ComparableInstance(inst) in selected_items for inst in input.table]
                _map = {item: str(i) for i, item in enumerate(selected_items)}

                def instance_key(inst):
                    return _map[ComparableInstance(inst)]

            mask = numpy.array(mask, dtype=bool)
            subset = input.table[mask]

            if len(subset) == 0:
                continue

            # add columns with source table id and set id

            if not self.output_duplicates:
                id_column = numpy.array([[instance_key(inst)] for inst in subset], dtype=object)
                source_names = numpy.array([[names[i]]] * len(subset), dtype=object)

                subset = append_column(subset, "M", source_var, source_names)
                subset = append_column(subset, "M", item_id_var, id_column)

            selected_subsets.append(subset)

        if selected_subsets and not self.output_duplicates:
            data = table_concat(selected_subsets)
            # Get all variables which are not constant between the same
            # item set
            varying = varying_between(data, [item_id_var])

            if source_var in varying:
                varying.remove(source_var)

            data = reshape_wide(data, varying, [item_id_var], [source_var])
            # remove the temporary item set id column
            data = drop_columns(data, [item_id_var])
        elif selected_subsets:
            data = table_concat(selected_subsets)
        else:
            data = None

        self.send("Selected Data", data)

    def getSettings(self, *args, **kwargs):
        self._storeHints()
        return super().getSettings(self, *args, **kwargs)

    def send_report(self):
        self.report_plot()
class TestAnchorLayout(QAppTestCase):
    def setUp(self):
        QAppTestCase.setUp(self)
        self.scene = CanvasScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.show()
        self.view.resize(600, 400)

    def test_layout(self):
        file_desc, disc_desc, bayes_desc = self.widget_desc()
        file_item = NodeItem()
        file_item.setWidgetDescription(file_desc)
        file_item.setPos(0, 150)
        self.scene.add_node_item(file_item)

        bayes_item = NodeItem()
        bayes_item.setWidgetDescription(bayes_desc)
        bayes_item.setPos(200, 0)
        self.scene.add_node_item(bayes_item)

        disc_item = NodeItem()
        disc_item.setWidgetDescription(disc_desc)
        disc_item.setPos(200, 300)
        self.scene.add_node_item(disc_item)

        link = LinkItem()
        link.setSourceItem(file_item)
        link.setSinkItem(disc_item)
        self.scene.add_link_item(link)

        link = LinkItem()
        link.setSourceItem(file_item)
        link.setSinkItem(bayes_item)
        self.scene.add_link_item(link)

        layout = AnchorLayout()
        self.scene.addItem(layout)
        self.scene.set_anchor_layout(layout)

        layout.invalidateNode(file_item)
        layout.activate()

        p1, p2 = file_item.outputAnchorItem.anchorPositions()
        self.assertGreater(p1, p2)

        self.scene.node_item_position_changed.connect(layout.invalidateNode)

        path = QPainterPath()
        path.addEllipse(125, 0, 50, 300)

        def advance():
            t = time.clock()
            bayes_item.setPos(path.pointAtPercent(t % 1.0))
            disc_item.setPos(path.pointAtPercent((t + 0.5) % 1.0))

            self.singleShot(20, advance)

        advance()

        self.app.exec_()

    def widget_desc(self):
        from ...registry.tests import small_testing_registry
        reg = small_testing_registry()

        file_desc = reg.widget("Orange.widgets.data.owfile.OWFile")

        discretize_desc = reg.widget(
            "Orange.widgets.data.owdiscretize.OWDiscretize")

        bayes_desc = reg.widget(
            "Orange.widgets.classify.ownaivebayes.OWNaiveBayes")

        return file_desc, discretize_desc, bayes_desc
Ejemplo n.º 8
0
class OWSilhouettePlot(widget.OWWidget):
    name = "Silhouette Plot"
    description = "Visually assess cluster quality and " \
                  "the degree of cluster membership."

    icon = "icons/SilhouettePlot.svg"
    priority = 300

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Selected Data", Orange.data.Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)]

    replaces = [
        "orangecontrib.prototypes.widgets.owsilhouetteplot.OWSilhouettePlot",
        "Orange.widgets.unsupervised.owsilhouetteplot.OWSilhouettePlot"
    ]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Distance metric index
    distance_idx = settings.Setting(0)
    #: Group/cluster variable index
    cluster_var_idx = settings.ContextSetting(0)
    #: Annotation variable index
    annotation_var_idx = settings.ContextSetting(0)
    #: Group the silhouettes by cluster
    group_by_cluster = settings.Setting(True)
    #: A fixed size for an instance bar
    bar_size = settings.Setting(3)
    #: Add silhouette scores to output data
    add_scores = settings.Setting(False)
    auto_commit = settings.Setting(False)

    Distances = [("Euclidean", Orange.distance.Euclidean),
                 ("Manhattan", Orange.distance.Manhattan)]

    graph_name = "scene"
    buttons_area_orientation = Qt.Vertical

    class Error(widget.OWWidget.Error):
        need_two_clusters = Msg("Need at least two non-empty clusters")

    def __init__(self):
        super().__init__()

        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self._silplot = None

        gui.comboBox(
            self.controlArea, self, "distance_idx", box="Distance",
            items=[name for name, _ in OWSilhouettePlot.Distances],
            orientation=Qt.Horizontal, callback=self._invalidate_distances)

        box = gui.vBox(self.controlArea, "Cluster Label")
        self.cluster_var_cb = gui.comboBox(
            box, self, "cluster_var_idx", addSpace=4,
            callback=self._invalidate_scores)
        gui.checkBox(
            box, self, "group_by_cluster", "Group by cluster",
            callback=self._replot)
        self.cluster_var_model = itemmodels.VariableListModel(parent=self)
        self.cluster_var_cb.setModel(self.cluster_var_model)

        box = gui.vBox(self.controlArea, "Bars")
        gui.widgetLabel(box, "Bar width:")
        gui.hSlider(
            box, self, "bar_size", minValue=1, maxValue=10, step=1,
            callback=self._update_bar_size, addSpace=6)
        gui.widgetLabel(box, "Annotations:")
        self.annotation_cb = gui.comboBox(
            box, self, "annotation_var_idx", callback=self._update_annotations)
        self.annotation_var_model = itemmodels.VariableListModel(parent=self)
        self.annotation_var_model[:] = ["None"]
        self.annotation_cb.setModel(self.annotation_var_model)
        ibox = gui.indentedBox(box, 5)
        self.ann_hidden_warning = warning = gui.widgetLabel(
            ibox, "(increase the width to show)")
        ibox.setFixedWidth(ibox.sizeHint().width())
        warning.setVisible(False)

        gui.rubber(self.controlArea)

        gui.separator(self.buttonsArea)
        box = gui.vBox(self.buttonsArea, "Output")
        # Thunk the call to commit to call conditional commit
        gui.checkBox(box, self, "add_scores", "Add silhouette scores",
                     callback=lambda: self.commit())
        gui.auto_commit(
            box, self, "auto_commit", "Commit",
            auto_label="Auto commit", box=False)
        # Ensure that the controlArea is not narrower than buttonsArea
        self.controlArea.layout().addWidget(self.buttonsArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(600, 720))

    @check_sql_input
    def set_data(self, data):
        """
        Set the input data set.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        warning_msg = ""
        candidatevars = []
        if data is not None:
            candidatevars = [
                v for v in data.domain.variables + data.domain.metas
                if v.is_discrete and len(v.values) >= 2]
            if not candidatevars:
                error_msg = "Input does not have any suitable cluster labels."
                data = None

        if data is not None:
            ncont = sum(v.is_continuous for v in data.domain.attributes)
            ndiscrete = len(data.domain.attributes) - ncont
            if ncont == 0:
                data = None
                error_msg = "No continuous columns"
            elif ncont < len(data.domain.attributes):
                warning_msg = "{0} discrete columns will not be used for " \
                              "distance computation".format(ndiscrete)

        self.data = data
        if data is not None:
            self.cluster_var_model[:] = candidatevars
            if data.domain.class_var in candidatevars:
                self.cluster_var_idx = \
                    candidatevars.index(data.domain.class_var)
            else:
                self.cluster_var_idx = 0

            annotvars = [var for var in data.domain.metas if var.is_string]
            self.annotation_var_model[:] = ["None"] + annotvars
            self.annotation_var_idx = 1 if len(annotvars) else 0
            self._effective_data = Orange.distance._preprocess(data)
            self.openContext(Orange.data.Domain(candidatevars))

        self.error(error_msg)
        self.warning(warning_msg)

    def handleNewSignals(self):
        if self._effective_data is not None:
            self._update()
            self._replot()

        self.unconditional_commit()

    def clear(self):
        """
        Clear the widget state.
        """
        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self.cluster_var_model[:] = []
        self.annotation_var_model[:] = ["None"]
        self._clear_scene()

    def _clear_scene(self):
        # Clear the graphics scene and associated objects
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self._silplot = None

    def _invalidate_distances(self):
        # Invalidate the computed distance matrix and recompute the silhouette.
        self._matrix = None
        self._invalidate_scores()

    def _invalidate_scores(self):
        # Invalidate and recompute the current silhouette scores.
        self._labels = self._silhouette = None
        self._update()
        self._replot()
        if self.data is not None:
            self.commit()

    def _update(self):
        # Update/recompute the distances/scores as required
        if self.data is None:
            self._silhouette = None
            self._labels = None
            self._matrix = None
            self._clear_scene()
            return

        if self._matrix is None and self._effective_data is not None:
            _, metric = self.Distances[self.distance_idx]
            self._matrix = numpy.asarray(metric(self._effective_data))

        labelvar = self.cluster_var_model[self.cluster_var_idx]
        labels, _ = self.data.get_column_view(labelvar)
        labels = labels.astype(int)
        _, counts = numpy.unique(labels, return_counts=True)
        if numpy.count_nonzero(counts) >= 2:
            self.Error.need_two_clusters.clear()
            silhouette = sklearn.metrics.silhouette_samples(
                self._matrix, labels, metric="precomputed")
        else:
            self.Error.need_two_clusters()
            labels = silhouette = None

        self._labels = labels
        self._silhouette = silhouette

    def _set_bar_height(self):
        visible = self.bar_size >= 5
        self._silplot.setBarHeight(self.bar_size)
        self._silplot.setRowNamesVisible(visible)
        self.ann_hidden_warning.setVisible(
            not visible and self.annotation_var_idx > 0)

    def _replot(self):
        # Clear and replot/initialize the scene
        self._clear_scene()
        if self._silhouette is not None and self._labels is not None:
            var = self.cluster_var_model[self.cluster_var_idx]
            self._silplot = silplot = SilhouettePlot()
            self._set_bar_height()

            if self.group_by_cluster:
                silplot.setScores(self._silhouette, self._labels, var.values)
            else:
                silplot.setScores(
                    self._silhouette,
                    numpy.zeros(len(self._silhouette), dtype=int),
                    [""]
                )

            self.scene.addItem(silplot)
            self._update_annotations()

            silplot.resize(silplot.effectiveSizeHint(Qt.PreferredSize))
            silplot.selectionChanged.connect(self.commit)

            self.scene.setSceneRect(
                QRectF(QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_bar_size(self):
        if self._silplot is not None:
            self._set_bar_height()
            self.scene.setSceneRect(
                QRectF(QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_annotations(self):
        if 0 < self.annotation_var_idx < len(self.annotation_var_model):
            annot_var = self.annotation_var_model[self.annotation_var_idx]
        else:
            annot_var = None
        self.ann_hidden_warning.setVisible(
            self.bar_size < 5 and annot_var is not None)

        if self._silplot is not None:
            if annot_var is not None:
                column, _ = self.data.get_column_view(annot_var)
                self._silplot.setRowNames(
                    [annot_var.str_val(value) for value in column])
            else:
                self._silplot.setRowNames(None)

    def commit(self):
        """
        Commit/send the current selection to the output.
        """
        selected = indices = data = None
        if self.data is not None:
            selectedmask = numpy.full(len(self.data), False, dtype=bool)
            if self._silplot is not None:
                indices = self._silplot.selection()
                selectedmask[indices] = True
            scores = self._silhouette
            silhouette_var = None
            if self.add_scores:
                var = self.cluster_var_model[self.cluster_var_idx]
                silhouette_var = Orange.data.ContinuousVariable(
                    "Silhouette ({})".format(escape(var.name)))
                domain = Orange.data.Domain(
                    self.data.domain.attributes,
                    self.data.domain.class_vars,
                    self.data.domain.metas + (silhouette_var, ))
                data = self.data.from_table(
                    domain, self.data)
            else:
                domain = self.data.domain
                data = self.data

            if numpy.count_nonzero(selectedmask):
                selected = self.data.from_table(
                    domain, self.data, numpy.flatnonzero(selectedmask))

            if self.add_scores:
                if selected is not None:
                    selected[:, silhouette_var] = numpy.c_[scores[selectedmask]]
                data[:, silhouette_var] = numpy.c_[scores]

        self.send("Selected Data", selected)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(data, indices))

    def send_report(self):
        if not len(self.cluster_var_model):
            return

        self.report_plot()
        caption = "Silhouette plot ({} distance), clustered by '{}'".format(
            self.Distances[self.distance_idx][0],
            self.cluster_var_model[self.cluster_var_idx])
        if self.annotation_var_idx and self._silplot.rowNamesVisible():
            caption += ", annotated with '{}'".format(
                self.annotation_var_model[self.annotation_var_idx])
        self.report_caption(caption)

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
Ejemplo n.º 9
0
class TestAnchorLayout(QAppTestCase):
    def setUp(self):
        super(TestAnchorLayout, self).setUp()
        self.scene = CanvasScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.show()
        self.view.resize(600, 400)

    def tearDown(self):
        self.scene.clear()
        self.view.deleteLater()
        self.scene.deleteLater()
        del self.scene
        del self.view
        super(TestAnchorLayout, self).tearDown()

    def test_layout(self):
        one_desc, negate_desc, cons_desc = self.widget_desc()
        one_item = NodeItem()
        one_item.setWidgetDescription(one_desc)
        one_item.setPos(0, 150)
        self.scene.add_node_item(one_item)

        cons_item = NodeItem()
        cons_item.setWidgetDescription(cons_desc)
        cons_item.setPos(200, 0)
        self.scene.add_node_item(cons_item)

        negate_item = NodeItem()
        negate_item.setWidgetDescription(negate_desc)
        negate_item.setPos(200, 300)
        self.scene.add_node_item(negate_item)

        link = LinkItem()
        link.setSourceItem(one_item)
        link.setSinkItem(negate_item)
        self.scene.add_link_item(link)

        link = LinkItem()
        link.setSourceItem(one_item)
        link.setSinkItem(cons_item)
        self.scene.add_link_item(link)

        layout = AnchorLayout()
        self.scene.addItem(layout)
        self.scene.set_anchor_layout(layout)

        layout.invalidateNode(one_item)
        layout.activate()

        p1, p2 = one_item.outputAnchorItem.anchorPositions()
        self.assertTrue(p1 > p2)

        self.scene.node_item_position_changed.connect(layout.invalidateNode)

        path = QPainterPath()
        path.addEllipse(125, 0, 50, 300)

        def advance():
            t = time.clock()
            cons_item.setPos(path.pointAtPercent(t % 1.0))
            negate_item.setPos(path.pointAtPercent((t + 0.5) % 1.0))

        timer = QTimer(negate_item, interval=20)
        timer.start()
        timer.timeout.connect(advance)
        self.app.exec_()

    def widget_desc(self):
        reg = small_testing_registry()
        one_desc = reg.widget("one")
        negate_desc = reg.widget("negate")
        cons_desc = reg.widget("cons")
        return one_desc, negate_desc, cons_desc
Ejemplo n.º 10
0
class OWSilhouettePlot(widget.OWWidget):
    name = "Silhouette Plot"
    description = "Visually assess cluster quality and " \
                  "the degree of cluster membership."

    icon = "icons/SilhouettePlot.svg"
    priority = 300

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Selected Data", Orange.data.Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)]

    replaces = [
        "orangecontrib.prototypes.widgets.owsilhouetteplot.OWSilhouettePlot",
        "Orange.widgets.unsupervised.owsilhouetteplot.OWSilhouettePlot"
    ]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Distance metric index
    distance_idx = settings.Setting(0)
    #: Group/cluster variable index
    cluster_var_idx = settings.ContextSetting(0)
    #: Annotation variable index
    annotation_var_idx = settings.ContextSetting(0)
    #: Group the silhouettes by cluster
    group_by_cluster = settings.Setting(True)
    #: A fixed size for an instance bar
    bar_size = settings.Setting(3)
    #: Add silhouette scores to output data
    add_scores = settings.Setting(False)
    auto_commit = settings.Setting(False)

    Distances = [("Euclidean", Orange.distance.Euclidean),
                 ("Manhattan", Orange.distance.Manhattan)]

    graph_name = "scene"
    buttons_area_orientation = Qt.Vertical

    class Error(widget.OWWidget.Error):
        need_two_clusters = Msg("Need at least two non-empty clusters")

    def __init__(self):
        super().__init__()

        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self._silplot = None

        gui.comboBox(self.controlArea,
                     self,
                     "distance_idx",
                     box="Distance",
                     items=[name for name, _ in OWSilhouettePlot.Distances],
                     orientation=Qt.Horizontal,
                     callback=self._invalidate_distances)

        box = gui.vBox(self.controlArea, "Cluster Label")
        self.cluster_var_cb = gui.comboBox(box,
                                           self,
                                           "cluster_var_idx",
                                           addSpace=4,
                                           callback=self._invalidate_scores)
        gui.checkBox(box,
                     self,
                     "group_by_cluster",
                     "Group by cluster",
                     callback=self._replot)
        self.cluster_var_model = itemmodels.VariableListModel(parent=self)
        self.cluster_var_cb.setModel(self.cluster_var_model)

        box = gui.vBox(self.controlArea, "Bars")
        gui.widgetLabel(box, "Bar width:")
        gui.hSlider(box,
                    self,
                    "bar_size",
                    minValue=1,
                    maxValue=10,
                    step=1,
                    callback=self._update_bar_size,
                    addSpace=6)
        gui.widgetLabel(box, "Annotations:")
        self.annotation_cb = gui.comboBox(box,
                                          self,
                                          "annotation_var_idx",
                                          callback=self._update_annotations)
        self.annotation_var_model = itemmodels.VariableListModel(parent=self)
        self.annotation_var_model[:] = ["None"]
        self.annotation_cb.setModel(self.annotation_var_model)
        ibox = gui.indentedBox(box, 5)
        self.ann_hidden_warning = warning = gui.widgetLabel(
            ibox, "(increase the width to show)")
        ibox.setFixedWidth(ibox.sizeHint().width())
        warning.setVisible(False)

        gui.rubber(self.controlArea)

        gui.separator(self.buttonsArea)
        box = gui.vBox(self.buttonsArea, "Output")
        # Thunk the call to commit to call conditional commit
        gui.checkBox(box,
                     self,
                     "add_scores",
                     "Add silhouette scores",
                     callback=lambda: self.commit())
        gui.auto_commit(box,
                        self,
                        "auto_commit",
                        "Commit",
                        auto_label="Auto commit",
                        box=False)
        # Ensure that the controlArea is not narrower than buttonsArea
        self.controlArea.layout().addWidget(self.buttonsArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(600, 720))

    @check_sql_input
    def set_data(self, data):
        """
        Set the input data set.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        warning_msg = ""
        candidatevars = []
        if data is not None:
            candidatevars = [
                v for v in data.domain.variables + data.domain.metas
                if v.is_discrete and len(v.values) >= 2
            ]
            if not candidatevars:
                error_msg = "Input does not have any suitable cluster labels."
                data = None

        if data is not None:
            ncont = sum(v.is_continuous for v in data.domain.attributes)
            ndiscrete = len(data.domain.attributes) - ncont
            if ncont == 0:
                data = None
                error_msg = "No continuous columns"
            elif ncont < len(data.domain.attributes):
                warning_msg = "{0} discrete columns will not be used for " \
                              "distance computation".format(ndiscrete)

        self.data = data
        if data is not None:
            self.cluster_var_model[:] = candidatevars
            if data.domain.class_var in candidatevars:
                self.cluster_var_idx = \
                    candidatevars.index(data.domain.class_var)
            else:
                self.cluster_var_idx = 0

            annotvars = [var for var in data.domain.metas if var.is_string]
            self.annotation_var_model[:] = ["None"] + annotvars
            self.annotation_var_idx = 1 if len(annotvars) else 0
            self._effective_data = Orange.distance._preprocess(data)
            self.openContext(Orange.data.Domain(candidatevars))

        self.error(error_msg)
        self.warning(warning_msg)

    def handleNewSignals(self):
        if self._effective_data is not None:
            self._update()
            self._replot()

        self.unconditional_commit()

    def clear(self):
        """
        Clear the widget state.
        """
        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self.cluster_var_model[:] = []
        self.annotation_var_model[:] = ["None"]
        self._clear_scene()

    def _clear_scene(self):
        # Clear the graphics scene and associated objects
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self._silplot = None

    def _invalidate_distances(self):
        # Invalidate the computed distance matrix and recompute the silhouette.
        self._matrix = None
        self._invalidate_scores()

    def _invalidate_scores(self):
        # Invalidate and recompute the current silhouette scores.
        self._labels = self._silhouette = None
        self._update()
        self._replot()
        if self.data is not None:
            self.commit()

    def _update(self):
        # Update/recompute the distances/scores as required
        if self.data is None:
            self._silhouette = None
            self._labels = None
            self._matrix = None
            self._clear_scene()
            return

        if self._matrix is None and self._effective_data is not None:
            _, metric = self.Distances[self.distance_idx]
            self._matrix = numpy.asarray(metric(self._effective_data))

        labelvar = self.cluster_var_model[self.cluster_var_idx]
        labels, _ = self.data.get_column_view(labelvar)
        labels = labels.astype(int)
        _, counts = numpy.unique(labels, return_counts=True)
        if numpy.count_nonzero(counts) >= 2:
            self.Error.need_two_clusters.clear()
            silhouette = sklearn.metrics.silhouette_samples(
                self._matrix, labels, metric="precomputed")
        else:
            self.Error.need_two_clusters()
            labels = silhouette = None

        self._labels = labels
        self._silhouette = silhouette

    def _set_bar_height(self):
        visible = self.bar_size >= 5
        self._silplot.setBarHeight(self.bar_size)
        self._silplot.setRowNamesVisible(visible)
        self.ann_hidden_warning.setVisible(not visible
                                           and self.annotation_var_idx > 0)

    def _replot(self):
        # Clear and replot/initialize the scene
        self._clear_scene()
        if self._silhouette is not None and self._labels is not None:
            var = self.cluster_var_model[self.cluster_var_idx]
            self._silplot = silplot = SilhouettePlot()
            self._set_bar_height()

            if self.group_by_cluster:
                silplot.setScores(self._silhouette, self._labels, var.values)
            else:
                silplot.setScores(
                    self._silhouette,
                    numpy.zeros(len(self._silhouette), dtype=int), [""])

            self.scene.addItem(silplot)
            self._update_annotations()

            silplot.resize(silplot.effectiveSizeHint(Qt.PreferredSize))
            silplot.selectionChanged.connect(self.commit)

            self.scene.setSceneRect(
                QRectF(QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_bar_size(self):
        if self._silplot is not None:
            self._set_bar_height()
            self.scene.setSceneRect(
                QRectF(QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_annotations(self):
        if 0 < self.annotation_var_idx < len(self.annotation_var_model):
            annot_var = self.annotation_var_model[self.annotation_var_idx]
        else:
            annot_var = None
        self.ann_hidden_warning.setVisible(self.bar_size < 5
                                           and annot_var is not None)

        if self._silplot is not None:
            if annot_var is not None:
                column, _ = self.data.get_column_view(annot_var)
                self._silplot.setRowNames(
                    [annot_var.str_val(value) for value in column])
            else:
                self._silplot.setRowNames(None)

    def commit(self):
        """
        Commit/send the current selection to the output.
        """
        selected = indices = data = None
        if self.data is not None:
            selectedmask = numpy.full(len(self.data), False, dtype=bool)
            if self._silplot is not None:
                indices = self._silplot.selection()
                selectedmask[indices] = True
            scores = self._silhouette
            silhouette_var = None
            if self.add_scores:
                var = self.cluster_var_model[self.cluster_var_idx]
                silhouette_var = Orange.data.ContinuousVariable(
                    "Silhouette ({})".format(escape(var.name)))
                domain = Orange.data.Domain(
                    self.data.domain.attributes, self.data.domain.class_vars,
                    self.data.domain.metas + (silhouette_var, ))
                data = self.data.from_table(domain, self.data)
            else:
                domain = self.data.domain
                data = self.data

            if numpy.count_nonzero(selectedmask):
                selected = self.data.from_table(
                    domain, self.data, numpy.flatnonzero(selectedmask))

            if self.add_scores:
                if selected is not None:
                    selected[:,
                             silhouette_var] = numpy.c_[scores[selectedmask]]
                data[:, silhouette_var] = numpy.c_[scores]

        self.send("Selected Data", selected)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(data, indices))

    def send_report(self):
        if not len(self.cluster_var_model):
            return

        self.report_plot()
        caption = "Silhouette plot ({} distance), clustered by '{}'".format(
            self.Distances[self.distance_idx][0],
            self.cluster_var_model[self.cluster_var_idx])
        if self.annotation_var_idx and self._silplot.rowNamesVisible():
            caption += ", annotated with '{}'".format(
                self.annotation_var_model[self.annotation_var_idx])
        self.report_caption(caption)

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
Ejemplo n.º 11
0
class EditLinksDialog(QDialog):
    """
    A dialog for editing links.

    >>> dlg = EditLinksDialog()
    >>> dlg.setNodes(file_node, test_learners_node)
    >>> dlg.setLinks([(file_node.output_channel("Data"),
    ...               (test_learners_node.input_channel("Data")])
    >>> if dlg.exec_() == EditLinksDialog.Accpeted:
    ...     new_links = dlg.links()
    ...

    """
    def __init__(self, *args, **kwargs):
        QDialog.__init__(self, *args, **kwargs)

        self.setModal(True)

        self.__setupUi()

    def __setupUi(self):
        layout = QVBoxLayout()

        # Scene with the link editor.
        self.scene = LinksEditScene()
        self.view = QGraphicsView(self.scene)
        self.view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.view.setRenderHint(QPainter.Antialiasing)

        self.scene.editWidget.geometryChanged.connect(self.__onGeometryChanged)

        # Ok/Cancel/Clear All buttons.
        buttons = QDialogButtonBox(QDialogButtonBox.Ok |
                                   QDialogButtonBox.Cancel |
                                   QDialogButtonBox.Reset,
                                   Qt.Horizontal)

        clear_button = buttons.button(QDialogButtonBox.Reset)
        clear_button.setText(self.tr("Clear All"))

        buttons.accepted.connect(self.accept)
        buttons.rejected.connect(self.reject)
        clear_button.clicked.connect(self.scene.editWidget.clearLinks)

        layout.addWidget(self.view)
        layout.addWidget(buttons)

        self.setLayout(layout)
        layout.setSizeConstraint(QVBoxLayout.SetFixedSize)

        self.setSizeGripEnabled(False)

    def setNodes(self, source_node, sink_node):
        """
        Set the source/sink nodes (:class:`.SchemeNode` instances)
        between which to edit the links.

        .. note:: This should be called before :func:`setLinks`.

        """
        self.scene.editWidget.setNodes(source_node, sink_node)

    def setLinks(self, links):
        """
        Set a list of links to display between the source and sink
        nodes. The `links` is a list of (`OutputSignal`, `InputSignal`)
        tuples where the first element is an output signal of the source
        node and the second an input signal of the sink node.

        """
        self.scene.editWidget.setLinks(links)

    def links(self):
        """
        Return the links between the source and sink node.
        """
        return self.scene.editWidget.links()

    def __onGeometryChanged(self):
        size = self.scene.editWidget.size()
        left, top, right, bottom = self.getContentsMargins()
        self.view.setFixedSize(size.toSize() + \
                               QSize(left + right + 4, top + bottom + 4))
Ejemplo n.º 12
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    def __init__(self):
        super().__init__()
        self.model = None
        self.forest_adapter = None
        self.instances = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []
        # In some rare cases, we need to prevent commiting, the only one
        # that this currently helps is that when changing the size calculation
        # the trees are all recomputed, but we don't want to output a new tree
        # to keep things consistent with other ui controls.
        self.__prevent_commit = False

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x + 1)),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info)

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(
            box_display, self, 'depth_limit', label='Depth', ticks=False,
            callback=self.update_depth)
        self.ui_target_class_combo = gui.comboBox(
            box_display, self, 'target_class_index', label='Target class',
            orientation=Qt.Horizontal, items=[], contentsLength=8,
            callback=self.update_colors)
        self.ui_size_calc_combo = gui.comboBox(
            box_display, self, 'size_calc_idx', label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8,
            callback=self.update_size_calc)
        self.ui_zoom_slider = gui.hSlider(
            box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20,
            maxValue=150, callback=self.zoom_changed, createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self._draw_trees()
            self.color_palette = self.forest_adapter.get_trees()[0]

            self.instances = model.instances
            # this bit is important for the regression classifier
            if self.instances is not None and self.instances.domain != model.domain:
                self.clf_dataset = self.instances.transform(self.model.domain)
            else:
                self.clf_dataset = self.instances

            self._update_info_box()
            self._update_target_class_combo()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    def update_depth(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def update_colors(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_changed(self.target_class_index)

    def update_size_calc(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            with self._prevent_commit():
                self.grid.clear()
                self._draw_trees()
                # Keep the selected item
                if self.selected_tree_index != -1:
                    self.grid_items[self.selected_tree_index].setSelected(True)
                self.update_depth()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    @contextmanager
    def _prevent_commit(self):
        try:
            self.__prevent_commit = True
            yield
        finally:
            self.__prevent_commit = False

    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    def _get_max_depth(self):
        return max(tree.tree_adapter.max_depth for tree in self.ptrees)

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    @contextmanager
    def disable_ui(self):
        """Temporarly disable the UI while trees may be redrawn."""
        try:
            self.ui_size_calc_combo.setEnabled(False)
            self.ui_depth_slider.setEnabled(False)
            self.ui_target_class_combo.setEnabled(False)
            self.ui_zoom_slider.setEnabled(False)
            yield
        finally:
            self.ui_size_calc_combo.setEnabled(True)
            self.ui_depth_slider.setEnabled(True)
            self.ui_target_class_combo.setEnabled(True)
            self.ui_zoom_slider.setEnabled(True)

    def _draw_trees(self):
        self.grid_items, self.ptrees = [], []

        num_trees = len(self.forest_adapter.get_trees())
        with self.progressBar(num_trees) as prg, self.disable_ui():
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None, tree, interactive=False, padding=100,
                    target_class_index=self.target_class_index,
                    weight_adjustment=self.SIZE_CALCULATION[self.size_calc_idx][1]
                )
                grid_item = GridItem(
                    ptree, self.grid, max_size=self._calculate_zoom(self.zoom)
                )
                # We don't want to show flickering while the trees are being
                grid_item.setVisible(False)

                self.grid_items.append(grid_item)
                self.ptrees.append(ptree)
                prg.advance()

            self.grid.set_items(self.grid_items)
            # This is necessary when adding items for the first time
            if self.grid:
                width = (self.view.width() - self.view.verticalScrollBar().width())
                self.grid.reflow(width)
                self.grid.setPreferredWidth(width)
                # After drawing is complete, we show the trees
                for grid_item in self.grid_items:
                    grid_item.setVisible(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if self.__prevent_commit:
            return

        if not self.scene.selectedItems():
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        tree = self.model.trees[self.selected_tree_index]
        tree.instances = self.instances
        tree.meta_target_class_index = self.target_class_index
        tree.meta_size_calc_idx = self.size_calc_idx
        tree.meta_depth_limit = self.depth_limit

        self.send('Tree', tree)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [x for x in self.ui_target_class_combo.parent().children()
                 if isinstance(x, QLabel)][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [c.title() for c in self.instances.domain.class_vars[0].values]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.ui_target_class_combo.addItems(values)
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)
Ejemplo n.º 13
0
class TestAnchorLayout(QAppTestCase):
    def setUp(self):
        QAppTestCase.setUp(self)
        self.scene = CanvasScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.show()
        self.view.resize(600, 400)

    def test_layout(self):
        file_desc, disc_desc, bayes_desc = self.widget_desc()
        file_item = NodeItem()
        file_item.setWidgetDescription(file_desc)
        file_item.setPos(0, 150)
        self.scene.add_node_item(file_item)

        bayes_item = NodeItem()
        bayes_item.setWidgetDescription(bayes_desc)
        bayes_item.setPos(200, 0)
        self.scene.add_node_item(bayes_item)

        disc_item = NodeItem()
        disc_item.setWidgetDescription(disc_desc)
        disc_item.setPos(200, 300)
        self.scene.add_node_item(disc_item)

        link = LinkItem()
        link.setSourceItem(file_item)
        link.setSinkItem(disc_item)
        self.scene.add_link_item(link)

        link = LinkItem()
        link.setSourceItem(file_item)
        link.setSinkItem(bayes_item)
        self.scene.add_link_item(link)

        layout = AnchorLayout()
        self.scene.addItem(layout)
        self.scene.set_anchor_layout(layout)

        layout.invalidateNode(file_item)
        layout.activate()

        p1, p2 = file_item.outputAnchorItem.anchorPositions()
        self.assertGreater(p1, p2)

        self.scene.node_item_position_changed.connect(layout.invalidateNode)

        path = QPainterPath()
        path.addEllipse(125, 0, 50, 300)

        def advance():
            t = time.clock()
            bayes_item.setPos(path.pointAtPercent(t % 1.0))
            disc_item.setPos(path.pointAtPercent((t + 0.5) % 1.0))

            self.singleShot(20, advance)

        advance()

        self.app.exec_()

    def widget_desc(self):
        from ...registry.tests import small_testing_registry
        reg = small_testing_registry()

        file_desc = reg.widget(
            "Orange.widgets.data.owfile.OWFile"
        )

        discretize_desc = reg.widget(
            "Orange.widgets.data.owdiscretize.OWDiscretize"
        )

        bayes_desc = reg.widget(
            "Orange.widgets.classify.ownaivebayes.OWNaiveBayes"
        )

        return file_desc, discretize_desc, bayes_desc
Ejemplo n.º 14
0
class OWExplainPrediction(OWWidget, ConcurrentWidgetMixin):
    name = "Explain Prediction"
    description = "Prediction explanation widget."
    icon = "icons/ExplainPred.svg"
    priority = 110

    class Inputs:
        model = Input("Model", Model)
        background_data = Input("Background Data", Table)
        data = Input("Data", Table)

    class Outputs:
        scores = Output("Scores", Table)

    class Error(OWWidget.Error):
        domain_transform_err = Msg("{}")
        unknown_err = Msg("{}")

    class Information(OWWidget.Information):
        multiple_instances = Msg("Explaining prediction for the first "
                                 "instance in 'Data'.")

    settingsHandler = ClassValuesContextHandler()
    target_index = ContextSetting(0)
    stripe_len = Setting(10)

    graph_name = "scene"

    def __init__(self):
        OWWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        self.__results = None  # type: Optional[Results]
        self.model = None  # type: Optional[Model]
        self.background_data = None  # type: Optional[Table]
        self.data = None  # type: Optional[Table]
        self._stripe_plot = None  # type: Optional[StripePlot]
        self.mo_info = ""
        self.bv_info = ""
        self.setup_gui()

    def setup_gui(self):
        self._add_controls()
        self._add_plot()
        self.info.set_input_summary(self.info.NoInput)

    def _add_plot(self):
        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignVCenter | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def _add_controls(self):
        box = gui.vBox(self.controlArea, "Target class")
        self._target_combo = gui.comboBox(box,
                                          self,
                                          "target_index",
                                          callback=self.__target_combo_changed,
                                          contentsLength=12)

        box = gui.hBox(self.controlArea, "Zoom")
        gui.hSlider(box,
                    self,
                    "stripe_len",
                    None,
                    minValue=1,
                    maxValue=500,
                    createLabel=False,
                    callback=self.__size_slider_changed)

        gui.rubber(self.controlArea)

        box = gui.vBox(self.controlArea, "Prediction info")
        gui.label(box, self, "%(mo_info)s")  # type: QLabel
        bv_label = gui.label(box, self, "%(bv_info)s")  # type: QLabel
        bv_label.setToolTip("The average prediction for selected class.")

    def __target_combo_changed(self):
        self.update_scene()

    def __size_slider_changed(self):
        if self._stripe_plot is not None:
            self._stripe_plot.set_height(self.stripe_len)

    @Inputs.data
    @check_sql_input
    def set_data(self, data: Optional[Table]):
        self.data = data

    @Inputs.background_data
    @check_sql_input
    def set_background_data(self, data: Optional[Table]):
        self.background_data = data

    @Inputs.model
    def set_model(self, model: Optional[Model]):
        self.closeContext()
        self.model = model
        self.setup_controls()
        self.openContext(self.model.domain.class_var if self.model else None)

    def setup_controls(self):
        self._target_combo.clear()
        self._target_combo.setEnabled(True)
        if self.model is not None:
            if self.model.domain.has_discrete_class:
                self._target_combo.addItems(self.model.domain.class_var.values)
                self.target_index = 0
            elif self.model.domain.has_continuous_class:
                self.target_index = -1
                self._target_combo.setEnabled(False)
            else:
                raise NotImplementedError

    def handleNewSignals(self):
        self.clear()
        self.check_inputs()
        data = self.data and self.data[:1]
        self.start(run, data, self.background_data, self.model)

    def clear(self):
        self.mo_info = ""
        self.bv_info = ""
        self.__results = None
        self.cancel()
        self.clear_scene()
        self.clear_messages()

    def check_inputs(self):
        if self.data and len(self.data) > 1:
            self.Information.multiple_instances()

        summary, details, kwargs = self.info.NoInput, "", {}
        if self.data or self.background_data:
            n_data = len(self.data) if self.data else 0
            n_background_data = len(self.background_data) \
                if self.background_data else 0
            summary = f"{self.info.format_number(n_background_data)}, " \
                      f"{self.info.format_number(n_data)}"
            kwargs = {"format": Qt.RichText}
            details = format_multiple_summaries([("Background data",
                                                  self.background_data),
                                                 ("Data", self.data)])
        self.info.set_input_summary(summary, details, **kwargs)

    def clear_scene(self):
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self.view.setSceneRect(QRectF())
        self._stripe_plot = None

    def update_scene(self):
        self.clear_scene()
        self.mo_info = ""
        self.bv_info = ""
        scores = None
        if self.__results is not None:
            data = self.__results.transformed_data
            pred = self.__results.predictions
            base = self.__results.base_value
            values, _, labels, ranges = prepare_force_plot_data(
                self.__results.values, data, pred, self.target_index)

            index = 0
            HIGH, LOW = 0, 1
            plot_data = PlotData(high_values=values[index][HIGH],
                                 low_values=values[index][LOW][::-1],
                                 high_labels=labels[index][HIGH],
                                 low_labels=labels[index][LOW][::-1],
                                 value_range=ranges[index],
                                 model_output=pred[index][self.target_index],
                                 base_value=base[self.target_index])
            self.setup_plot(plot_data)

            self.mo_info = f"Model prediction: {_str(plot_data.model_output)}"
            self.bv_info = f"Base value: {_str(plot_data.base_value)}"

            assert isinstance(self.__results.values, list)
            scores = self.__results.values[self.target_index][0, :]
            names = [a.name for a in data.domain.attributes]
            scores = self.create_scores_table(scores, names)
        self.Outputs.scores.send(scores)

    def setup_plot(self, plot_data: PlotData):
        self._stripe_plot = StripePlot()
        self._stripe_plot.set_data(plot_data, self.stripe_len)
        self._stripe_plot.layout().activate()
        self._stripe_plot.geometryChanged.connect(self.update_scene_rect)
        self.scene.addItem(self._stripe_plot)
        self.update_scene_rect()

    def update_scene_rect(self):
        geom = self._stripe_plot.geometry()
        self.scene.setSceneRect(geom)
        self.view.setSceneRect(geom)

    @staticmethod
    def create_scores_table(scores: np.ndarray, names: List[str]) -> Table:
        domain = Domain([ContinuousVariable("Score")],
                        metas=[StringVariable("Feature")])
        scores_table = Table(domain,
                             scores[:, None],
                             metas=np.array(names)[:, None])
        scores_table.name = "Feature Scores"
        return scores_table

    def on_partial_result(self, _):
        pass

    def on_done(self, results: Optional[RunnerResults]):
        self.__results = results
        self.update_scene()

    def on_exception(self, ex: Exception):
        if isinstance(ex, DomainTransformationError):
            self.Error.domain_transform_err(ex)
        else:
            self.Error.unknown_err(ex)

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def sizeHint(self) -> QSizeF:
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QSize(700, 700))

    def send_report(self):
        if not self.data or not self.background_data or not self.model:
            return
        items = {"Target class": "None"}
        if self.model.domain.has_discrete_class:
            class_var = self.model.domain.class_var
            items["Target class"] = class_var.values[self.target_index]
        self.report_items(items)
        self.report_plot()
Ejemplo n.º 15
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    size_log_scale = settings.Setting(2)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    CLASSIFICATION, REGRESSION = range(2)

    def __init__(self):
        super().__init__()
        # Instance variables
        self.forest_type = self.CLASSIFICATION
        self.model = None
        self.forest_adapter = None
        self.dataset = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x * self.size_log_scale)),
        ]

        self.REGRESSION_COLOR_CALC = [
            ('None', lambda _, __: QColor(255, 255, 255)),
            ('Class mean', self._color_class_mean),
            ('Standard deviation', self._color_stddev),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info, label='')

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(box_display,
                                           self,
                                           'depth_limit',
                                           label='Depth',
                                           ticks=False,
                                           callback=self.max_depth_changed)
        self.ui_target_class_combo = gui.comboBox(
            box_display,
            self,
            'target_class_index',
            label='Target class',
            orientation=Qt.Horizontal,
            items=[],
            contentsLength=8,
            callback=self.target_colors_changed)
        self.ui_size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
            callback=self.size_calc_changed)
        self.ui_zoom_slider = gui.hSlider(box_display,
                                          self,
                                          'zoom',
                                          label='Zoom',
                                          ticks=False,
                                          minValue=20,
                                          maxValue=150,
                                          callback=self.zoom_changed,
                                          createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            if isinstance(model, RandomForestClassifier):
                self.forest_type = self.CLASSIFICATION
            elif isinstance(model, RandomForestRegressor):
                self.forest_type = self.REGRESSION
            else:
                raise RuntimeError('Invalid type of forest.')

            self.forest_adapter = self._get_forest_adapter(self.model)
            self.color_palette = self._type_specific('_get_color_palette')()
            self._draw_trees()

            self.dataset = model.instances
            # this bit is important for the regression classifier
            if self.dataset is not None and \
                    self.dataset.domain != model.domain:
                self.clf_dataset = Table.from_table(self.model.domain,
                                                    self.dataset)
            else:
                self.clf_dataset = self.dataset

            self._update_info_box()
            self._type_specific('_update_target_class_combo')()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    # CONTROL AREA CALLBACKS
    def max_depth_changed(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def target_colors_changed(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_has_changed()

    def size_calc_changed(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self.grid.clear()
            self._draw_trees()
            # Keep the selected item
            if self.selected_tree_index != -1:
                self.grid_items[self.selected_tree_index].setSelected(True)
            self.max_depth_changed()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    # MODEL CHANGED METHODS
    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(
            len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    # MODEL CLEARED METHODS
    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    # HELPFUL METHODS
    def _get_max_depth(self):
        return max([tree.tree_adapter.max_depth for tree in self.ptrees])

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    def _draw_trees(self):
        self.ui_size_calc_combo.setEnabled(False)
        self.grid_items, self.ptrees = [], []

        with self.progressBar(len(self.forest_adapter.get_trees())) as prg:
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None,
                    tree,
                    node_color_func=self._type_specific('_get_node_color'),
                    interactive=False,
                    padding=100)
                self.grid_items.append(
                    GridItem(ptree,
                             self.grid,
                             max_size=self._calculate_zoom(self.zoom)))
                self.ptrees.append(ptree)
                prg.advance()
        self.grid.set_items(self.grid_items)
        # This is necessary when adding items for the first time
        if self.grid:
            width = (self.view.width() - self.view.verticalScrollBar().width())
            self.grid.reflow(width)
            self.grid.setPreferredWidth(width)
        self.ui_size_calc_combo.setEnabled(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if len(self.scene.selectedItems()) == 0:
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        obj = self.model.trees[self.selected_tree_index]
        obj.instances = self.dataset
        obj.meta_target_class_index = self.target_class_index
        obj.meta_size_calc_idx = self.size_calc_idx
        obj.meta_size_log_scale = self.size_log_scale
        obj.meta_depth_limit = self.depth_limit

        self.send('Tree', obj)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)

    def _type_specific(self, method):
        """A best effort method getter that somewhat separates logic specific
        to classification and regression trees.
        This relies on conventional naming of specific methods, e.g.
        a method name _get_tooltip would need to be defined like so:
        _classification_get_tooltip and _regression_get_tooltip, since they are
        both specific.

        Parameters
        ----------
        method : str
            Method name that we would like to call.

        Returns
        -------
        callable or None

        """
        if self.forest_type == self.CLASSIFICATION:
            return getattr(self, '_classification' + method)
        elif self.forest_type == self.REGRESSION:
            return getattr(self, '_regression' + method)
        else:
            return None

    # CLASSIFICATION FOREST SPECIFIC METHODS
    def _classification_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItem('None')
        values = [c.title() for c in self.model.domain.class_vars[0].values]
        self.ui_target_class_combo.addItems(values)

    def _classification_get_color_palette(self):
        return [QColor(*c) for c in self.model.domain.class_var.colors]

    def _classification_get_node_color(self, adapter, tree_node):
        # this is taken almost directly from the existing classification tree
        # viewer
        colors = self.color_palette
        distribution = adapter.get_distribution(tree_node.label)[0]
        total = np.sum(distribution)

        if self.target_class_index:
            p = distribution[self.target_class_index - 1] / total
            color = colors[self.target_class_index - 1].lighter(200 - 100 * p)
        else:
            modus = np.argmax(distribution)
            p = distribution[modus] / (total or 1)
            color = colors[int(modus)].lighter(400 - 300 * p)
        return color

    # REGRESSION FOREST SPECIFIC METHODS
    def _regression_update_target_class_combo(self):
        self._clear_target_class_combo()
        self.ui_target_class_combo.addItems(
            list(zip(*self.REGRESSION_COLOR_CALC))[0])
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _regression_get_color_palette(self):
        return ContinuousPaletteGenerator(
            *self.forest_adapter.domain.class_var.colors)

    def _regression_get_node_color(self, adapter, tree_node):
        return self.REGRESSION_COLOR_CALC[self.target_class_index][1](
            adapter, tree_node)

    def _color_class_mean(self, adapter, tree_node):
        # calculate node colors relative to the mean of the node samples
        min_mean = np.min(self.clf_dataset.Y)
        max_mean = np.max(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        mean = np.mean(instances.Y)

        return self.color_palette[(mean - min_mean) / (max_mean - min_mean)]

    def _color_stddev(self, adapter, tree_node):
        # calculate node colors relative to the standard deviation in the node
        # samples
        min_mean, max_mean = 0, np.std(self.clf_dataset.Y)
        instances = adapter.get_instances_in_nodes(self.clf_dataset,
                                                   tree_node.label)
        std = np.std(instances.Y)

        return self.color_palette[(std - min_mean) / (max_mean - min_mean)]
Ejemplo n.º 16
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'

    priority = 1001

    inputs = [('Random forest', RandomForestModel, 'set_rf')]
    outputs = [('Tree', TreeModel)]

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    zoom = settings.Setting(50)
    selected_tree_index = settings.ContextSetting(-1)

    def __init__(self):
        super().__init__()
        self.model = None
        self.forest_adapter = None
        self.instances = None
        self.clf_dataset = None
        # We need to store refernces to the trees and grid items
        self.grid_items, self.ptrees = [], []
        # In some rare cases, we need to prevent commiting, the only one
        # that this currently helps is that when changing the size calculation
        # the trees are all recomputed, but we don't want to output a new tree
        # to keep things consistent with other ui controls.
        self.__prevent_commit = False

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x + 1)),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info)

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(box_display,
                                           self,
                                           'depth_limit',
                                           label='Depth',
                                           ticks=False,
                                           callback=self.update_depth)
        self.ui_target_class_combo = gui.comboBox(box_display,
                                                  self,
                                                  'target_class_index',
                                                  label='Target class',
                                                  orientation=Qt.Horizontal,
                                                  items=[],
                                                  contentsLength=8,
                                                  callback=self.update_colors)
        self.ui_size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
            callback=self.update_size_calc)
        self.ui_zoom_slider = gui.hSlider(box_display,
                                          self,
                                          'zoom',
                                          label='Zoom',
                                          ticks=False,
                                          minValue=20,
                                          maxValue=150,
                                          callback=self.zoom_changed,
                                          createLabel=False)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = QGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.grid = OWGrid()
        self.grid.geometryChanged.connect(self._update_scene_rect)
        self.scene.addItem(self.grid)

        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.view)

        self.resize(800, 500)

        self.clear()

    def set_rf(self, model=None):
        """When a different forest is given."""
        self.clear()
        self.model = model

        if model is not None:
            self.forest_adapter = self._get_forest_adapter(self.model)
            self._draw_trees()
            self.color_palette = self.forest_adapter.get_trees()[0]

            self.instances = model.instances
            # this bit is important for the regression classifier
            if self.instances is not None and self.instances.domain != model.domain:
                self.clf_dataset = Table.from_table(self.model.domain,
                                                    self.instances)
            else:
                self.clf_dataset = self.instances

            self._update_info_box()
            self._update_target_class_combo()
            self._update_depth_slider()

            self.selected_tree_index = -1

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.forest_adapter = None
        self.ptrees = []
        self.grid_items = []
        self.grid.clear()

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    def update_depth(self):
        """When the max depth slider is changed."""
        for tree in self.ptrees:
            tree.set_depth_limit(self.depth_limit)

    def update_colors(self):
        """When the target class or coloring method is changed."""
        for tree in self.ptrees:
            tree.target_class_changed(self.target_class_index)

    def update_size_calc(self):
        """When the size calculation of the trees is changed."""
        if self.model is not None:
            with self._prevent_commit():
                self.grid.clear()
                self._draw_trees()
                # Keep the selected item
                if self.selected_tree_index != -1:
                    self.grid_items[self.selected_tree_index].setSelected(True)
                self.update_depth()

    def zoom_changed(self):
        """When we update the "Zoom" slider."""
        for item in self.grid_items:
            item.set_max_size(self._calculate_zoom(self.zoom))

        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

    @contextmanager
    def _prevent_commit(self):
        try:
            self.__prevent_commit = True
            yield
        finally:
            self.__prevent_commit = False

    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(
            len(self.forest_adapter.get_trees())))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    def _get_max_depth(self):
        return max(tree.tree_adapter.max_depth for tree in self.ptrees)

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    @contextmanager
    def disable_ui(self):
        """Temporarly disable the UI while trees may be redrawn."""
        try:
            self.ui_size_calc_combo.setEnabled(False)
            self.ui_depth_slider.setEnabled(False)
            self.ui_target_class_combo.setEnabled(False)
            self.ui_zoom_slider.setEnabled(False)
            yield
        finally:
            self.ui_size_calc_combo.setEnabled(True)
            self.ui_depth_slider.setEnabled(True)
            self.ui_target_class_combo.setEnabled(True)
            self.ui_zoom_slider.setEnabled(True)

    def _draw_trees(self):
        self.grid_items, self.ptrees = [], []

        num_trees = len(self.forest_adapter.get_trees())
        with self.progressBar(num_trees) as prg, self.disable_ui():
            for tree in self.forest_adapter.get_trees():
                ptree = PythagorasTreeViewer(
                    None,
                    tree,
                    interactive=False,
                    padding=100,
                    target_class_index=self.target_class_index,
                    weight_adjustment=self.SIZE_CALCULATION[
                        self.size_calc_idx][1])
                grid_item = GridItem(ptree,
                                     self.grid,
                                     max_size=self._calculate_zoom(self.zoom))
                # We don't want to show flickering while the trees are being
                grid_item.setVisible(False)

                self.grid_items.append(grid_item)
                self.ptrees.append(ptree)
                prg.advance()

            self.grid.set_items(self.grid_items)
            # This is necessary when adding items for the first time
            if self.grid:
                width = (self.view.width() -
                         self.view.verticalScrollBar().width())
                self.grid.reflow(width)
                self.grid.setPreferredWidth(width)
                # After drawing is complete, we show the trees
                for grid_item in self.grid_items:
                    grid_item.setVisible(True)

    @staticmethod
    def _calculate_zoom(zoom_level):
        """Calculate the max size for grid items from zoom level setting."""
        return zoom_level * 5

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected tree to output."""
        if self.__prevent_commit:
            return

        if not self.scene.selectedItems():
            self.send('Tree', None)
            # The selected tree index should only reset when model changes
            if self.model is None:
                self.selected_tree_index = -1
            return

        selected_item = self.scene.selectedItems()[0]
        self.selected_tree_index = self.grid_items.index(selected_item)
        tree = self.model.trees[self.selected_tree_index]
        tree.instances = self.instances
        tree.meta_target_class_index = self.target_class_index
        tree.meta_size_calc_idx = self.size_calc_idx
        tree.meta_depth_limit = self.depth_limit

        self.send('Tree', tree)

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_scene_rect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [
            x for x in self.ui_target_class_combo.parent().children()
            if isinstance(x, QLabel)
        ][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [
                c.title() for c in self.instances.domain.class_vars[0].values
            ]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.ui_target_class_combo.addItems(values)
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def resizeEvent(self, ev):
        width = (self.view.width() - self.view.verticalScrollBar().width())
        self.grid.reflow(width)
        self.grid.setPreferredWidth(width)

        super().resizeEvent(ev)
Ejemplo n.º 17
0
class TestAnchorLayout(QAppTestCase):
    def setUp(self):
        super().setUp()
        self.scene = CanvasScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.show()
        self.view.resize(600, 400)

    def tearDown(self):
        self.scene.clear()
        self.view.deleteLater()
        self.scene.deleteLater()
        del self.scene
        del self.view
        super().tearDown()

    def test_layout(self):
        one_desc, negate_desc, cons_desc = self.widget_desc()
        one_item = NodeItem()
        one_item.setWidgetDescription(one_desc)
        one_item.setPos(0, 150)
        self.scene.add_node_item(one_item)

        cons_item = NodeItem()
        cons_item.setWidgetDescription(cons_desc)
        cons_item.setPos(200, 0)
        self.scene.add_node_item(cons_item)

        negate_item = NodeItem()
        negate_item.setWidgetDescription(negate_desc)
        negate_item.setPos(200, 300)
        self.scene.add_node_item(negate_item)

        link = LinkItem()
        link.setSourceItem(one_item)
        link.setSinkItem(negate_item)
        self.scene.add_link_item(link)

        link = LinkItem()
        link.setSourceItem(one_item)
        link.setSinkItem(cons_item)
        self.scene.add_link_item(link)

        layout = AnchorLayout()
        self.scene.addItem(layout)
        self.scene.set_anchor_layout(layout)

        layout.invalidateNode(one_item)
        layout.activate()

        p1, p2 = one_item.outputAnchorItem.anchorPositions()
        self.assertTrue(p1 > p2)

        self.scene.node_item_position_changed.connect(layout.invalidateNode)

        path = QPainterPath()
        path.addEllipse(125, 0, 50, 300)

        def advance():
            t = time.process_time()
            cons_item.setPos(path.pointAtPercent(t % 1.0))
            negate_item.setPos(path.pointAtPercent((t + 0.5) % 1.0))

        timer = QTimer(negate_item, interval=20)
        timer.start()
        timer.timeout.connect(advance)
        self.app.exec_()

    def widget_desc(self):
        reg = small_testing_registry()
        one_desc = reg.widget("one")
        negate_desc = reg.widget("negate")
        cons_desc = reg.widget("cons")
        return one_desc, negate_desc, cons_desc