Exemple #1
0
class OWPredictions(widget.OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display the predictions of models for an input data set."
    inputs = [("Data", Orange.data.Table, "set_data"),
              ("Predictors", Model,
               "set_predictor", widget.Multiple)]
    outputs = [("Predictions", Orange.data.Table),
               ("Evaluation Results", Orange.evaluation.Results)]

    settingsHandler = settings.ClassValuesContextHandler()
    #: Display the full input dataset or only the target variable columns (if
    #: available)
    show_attrs = settings.Setting(True)
    #: Show predicted values (for discrete target variable)
    show_predictions = settings.Setting(True)
    #: Show predictions probabilities (for discrete target variable)
    show_probabilities = settings.Setting(True)
    #: List of selected class value indices in the "Show probabilities" list
    selected_classes = settings.ContextSetting([])
    #: Draw colored distribution bars
    draw_dist = settings.Setting(True)

    output_attrs = settings.Setting(True)
    output_predictions = settings.Setting(True)
    output_probabilities = settings.Setting(True)

    def __init__(self):
        super().__init__()

        #: Input data table
        self.data = None  # type: Optional[Orange.data.Table]
        #: A dict mapping input ids to PredictorSlot
        self.predictors = OrderedDict()  # type: Dict[object, PredictorSlot]
        #: A class variable (prediction target)
        self.class_var = None  # type: Optional[Orange.data.Variable]
        #: List of (discrete) class variable's values
        self.class_values = []  # type: List[str]

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(
            box, "No data on input.\nPredictors: 0\nTask: N/A")
        self.infolabel.setMinimumWidth(150)
        gui.button(box, self, "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        self.classification_options = box = gui.vBox(
            self.controlArea, "Options (classification)", spacing=-1,
            addSpace=False)

        gui.checkBox(box, self, "show_predictions", "Show predicted class",
                     callback=self._update_prediction_delegate)
        b = gui.checkBox(box, self, "show_probabilities",
                         "Show predicted probabilities",
                         callback=self._update_prediction_delegate)
        ibox = gui.indentedBox(box, sep=gui.checkButtonOffsetHint(b),
                               addSpace=False)
        gui.listBox(ibox, self, "selected_classes", "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QtGui.QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box, self, "draw_dist", "Draw distribution bars",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "Data View")
        gui.checkBox(box, self, "show_attrs", "Show full data set",
                     callback=self._update_column_visibility)

        box = gui.vBox(self.controlArea, "Output", spacing=-1)
        self.checkbox_class = gui.checkBox(
            box, self, "output_attrs", "Original data",
            callback=self.commit)
        self.checkbox_class = gui.checkBox(
            box, self, "output_predictions", "Predictions",
            callback=self.commit)
        self.checkbox_prob = gui.checkBox(
            box, self, "output_probabilities", "Probabilities",
            callback=self.commit)

        gui.rubber(self.controlArea)

        self.splitter = QtGui.QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = QtGui.QTableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QtGui.QTableView.ScrollPerPixel,
            selectionMode=QtGui.QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus
        )
        self.predictionsview = QtGui.QTableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QtGui.QTableView.ScrollPerPixel,
            selectionMode=QtGui.QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.predictionsview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size:
                self.predictionsview.verticalHeader()
                    .resizeSection(index, size)
        )

        self.splitter.addWidget(self.dataview)
        self.splitter.addWidget(self.predictionsview)

        self.mainArea.layout().addWidget(self.splitter)
        self.spliter_restore_state = int(self.show_attrs), 300

    @check_sql_input
    def set_data(self, data):
        """Set the input data set"""
        self.data = data
        if data is None:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
            self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            self._update_column_visibility()

        self.invalidate_predictions()

    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = \
                PredictorSlot(predictor, predictor.name, None)

        if predictor is not None:
            self.class_var = predictor.domain.class_var

    def handleNewSignals(self):
        self.error(0)
        if self.data is not None:
            for inputid, pred in list(self.predictors.items()):
                if pred.results is None or numpy.isnan(pred.results[0]).all():
                    try:
                        results = self.predict(pred.predictor, self.data)
                    except ValueError as err:
                        err_msg = '{}:\n'.format(pred.predictor.name) + \
                                  str(err)
                        self.error(0, err_msg)
                        n, m = len(self.data), 1
                        if self.data.domain.has_discrete_class:
                            m = len(self.data.domain.class_var.values)
                        probabilities = numpy.full((n, m), numpy.nan)
                        results = (numpy.full(n, numpy.nan), probabilities)
                    self.predictors[inputid] = pred._replace(results=results)

        if not self.predictors:
            self.class_var = None

        self.classification_options.setVisible(
            self.class_var is not None and self.class_var.is_discrete)

        self.closeContext()
        if self.class_var is not None and self.class_var.is_discrete:
            self.class_values = list(self.class_var.values)
            self.selected_classes = list(range(len(self.class_values)))
            self.openContext(self.class_var)
        else:
            self.class_values = []
            self.selected_classes = []

        self._update_predictions_model()
        self._update_prediction_delegate()
        # Check for prediction target consistency
        target_vars = set([p.predictor.domain.class_var
                           for p in self.predictors.values()])

        if len(target_vars) > 1:
            self.warning(0, "Inconsistent class variables")
        else:
            self.warning(0)

        # Update the Info box text.
        info = []
        if self.data is not None:
            info.append("Data: {} instances.".format(len(self.data)))
        else:
            info.append("Data: N/A")

        if self.predictors:
            info.append("Predictors: {}".format(len(self.predictors)))
        else:
            info.append("Predictors: N/A")

        if self.class_var is not None:
            if self.class_var.is_discrete:
                info.append("Task: Classification")
                self.checkbox_class.setEnabled(True)
                self.checkbox_prob.setEnabled(True)
            else:
                info.append("Task: Regression")
                self.checkbox_class.setEnabled(False)
                self.checkbox_prob.setEnabled(False)
        else:
            info.append("Task: N/A")

        self.infolabel.setText("\n".join(info))
        self.commit()

    def invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _update_predictions_model(self):
        """Update the prediction view model."""
        if self.data is not None:
            slots = self.predictors.values()
            results = []
            for p in slots:
                values, prob = p.results
                if p.predictor.domain.class_var.is_discrete:
                    values = [
                        Orange.data.Value(p.predictor.domain.class_var, v)
                        for v in values
                    ]
                results.append((values, prob))
            results = list(zip(*(zip(*res) for res in results)))

            headers = [p.name for p in slots]
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.predictionsview.setModel(predmodel)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)
        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_column_visibility(self):
        """Update data column visibility."""
        domain = self.data.domain
        first_attr = len(domain.class_vars) + len(domain.metas)
        if self.data is not None:
            for i in range(first_attr, first_attr + len(domain.attributes)):
                self.dataview.setColumnHidden(i, not self.show_attrs)
            if domain.class_var:
                self.dataview.setColumnHidden(0, False)
            self._update_spliter()

    def _update_data_sort_order(self):
        """Update data row order to match the current predictions view order"""
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort(
                    [predmodel.mapToSource(predmodel.index(i, 0)).row()
                     for i in range(n)])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        """Reset the row sorting to original input order."""
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        """Update the predicted probability visibility state"""
        delegate = PredictionsItemDelegate()
        colors = None
        if self.class_var is not None:
            if self.class_var.is_discrete:
                colors = [QtGui.QColor(*rgb) for rgb in self.class_var.colors]
                dist_fmt = ""
                pred_fmt = ""
                if self.show_probabilities:
                    decimals = 2
                    float_fmt = "{{dist[{}]:.{}f}}"
                    dist_fmt = " : ".join(
                        float_fmt.format(i, decimals)
                        for i in range(len(self.class_var.values))
                        if i in self.selected_classes
                    )
                if self.show_predictions:
                    pred_fmt = "{value!s}"
                if pred_fmt and dist_fmt:
                    fmt = dist_fmt + " \N{RIGHTWARDS ARROW} " + pred_fmt
                else:
                    fmt = dist_fmt or pred_fmt
            else:
                assert isinstance(self.class_var, ContinuousVariable)
                fmt = "{{value:.{}f}}".format(
                    self.class_var.number_of_decimals)

            delegate.setFormat(fmt)
            if self.draw_dist and colors is not None:
                delegate.setColors(colors)
            self.predictionsview.setItemDelegate(delegate)
            self.predictionsview.resizeColumnsToContents()

        if self.class_var is not None and self.class_var.is_discrete:
            proxy = self.predictionsview.model()
            if proxy is not None:
                proxy.setProbInd(numpy.array(self.selected_classes, dtype=int))

    def _update_spliter(self):
        if self.data is None:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        def widthForColumns(view, start=0, end=None):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            width = sum([h_header.sectionSize(i)
                         for i in range(h_header.count())[start: end]])
            return v_header.width() + width

        if not self.show_attrs:
            w1, w2 = self.splitter.sizes()
            # w = widthHint(self.dataview)
            w = width(self.dataview) + 4
            self.splitter.setSizes([w, w1 + w2 - w])
            self.dataview.setMaximumWidth(w)

            state, w = self.spliter_restore_state
            if state == 0:
                # save dataview width on change from 'show all' to 'hide all'
                self.spliter_restore_state = 1, w1
        else:
            w1, w2 = self.splitter.sizes()
            state, w = self.spliter_restore_state
            if state == 1:
                # restore dataview on change from 'hide all' to 'show all'
                # extend the dataview to the saved width but no further
                # then 2/3 of the available space
                w = min(w, (w1 + w2) * 2 // 3)
            else:
                # shrink the dataview width if its contents are smaller then
                # its width
                w1, w2 = self.splitter.sizes()
                w = widthForColumns(self.dataview, -2) + 4
                w = min(w,  (w1 + w2) // 2)
                predw = widthForColumns(self.predictionsview)
                w = max(w,  min(w1 + w2 - predw - 20, w1 + w2 - w))
            self.splitter.setSizes([w, w1 + w2 - w])
            self.dataview.setMaximumWidth(QWIDGETSIZE_MAX)

            self.spliter_restore_state = 0, w

    def commit(self):
        if self.data is None or not self.predictors:
            self.send("Predictions", None)
            self.send("Evaluation Results", None)
            return

        predictor = next(iter(self.predictors.values())).predictor
        class_var = predictor.domain.class_var
        classification = class_var and class_var.is_discrete

        newmetas = []
        newcolumns = []
        slots = list(self.predictors.values())

        if classification:
            if self.output_predictions:
                mc = [DiscreteVariable(name=p.name, values=class_var.values)
                      for p in slots]
                newmetas.extend(mc)
                newcolumns.extend(p.results[0].reshape((-1, 1))
                                  for p in slots)

            if self.output_probabilities:
                for p in slots:
                    m = [ContinuousVariable(name="%s(%s)" % (p.name, value))
                         for value in class_var.values]
                    newmetas.extend(m)
                newcolumns.extend(p.results[1] for p in slots)

        else:
            # regression
            mc = [ContinuousVariable(name=p.name)
                  for p in self.predictors.values()]
            newmetas.extend(mc)
            newcolumns.extend(p.results[0].reshape((-1, 1))
                              for p in slots)

        if self.output_attrs:
            attrs = list(self.data.domain.attributes)
        else:
            attrs = []
        metas = list(self.data.domain.metas) + newmetas

        domain = Orange.data.Domain(attrs, self.data.domain.class_var,
                                    metas=metas)
        predictions = self.data.from_table(domain, self.data)

        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns]
            )
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns

        results = None
        if self.data.domain.class_var == class_var:
            N = len(self.data)
            results = Orange.evaluation.Results(self.data, store_data=True)
            results.folds = None
            results.row_indices = numpy.arange(N)
            results.actual = self.data.Y.ravel()
            results.predicted = numpy.vstack(
                tuple(p.results[0] for p in slots))
            if classification:
                results.probabilities = numpy.array(
                    [p.results[1] for p in slots])
            results.learner_names = [p.name for p in slots]

        self.send("Predictions", predictions)
        self.send("Evaluation Results", results)

    @classmethod
    def predict(cls, predictor, data):
        class_var = predictor.domain.class_var
        if class_var:
            if class_var.is_discrete:
                return cls.predict_discrete(predictor, data)
            elif class_var.is_continuous:
                return cls.predict_continuous(predictor, data)

    @staticmethod
    def predict_discrete(predictor, data):
        return predictor(data, Model.ValueProbs)

    @staticmethod
    def predict_continuous(predictor, data):
        values = predictor(data, Model.Value)
        return values, [None] * len(data)
Exemple #2
0
class OWColor(widget.OWWidget):
    name = "Color"
    description = "Set color legend for variables."
    icon = "icons/Colors.svg"

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    settingsHandler = settings.PerfectDomainContextHandler(
        match_values=settings.PerfectDomainContextHandler.MATCH_VALUES_ALL)
    disc_descs = settings.ContextSetting([])
    cont_descs = settings.ContextSetting([])
    selected_schema_index = settings.Setting(0)
    auto_apply = settings.Setting(True)

    settings_version = 2

    want_main_area = False

    def __init__(self):
        super().__init__()
        self.data = None
        self.orig_domain = self.domain = None

        box = gui.hBox(self.controlArea, "Discrete Variables")
        self.disc_model = DiscColorTableModel()
        self.disc_view = DiscreteTable(self.disc_model)
        self.disc_model.dataChanged.connect(self._on_data_changed)
        box.layout().addWidget(self.disc_view)

        box = gui.hBox(self.controlArea, "Numeric Variables")
        self.cont_model = ContColorTableModel()
        self.cont_view = ContinuousTable(self.cont_model)
        self.cont_model.dataChanged.connect(self._on_data_changed)
        box.layout().addWidget(self.cont_view)

        box = gui.hBox(self.buttonsArea)
        gui.button(box, self, "Save", callback=self.save)
        gui.button(box, self, "Load", callback=self.load)
        gui.button(box, self, "Reset", callback=self.reset)
        gui.rubber(self.buttonsArea)
        gui.auto_apply(self.buttonsArea, self, "auto_apply")

    @staticmethod
    def sizeHint():  # pragma: no cover
        return QSize(500, 570)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.disc_descs = []
        self.cont_descs = []
        if data is None:
            self.data = self.domain = None
        else:
            self.data = data
            for var in chain(data.domain.variables, data.domain.metas):
                if var.is_discrete:
                    self.disc_descs.append(DiscAttrDesc(var))
                elif var.is_continuous:
                    self.cont_descs.append(ContAttrDesc(var))

        self.disc_model.set_data(self.disc_descs)
        self.cont_model.set_data(self.cont_descs)
        self.openContext(data)
        self.disc_view.resizeColumnsToContents()
        self.cont_view.resizeColumnsToContents()
        self.commit.now()

    def _on_data_changed(self):
        self.commit.deferred()

    def reset(self):
        self.disc_model.reset()
        self.cont_model.reset()
        # Reset button is in the same box as Load, which has commit.now,
        # and Apply, hence let Reset commit now, too.
        self.commit.now()

    def save(self):
        fname, _ = QFileDialog.getSaveFileName(
            self, "File name", self._start_dir(),
            "Variable definitions (*.colors)")
        if not fname:
            return
        QSettings().setValue("colorwidget/last-location",
                             os.path.split(fname)[0])
        self._save_var_defs(fname)

    def _save_var_defs(self, fname):
        with open(fname, "w") as f:
            json.dump(
                {vartype: {
                    var.name: var_data
                    for var, var_data in (
                        (desc.var, desc.to_dict()) for desc in repo)
                    if var_data}
                 for vartype, repo in (("categorical", self.disc_descs),
                                       ("numeric", self.cont_descs))
                },
                f,
                indent=4)

    def load(self):
        fname, _ = QFileDialog.getOpenFileName(
            self, "File name", self._start_dir(),
            "Variable definitions (*.colors)")
        if not fname:
            return

        try:
            with open(fname) as f:
                js = json.load(f)  #: dict
                self._parse_var_defs(js)
        except IOError:
            QMessageBox.critical(self, "File error", "File cannot be opened.")
            return
        except (json.JSONDecodeError, InvalidFileFormat):
            QMessageBox.critical(self, "File error", "Invalid file format.")
            return

    def _parse_var_defs(self, js):
        if not isinstance(js, dict) or set(js) != {"categorical", "numeric"}:
            raise InvalidFileFormat
        try:
            renames = {
                var_name: desc["rename"]
                for repo in js.values() for var_name, desc in repo.items()
                if "rename" in desc
            }
        # js is an object coming from json file that can be manipulated by
        # the user, so there are too many things that can go wrong.
        # Catch all exceptions, therefore.
        except Exception as exc:
            raise InvalidFileFormat from exc
        if not all(isinstance(val, str)
                   for val in chain(renames, renames.values())):
            raise InvalidFileFormat
        renamed_vars = {
            renames.get(desc.var.name, desc.var.name)
            for desc in chain(self.disc_descs, self.cont_descs)
        }
        if len(renamed_vars) != len(self.disc_descs) + len(self.cont_descs):
            QMessageBox.warning(
                self,
                "Duplicated variable names",
                "Variables will not be renamed due to duplicated names.")
            for repo in js.values():
                for desc in repo.values():
                    desc.pop("rename", None)

        # First, construct all descriptions; assign later, after we know
        # there won't be exceptions due to invalid file format
        unused_vars = []
        both_descs = []
        warnings = []
        for old_desc, repo, desc_type in (
                (self.disc_descs, "categorical", DiscAttrDesc),
                (self.cont_descs, "numeric", ContAttrDesc)):
            var_by_name = {desc.var.name: desc.var for desc in old_desc}
            new_descs = {}
            for var_name, var_data in js[repo].items():
                var = var_by_name.get(var_name)
                if var is None:
                    unused_vars.append(var_name)
                    continue
                # This can throw InvalidFileFormat
                new_descs[var_name], warn = desc_type.from_dict(var, var_data)
                warnings += warn
            both_descs.append(new_descs)
        if unused_vars:
            names = [f"'{name}'" for name in unused_vars]
            if len(unused_vars) == 1:
                warn = f'Definition for variable {names[0]}, which does not ' \
                       f'appear in the data, was ignored.\n'
            else:
                if len(unused_vars) <= 5:
                    warn = 'Definitions for variables ' \
                           f'{", ".join(names[:-1])} and {names[-1]}'
                else:
                    warn = f'Definitions for {", ".join(names[:4])} ' \
                           f'and {len(names) - 4} other variables'
                warn += ", which do not appear in the data, were ignored.\n"
            warnings.insert(0, warn)

        self.disc_descs = [both_descs[0].get(desc.var.name, desc)
                           for desc in self.disc_descs]
        self.cont_descs = [both_descs[1].get(desc.var.name, desc)
                           for desc in self.cont_descs]
        if warnings:
            QMessageBox.warning(
                self, "Invalid definitions", "\n".join(warnings))

        self.disc_model.set_data(self.disc_descs)
        self.cont_model.set_data(self.cont_descs)
        self.commit.now()

    def _start_dir(self):
        return self.workflowEnv().get("basedir") \
               or QSettings().value("colorwidget/last-location") \
               or os.path.expanduser(f"~{os.sep}")

    @gui.deferred
    def commit(self):
        def make(variables):
            new_vars = []
            for var in variables:
                source = disc_dict if var.is_discrete else cont_dict
                desc = source.get(var.name)
                new_vars.append(desc.create_variable() if desc else var)
            return new_vars

        if self.data is None:
            self.Outputs.data.send(None)
            return

        disc_dict = {desc.var.name: desc for desc in self.disc_descs}
        cont_dict = {desc.var.name: desc for desc in self.cont_descs}

        dom = self.data.domain
        new_domain = Orange.data.Domain(
            make(dom.attributes), make(dom.class_vars), make(dom.metas))
        new_data = self.data.transform(new_domain)
        self.Outputs.data.send(new_data)

    def send_report(self):
        """Send report"""
        def _report_variables(variables):
            def was(n, o):
                return n if n == o else f"{n} (was: {o})"

            max_values = max(
                (len(var.values) for var in variables if var.is_discrete),
                default=1)

            rows = ""
            disc_dict = {k.var.name: k for k in self.disc_descs}
            cont_dict = {k.var.name: k for k in self.cont_descs}
            for var in variables:
                if var.is_discrete:
                    desc = disc_dict[var.name]
                    value_cols = "    \n".join(
                        f"<td>{square(*color)} {was(value, old_value)}</td>"
                        for color, value, old_value in
                        zip(desc.colors, desc.values, var.values))
                elif var.is_continuous:
                    desc = cont_dict[var.name]
                    pal = colorpalettes.ContinuousPalettes[desc.palette_name]
                    value_cols = f'<td colspan="{max_values}">' \
                                 f'{pal.friendly_name}</td>'
                else:
                    continue
                names = was(desc.name, desc.var.name)
                rows += '<tr style="height: 2em">\n' \
                        f'  <th style="text-align: right">{names}</th>' \
                        f'  {value_cols}\n' \
                        '</tr>\n'
            return rows

        if not self.data:
            return
        dom = self.data.domain
        sections = (
            (name, _report_variables(variables))
            for name, variables in (
                ("Features", dom.attributes),
                ("Outcome" + "s" * (len(dom.class_vars) > 1), dom.class_vars),
                ("Meta attributes", dom.metas)))
        table = "".join(f"<tr><th>{name}</th></tr>{rows}"
                        for name, rows in sections if rows)
        if table:
            self.report_raw(f"<table>{table}</table>")

    @classmethod
    def migrate_context(cls, _, version):
        if not version or version < 2:
            raise IncompatibleContext
Exemple #3
0
class OWPredictions(OWWidget):
    name = "预测"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "显示输入数据集的模型预测。"
    keywords = []

    class Inputs:
        data = Input("数据", Orange.data.Table)
        predictors = Input("预测模型", Model, multiple=True)

    class Outputs:
        predictions = Output("预测数据", Orange.data.Table)
        evaluation_results = Output("评估结果",
                                    Orange.evaluation.Results,
                                    dynamic=False)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")

    class Error(OWWidget.Error):
        predictor_failed = \
            Msg("One or more predictors failed (see more...)\n{}")
        predictors_target_mismatch = \
            Msg("Predictors do not have the same target.")
        data_target_mismatch = \
            Msg("Data does not have the same target as predictors.")

    settingsHandler = settings.ClassValuesContextHandler()
    #: Display the full input dataset or only the target variable columns (if
    #: available)
    show_attrs = settings.Setting(True)
    #: Show predicted values (for discrete target variable)
    show_predictions = settings.Setting(True)
    #: Show predictions probabilities (for discrete target variable)
    show_probabilities = settings.Setting(True)
    #: List of selected class value indices in the "Show probabilities" list
    selected_classes = settings.ContextSetting([])
    #: Draw colored distribution bars
    draw_dist = settings.Setting(True)

    output_attrs = settings.Setting(True)
    output_predictions = settings.Setting(True)
    output_probabilities = settings.Setting(True)

    def __init__(self):
        super().__init__()

        #: Input data table
        self.data = None  # type: Optional[Orange.data.Table]
        #: A dict mapping input ids to PredictorSlot
        self.predictors = OrderedDict()  # type: Dict[object, PredictorSlot]
        #: A class variable (prediction target)
        self.class_var = None  # type: Optional[Orange.data.Variable]
        #: List of (discrete) class variable's values
        self.class_values = []  # type: List[str]

        box = gui.vBox(self.controlArea, "信息")
        self.infolabel = gui.widgetLabel(
            box, "没有输入数据。\n预测因子: 0\n任务: N/A")
        self.infolabel.setMinimumWidth(150)
        gui.button(box, self, "恢复原始顺序",
                   callback=self._reset_order,
                   tooltip="按原始顺序显示行")

        self.classification_options = box = gui.vBox(
            self.controlArea, "显示", spacing=-1, addSpace=False)

        gui.checkBox(box, self, "show_predictions", "预测类",
                     callback=self._update_prediction_delegate)
        b = gui.checkBox(box, self, "show_probabilities",
                         "预测概率:",
                         callback=self._update_prediction_delegate)
        ibox = gui.indentedBox(box, sep=gui.checkButtonOffsetHint(b),
                               addSpace=False)
        gui.listBox(ibox, self, "selected_classes", "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box, self, "draw_dist", "绘制分布条",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "数据视图")
        gui.checkBox(box, self, "show_attrs", "显示完整数据集",
                     callback=self._update_column_visibility)

        box = gui.vBox(self.controlArea, "输出", spacing=-1)
        self.checkbox_class = gui.checkBox(
            box, self, "output_attrs", "原始数据",
            callback=self.commit)
        self.checkbox_class = gui.checkBox(
            box, self, "output_predictions", "预测",
            callback=self.commit)
        self.checkbox_prob = gui.checkBox(
            box, self, "output_probabilities", "可能性",
            callback=self.commit)

        gui.rubber(self.controlArea)

        self.splitter = QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus
        )
        self.predictionsview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.dataview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size:
            self.predictionsview.verticalHeader().resizeSection(index, size)
        )

        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.mainArea.layout().addWidget(self.splitter)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set the input dataset"""
        if data is not None and not len(data):
            data = None
            self.Warning.empty_data()
        else:
            self.Warning.empty_data.clear()

        self.data = data
        if data is None:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
            self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            self._update_column_visibility()

        self._invalidate_predictions()

    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = \
                PredictorSlot(predictor, predictor.name, None)

    def set_class_var(self):
        pred_classes = set(pred.predictor.domain.class_var
                           for pred in self.predictors.values())
        self.Error.predictors_target_mismatch.clear()
        self.Error.data_target_mismatch.clear()
        self.class_var = None
        if len(pred_classes) > 1:
            self.Error.predictors_target_mismatch()
        if len(pred_classes) == 1:
            self.class_var = pred_classes.pop()
            if self.data is not None and \
                    self.data.domain.class_var is not None and \
                    self.class_var != self.data.domain.class_var:
                self.Error.data_target_mismatch()
                self.class_var = None

        discrete_class = self.class_var is not None \
                         and self.class_var.is_discrete
        self.classification_options.setVisible(discrete_class)
        self.closeContext()
        if discrete_class:
            self.class_values = list(self.class_var.values)
            self.selected_classes = list(range(len(self.class_values)))
            self.openContext(self.class_var)
        else:
            self.class_values = []
            self.selected_classes = []

    def handleNewSignals(self):
        self.set_class_var()
        if self.data is not None:
            self._call_predictors()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._update_info()
        self.commit()

    def _call_predictors(self):
        for inputid, pred in self.predictors.items():
            if pred.results is None \
                    or isinstance(pred.results, str) \
                    or numpy.isnan(pred.results[0]).all():
                try:
                    results = self.predict(pred.predictor, self.data)
                except (ValueError, DomainTransformationError) as err:
                    results = "{}: {}".format(pred.predictor.name, err)
                self.predictors[inputid] = pred._replace(results=results)

    def _set_errors(self):
        errors = "\n".join(p.results for p in self.predictors.values()
                           if isinstance(p.results, str))
        if errors:
            self.Error.predictor_failed(errors)
        else:
            self.Error.predictor_failed.clear()

    def _update_info(self):
        info = []
        if self.data is not None:
            info.append("Data: {} instances.".format(len(self.data)))
        else:
            info.append("Data: N/A")

        n_predictors = len(self.predictors)
        n_valid = len(self._valid_predictors())
        if n_valid != n_predictors:
            info.append("Predictors: {} (+ {} failed)".format(
                n_valid, n_predictors - n_valid))
        else:
            info.append("Predictors: {}".format(n_predictors or "N/A"))

        if self.class_var is None:
            info.append("Task: N/A")
        elif self.class_var.is_discrete:
            info.append("Task: Classification")
            self.checkbox_class.setEnabled(True)
            self.checkbox_prob.setEnabled(True)
        else:
            info.append("Task: Regression")
            self.checkbox_class.setEnabled(False)
            self.checkbox_prob.setEnabled(False)

        self.infolabel.setText("\n".join(info))

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _valid_predictors(self):
        if self.class_var is not None and \
                self.data is not None:
            return [p for p in self.predictors.values()
                    if p.results is not None and not isinstance(p.results, str)]
        else:
            return []

    def _update_predictions_model(self):
        """Update the prediction view model."""
        if self.data is not None and self.class_var is not None:
            slots = self._valid_predictors()
            results = []
            class_var = self.class_var
            for p in slots:
                values, prob = p.results
                if self.class_var.is_discrete:
                    # if values were added to class_var between building the
                    # model and predicting, add zeros for new class values,
                    # which are always at the end
                    prob = numpy.c_[
                        prob,
                        numpy.zeros((prob.shape[0], len(class_var.values) - prob.shape[1]))]
                    values = [Value(class_var, v) for v in values]
                results.append((values, prob))
            results = list(zip(*(zip(*res) for res in results)))
            headers = [p.name for p in slots]
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.predictionsview.setModel(predmodel)
        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_column_visibility(self):
        """Update data column visibility."""
        if self.data is not None and self.class_var is not None:
            domain = self.data.domain
            first_attr = len(domain.class_vars) + len(domain.metas)

            for i in range(first_attr, first_attr + len(domain.attributes)):
                self.dataview.setColumnHidden(i, not self.show_attrs)
            if domain.class_var:
                self.dataview.setColumnHidden(0, False)

    def _update_data_sort_order(self):
        """Update data row order to match the current predictions view order"""
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort(
                    [predmodel.mapToSource(predmodel.index(i, 0)).row()
                     for i in range(n)])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        """Reset the row sorting to original input order."""
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        """Update the predicted probability visibility state"""
        if self.class_var is not None:
            delegate = PredictionsItemDelegate()
            if self.class_var.is_continuous:
                self._setup_delegate_continuous(delegate)
            else:
                self._setup_delegate_discrete(delegate)
                proxy = self.predictionsview.model()
                if proxy is not None:
                    proxy.setProbInd(
                        numpy.array(self.selected_classes, dtype=int))
            self.predictionsview.setItemDelegate(delegate)
            self.predictionsview.resizeColumnsToContents()
        self._update_spliter()

    def _setup_delegate_discrete(self, delegate):
        colors = [QtGui.QColor(*rgb) for rgb in self.class_var.colors]
        fmt = []
        if self.show_probabilities:
            fmt.append(" : ".join("{{dist[{}]:.2f}}".format(i)
                                  for i in sorted(self.selected_classes)))
        if self.show_predictions:
            fmt.append("{value!s}")
        delegate.setFormat(" \N{RIGHTWARDS ARROW} ".join(fmt))
        if self.draw_dist and colors is not None:
            delegate.setColors(colors)
        return delegate

    def _setup_delegate_continuous(self, delegate):
        delegate.setFormat(
            "{{value:.{}f}}".format(self.class_var.number_of_decimals))

    def _update_spliter(self):
        if self.data is None:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        w = width(self.predictionsview) + 4
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([w, w1 + w2 - w])

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = self._valid_predictors()
        if not slots or self.data.domain.class_var is None:
            self.Outputs.evaluation_results.send(None)
            return

        class_var = self.class_var
        nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
        data = self.data[~nanmask]
        N = len(data)
        results = Orange.evaluation.Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(N)
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results[0][~nanmask] for p in slots))
        if class_var and class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results[1][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        slots = self._valid_predictors()
        if not slots:
            self.Outputs.predictions.send(None)
            return

        if self.class_var and self.class_var.is_discrete:
            newmetas, newcolumns = self._classification_output_columns()
        else:
            newmetas, newcolumns = self._regression_output_columns()

        attrs = list(self.data.domain.attributes) if self.output_attrs else []
        metas = list(self.data.domain.metas) + newmetas
        domain = \
            Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns
        self.Outputs.predictions.send(predictions)

    def _classification_output_columns(self):
        newmetas = []
        newcolumns = []
        slots = self._valid_predictors()
        if self.output_predictions:
            newmetas += [DiscreteVariable(name=p.name, values=self.class_values)
                         for p in slots]
            newcolumns += [p.results[0].reshape((-1, 1)) for p in slots]

        if self.output_probabilities:
            newmetas += [ContinuousVariable(name="%s (%s)" % (p.name, value))
                         for p in slots for value in self.class_values]
            newcolumns += [p.results[1] for p in slots]
        return newmetas, newcolumns

    def _regression_output_columns(self):
        slots = self._valid_predictors()
        newmetas = [ContinuousVariable(name=p.name) for p in slots]
        newcolumns = [p.results[0].reshape((-1, 1)) for p in slots]
        return newmetas, newcolumns

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_model = self.predictionsview.model()

            # use ItemDelegate to style prediction values
            style = lambda x: self.predictionsview.itemDelegate().displayText(x, QLocale())

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(filter(lambda x: not self.dataview.isColumnHidden(x),
                                         range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [style(predictions_model.data(predictions_model.index(i, j)))
                       for j in range(predictions_model.columnCount())] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data is not None and self.class_var is not None:
            text = self.infolabel.text().replace('\n', '<br>')
            if self.show_probabilities and self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '. join([self.class_values[i]
                                    for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions", merge_data_with_predictions(),
                              header_rows=1, header_columns=1)

    @classmethod
    def predict(cls, predictor, data):
        class_var = predictor.domain.class_var
        if class_var:
            if class_var.is_discrete:
                return cls.predict_discrete(predictor, data)
            elif class_var.is_continuous:
                return cls.predict_continuous(predictor, data)

    @staticmethod
    def predict_discrete(predictor, data):
        return predictor(data, Model.ValueProbs)

    @staticmethod
    def predict_continuous(predictor, data):
        values = predictor(data, Model.Value)
        return values, [None] * len(data)
class OWSilhouettePlot(widget.OWWidget):
    name = "Silhouette Plot"
    description = "Silhouette Plot"

    icon = "icons/Silhouette.svg"

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Selected Data", Orange.data.Table, widget.Default),
               ("Other Data", Orange.data.Table)]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Distance metric index
    distance_idx = settings.Setting(0)
    #: Group/cluster variable index
    cluster_var_idx = settings.ContextSetting(0)
    #: Annotation variable index
    annotation_var_idx = settings.ContextSetting(0)
    #: Group the silhouettes by cluster
    group_by_cluster = settings.Setting(True)
    #: A fixed size for an instance bar
    bar_size = settings.Setting(3)
    #: Add silhouette scores to output data
    add_scores = settings.Setting(False)
    auto_commit = settings.Setting(False)

    Distances = [("Euclidean", Orange.distance.Euclidean),
                 ("Manhattan", Orange.distance.Manhattan)]

    def __init__(self):
        super().__init__()

        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self._silplot = None

        box = gui.widgetBox(
            self.controlArea,
            "Settings",
        )
        gui.comboBox(box,
                     self,
                     "distance_idx",
                     label="Distance",
                     items=[name for name, _ in OWSilhouettePlot.Distances],
                     callback=self._invalidate_distances)
        self.cluster_var_cb = gui.comboBox(box,
                                           self,
                                           "cluster_var_idx",
                                           label="Cluster",
                                           callback=self._invalidate_scores)
        self.cluster_var_model = itemmodels.VariableListModel(parent=self)
        self.cluster_var_cb.setModel(self.cluster_var_model)

        gui.spin(box,
                 self,
                 "bar_size",
                 minv=1,
                 maxv=10,
                 label="Bar Size",
                 callback=self._update_bar_size)

        gui.checkBox(box,
                     self,
                     "group_by_cluster",
                     "Group by cluster",
                     callback=self._replot)

        self.annotation_cb = gui.comboBox(box,
                                          self,
                                          "annotation_var_idx",
                                          label="Annotations",
                                          callback=self._update_annotations)
        self.annotation_var_model = itemmodels.VariableListModel(parent=self)
        self.annotation_var_model[:] = ["None"]
        self.annotation_cb.setModel(self.annotation_var_model)

        gui.rubber(self.controlArea)

        box = gui.widgetBox(self.controlArea, "Output")
        gui.checkBox(
            box,
            self,
            "add_scores",
            "Add silhouette scores",
        )
        gui.auto_commit(box, self, "auto_commit", "Commit", box=False)

        self.scene = QtGui.QGraphicsScene()
        self.view = QtGui.QGraphicsView(self.scene)
        self.view.setRenderHint(QtGui.QPainter.Antialiasing, True)
        self.view.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.mainArea.layout().addWidget(self.view)

    def sizeHint(self):
        sh = self.controlArea.sizeHint()
        return sh.expandedTo(QtCore.QSize(600, 720))

    def set_data(self, data):
        """
        Set the input data set.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        warning_msg = ""
        candidatevars = []
        if data is not None:
            candidatevars = [
                v for v in data.domain.variables + data.domain.metas
                if v.is_discrete and len(v.values) >= 2
            ]
            if not candidatevars:
                error_msg = "Input does not have any suitable cluster labels."
                data = None

        if data is not None:
            ncont = sum(v.is_continuous for v in data.domain.attributes)
            ndiscrete = len(data.domain.attributes) - ncont
            if ncont == 0:
                data = None
                error_msg = "No continuous columns"
            elif ncont < len(data.domain.attributes):
                warning_msg = "{0} discrete columns will not be used for " \
                              "distance computation".format(ndiscrete)

        self.data = data
        if data is not None:
            self.cluster_var_model[:] = candidatevars
            if data.domain.class_var in candidatevars:
                self.cluster_var_idx = candidatevars.index(
                    data.domain.class_var)
            else:
                self.cluster_var_idx = 0

            annotvars = [var for var in data.domain.metas if var.is_string]
            self.annotation_var_model[:] = ["None"] + annotvars
            self.annotation_var_idx = 1 if len(annotvars) else 0
            self._effective_data = Orange.distance._preprocess(data)
            self.openContext(Orange.data.Domain(candidatevars))

        self.error(0, error_msg)
        self.warning(0, warning_msg)

    def handleNewSignals(self):
        if self._effective_data is not None:
            self._update()
            self._replot()

        self.unconditional_commit()

    def clear(self):
        """
        Clear the widget state.
        """
        self.data = None
        self._effective_data = None
        self._matrix = None
        self._silhouette = None
        self._labels = None
        self.cluster_var_model[:] = []
        self.annotation_var_model[:] = ["None"]
        self._clear_scene()

    def _clear_scene(self):
        # Clear the graphics scene and associated objects
        self.scene.clear()
        self.scene.setSceneRect(QRectF())
        self._silplot = None

    def _invalidate_distances(self):
        # Invalidate the computed distance matrix and recompute the silhouette.
        self._matrix = None
        self._invalidate_scores()

    def _invalidate_scores(self):
        # Invalidate and recompute the current silhouette scores.
        self._labels = self._silhouette = None
        self._update()
        self._replot()
        if self.data is not None:
            self.commit()

    def _update(self):
        # Update/recompute the distances/scores as required
        if self.data is None:
            self._silhouette = None
            self._labels = None
            self._matrix = None
            self._clear_scene()
            return

        if self._matrix is None and self._effective_data is not None:
            _, metric = self.Distances[self.distance_idx]
            self._matrix = numpy.asarray(metric(self._effective_data))

        labelvar = self.cluster_var_model[self.cluster_var_idx]
        labels, _ = self.data.get_column_view(labelvar)
        labels = labels.astype(int)
        _, counts = numpy.unique(labels, return_counts=True)
        if numpy.count_nonzero(counts) >= 2:
            self.error(1, "")
            silhouette = sklearn.metrics.silhouette_samples(
                self._matrix, labels, metric="precomputed")
        else:
            self.error(1, "Need at least 2 clusters with non zero counts")
            labels = silhouette = None

        self._labels = labels
        self._silhouette = silhouette

    def _replot(self):
        # Clear and replot/initialize the scene
        self._clear_scene()
        if self._silhouette is not None and self._labels is not None:
            var = self.cluster_var_model[self.cluster_var_idx]
            silplot = SilhouettePlot()
            silplot.setBarHeight(self.bar_size)
            silplot.setRowNamesVisible(self.bar_size >= 5)

            if self.group_by_cluster:
                silplot.setScores(self._silhouette, self._labels, var.values)
            else:
                silplot.setScores(
                    self._silhouette,
                    numpy.zeros(len(self._silhouette), dtype=int), [""])

            self.scene.addItem(silplot)
            self._silplot = silplot
            self._update_annotations()

            silplot.resize(silplot.effectiveSizeHint(Qt.PreferredSize))
            silplot.selectionChanged.connect(self.commit)

            self.scene.setSceneRect(
                QRectF(QtCore.QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_bar_size(self):
        if self._silplot is not None:
            self._silplot.setBarHeight(self.bar_size)
            self._silplot.setRowNamesVisible(self.bar_size >= 5)

            self.scene.setSceneRect(
                QRectF(QtCore.QPointF(0, 0),
                       self._silplot.effectiveSizeHint(Qt.PreferredSize)))

    def _update_annotations(self):
        if 0 < self.annotation_var_idx < len(self.annotation_var_model):
            annot_var = self.annotation_var_model[self.annotation_var_idx]
        else:
            annot_var = None

        if self._silplot is not None:
            if annot_var is not None:
                column, _ = self.data.get_column_view(annot_var)
                self._silplot.setRowNames(
                    [annot_var.str_val(value) for value in column])
            else:
                self._silplot.setRowNames(None)

    def commit(self):
        """
        Commit/send the current selection to the output.
        """
        selected = other = None
        if self.data is not None:
            selectedmask = numpy.full(len(self.data), False, dtype=bool)
            if self._silplot is not None:
                indices = self._silplot.selection()
                selectedmask[indices] = True
            scores = self._silhouette
            silhouette_var = None
            if self.add_scores:
                var = self.cluster_var_model[self.cluster_var_idx]
                silhouette_var = Orange.data.ContinuousVariable(
                    "Silhouette ({})".format(escape(var.name)))
                domain = Orange.data.Domain(
                    self.data.domain.attributes, self.data.domain.class_vars,
                    self.data.domain.metas + (silhouette_var, ))
            else:
                domain = self.data.domain

            if numpy.count_nonzero(selectedmask):
                selected = self.data.from_table(
                    domain, self.data, numpy.flatnonzero(selectedmask))

            if numpy.count_nonzero(~selectedmask):
                other = self.data.from_table(domain, self.data,
                                             numpy.flatnonzero(~selectedmask))

            if self.add_scores:
                if selected is not None:
                    selected[:,
                             silhouette_var] = numpy.c_[scores[selectedmask]]
                if other is not None:
                    other[:, silhouette_var] = numpy.c_[scores[~selectedmask]]

        self.send("Selected Data", selected)
        self.send("Other Data", other)

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
Exemple #5
0
class OWFreeViz(widget.OWWidget):
    name = "FreeViz"
    description = "FreeViz Visualization"
    icon = "icons/LinearProjection.svg"
    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Data", Orange.data.Table, widget.Default),
               ("Selected Data", Orange.data.Table),
               ("Components", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()
    #: Initialization type
    Circular, Random = 0, 1
    #: Force law
    ForceLaw = [("Linear", 1), ("Square", 2)]

    ReplotIntervals = [
        ("Every iteration", 1),
        ("Every 3 steps", 3),
        (
            "Every 5 steps",
            5,
        ),
        ("Every 10 steps", 10),
        ("Every 20 steps", 20),
        ("Every 50 steps", 50),
        ("Every 100 steps", 100),
        ("None", -1),
    ]
    JitterAmount = [("None", 0), ("0.1%", 0.1), ("0.5%", 0.5), ("1%", 1.0),
                    ("2%", 2.0)]

    #: Output coordinate embedding domain role
    NoCoords, Attribute, Meta = 0, 1, 2

    force_law = settings.Setting(0)
    maxiter = settings.Setting(300)
    replot_interval = settings.Setting(3)
    initialization = settings.Setting(Circular)
    min_anchor_radius = settings.Setting(0)
    embedding_domain_role = settings.Setting(Meta)
    autocommit = settings.Setting(True)

    color_var = settings.ContextSetting("", exclude_metas=False)
    shape_var = settings.ContextSetting("", exclude_metas=False)
    size_var = settings.ContextSetting("", exclude_metas=False)
    label_var = settings.ContextSetting("", exclude_metas=False)

    opacity = settings.Setting(255)
    point_size = settings.Setting(5)
    jitter = settings.Setting(0)
    class_density = settings.Setting(False)

    def __init__(self):
        super().__init__()

        self.data = None
        self.plotdata = None

        box = gui.widgetBox(self.controlArea, "Optimization", spacing=10)
        form = QtGui.QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QtGui.QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10)
        form.addRow(
            "Force law",
            gui.comboBox(box,
                         self,
                         "force_law",
                         items=[text for text, _ in OWFreeViz.ForceLaw],
                         callback=self.__reset_update_interval))
        form.addRow("Max iterations", gui.spin(box, self, "maxiter", 10,
                                               10**4))
        form.addRow(
            "Initialization",
            gui.comboBox(box,
                         self,
                         "initialization",
                         items=["Circular", "Random"],
                         callback=self.__reset_initialization))
        form.addRow(
            "Replot",
            gui.comboBox(box,
                         self,
                         "replot_interval",
                         items=[text for text, _ in OWFreeViz.ReplotIntervals],
                         callback=self.__reset_update_interval))
        box.layout().addLayout(form)

        self.start_button = gui.button(box, self, "Optimize",
                                       self._toogle_start)

        self.color_varmodel = itemmodels.VariableListModel(parent=self)
        self.shape_varmodel = itemmodels.VariableListModel(parent=self)
        self.size_varmodel = itemmodels.VariableListModel(parent=self)
        self.label_varmodel = itemmodels.VariableListModel(parent=self)

        box = gui.widgetBox(self.controlArea, "Plot")
        form = QtGui.QFormLayout(
            formAlignment=Qt.AlignLeft,
            labelAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QtGui.QFormLayout.AllNonFixedFieldsGrow,
            spacing=8,
        )
        box.layout().addLayout(form)
        color_cb = gui.comboBox(box,
                                self,
                                "color_var",
                                sendSelectedValue=True,
                                emptyString="(Same color)",
                                contentsLength=10,
                                callback=self._update_color)

        color_cb.setModel(self.color_varmodel)
        form.addRow("Color", color_cb)
        opacity_slider = gui.hSlider(box,
                                     self,
                                     "opacity",
                                     minValue=50,
                                     maxValue=255,
                                     ticks=True,
                                     createLabel=False,
                                     callback=self._update_color)
        opacity_slider.setTickInterval(0)
        opacity_slider.setPageStep(10)
        form.addRow("Opacity", opacity_slider)

        shape_cb = gui.comboBox(box,
                                self,
                                "shape_var",
                                contentsLength=10,
                                sendSelectedValue=True,
                                emptyString="(Same shape)",
                                callback=self._update_shape)
        shape_cb.setModel(self.shape_varmodel)
        form.addRow("Shape", shape_cb)

        size_cb = gui.comboBox(box,
                               self,
                               "size_var",
                               contentsLength=10,
                               sendSelectedValue=True,
                               emptyString="(Same size)",
                               callback=self._update_size)
        size_cb.setModel(self.size_varmodel)
        form.addRow("Size", size_cb)
        size_slider = gui.hSlider(box,
                                  self,
                                  "point_size",
                                  minValue=3,
                                  maxValue=20,
                                  ticks=True,
                                  createLabel=False,
                                  callback=self._update_size)
        form.addRow(None, size_slider)

        label_cb = gui.comboBox(box,
                                self,
                                "label_var",
                                contentsLength=10,
                                sendSelectedValue=True,
                                emptyString="(No labels)",
                                callback=self._update_labels)
        label_cb.setModel(self.label_varmodel)
        form.addRow("Label", label_cb)

        form.addRow(
            "Jitter",
            gui.comboBox(box,
                         self,
                         "jitter",
                         items=[text for text, _ in self.JitterAmount],
                         callback=self._update_xy))
        self.class_density_cb = gui.checkBox(box,
                                             self,
                                             "class_density",
                                             "",
                                             callback=self._update_density)
        form.addRow("Class density", self.class_density_cb)

        box = gui.widgetBox(self.controlArea, "Hide anchors")
        rslider = gui.hSlider(box,
                              self,
                              "min_anchor_radius",
                              minValue=0,
                              maxValue=100,
                              step=5,
                              label="Hide radius",
                              createLabel=False,
                              ticks=True,
                              callback=self._update_anchor_visibility)
        rslider.setTickInterval(0)
        rslider.setPageStep(10)

        box = gui.widgetBox(self.controlArea, "Zoom/Select")
        hlayout = QtGui.QHBoxLayout()
        box.layout().addLayout(hlayout)

        toolbox = PlotToolBox(self)
        hlayout.addWidget(toolbox.button(PlotToolBox.SelectTool))
        hlayout.addWidget(toolbox.button(PlotToolBox.ZoomTool))
        hlayout.addWidget(toolbox.button(PlotToolBox.PanTool))
        hlayout.addSpacing(2)
        hlayout.addWidget(toolbox.button(PlotToolBox.ZoomReset))
        toolbox.standardAction(PlotToolBox.ZoomReset).triggered.connect(
            lambda: self.plot.setRange(QtCore.QRectF(-1.05, -1.05, 2.1, 2.1)))
        toolbox.standardAction(PlotToolBox.ZoomIn).triggered.connect(
            lambda: self.plot.getViewBox().scaleBy((1.25, 1.25)))
        toolbox.standardAction(PlotToolBox.ZoomIn).triggered.connect(
            lambda: self.plot.getViewBox().scaleBy((1 / 1.25, 1 / 1.25)))
        selecttool = toolbox.plotTool(PlotToolBox.SelectTool)
        selecttool.selectionFinished.connect(self.__select_area)
        self.addActions(toolbox.actions())

        self.controlArea.layout().addStretch(1)

        box = gui.widgetBox(self.controlArea, "Output")
        gui.comboBox(box,
                     self,
                     "embedding_domain_role",
                     items=[
                         "Original features only", "Coordinates as features",
                         "Coordinates as meta attributes"
                     ])
        gui.auto_commit(box,
                        self,
                        "autocommit",
                        "Commit",
                        box=False,
                        callback=self.commit)

        self.plot = pg.PlotWidget(enableMouse=False, enableMenu=False)
        self.plot.setFrameStyle(QtGui.QFrame.StyledPanel)
        self.plot.plotItem.hideAxis("bottom")
        self.plot.plotItem.hideAxis("left")
        self.plot.plotItem.hideButtons()
        self.plot.setAspectLocked(True)
        self.plot.scene().installEventFilter(self)

        self.legend = linproj.LegendItem()
        self.legend.setParentItem(self.plot.getViewBox())
        self.legend.anchor((1, 0), (1, 0))

        self.plot.setRenderHint(QtGui.QPainter.Antialiasing, True)
        self.mainArea.layout().addWidget(self.plot)
        viewbox = self.plot.getViewBox()
        viewbox.grabGesture(Qt.PinchGesture)
        pinchtool = linproj.PlotPinchZoomTool(parent=self)
        pinchtool.setViewBox(viewbox)

        toolbox.setViewBox(viewbox)

        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

    def clear(self):
        """
        Clear/reset the widget state
        """
        self.data = None
        self._clear_plot()
        self._loop.cancel()

        self.color_varmodel[:] = ["(Same color)"]
        self.shape_varmodel[:] = ["(Same shape)"]
        self.size_varmodel[:] = ["(Same size)"]
        self.label_varmodel[:] = ["(No labels)"]
        self.color_var = self.shape_var = self.size_var = self.label_var = ""

    def set_data(self, data):
        """
        Set the input dataset.
        """
        self.closeContext()
        self.clear()
        error_msg = ""
        if data is not None:
            if data.domain.class_var is None:
                error_msg = "Need a class variable"
                data = None
            elif data.domain.class_var.is_discrete and \
                    len(data.domain.class_var.values) < 2:
                error_msg = "Needs discrete class variable with at" \
                            " lest 2 values"
                data = None

        self.data = data
        self.error(0, error_msg)
        if data is not None:
            separator = itemmodels.VariableListModel.Separator
            domain = data.domain
            colorvars = ["(Same color)"] + list(domain)
            colorvars_meta = [
                var for var in domain.metas if var.is_primitive()
            ]
            if colorvars_meta:
                colorvars += [separator] + colorvars_meta
            self.color_varmodel[:] = colorvars
            self.color_var = domain.class_var.name

            def is_discrete(var):
                return var.is_discrete

            def is_continuous(var):
                return var.is_continuous

            def is_string(var):
                return var.is_string

            def filter_(func, iterable):
                return list(filter(func, iterable))

            maxsymbols = len(linproj.ScatterPlotItem.Symbols) - 1

            def can_be_shape(var):
                return is_discrete(var) and len(var.values) < maxsymbols

            shapevars = ["(Same shape)"] + filter_(can_be_shape, domain)
            shapevars_meta = filter_(can_be_shape, domain.metas)
            if shapevars_meta:
                shapevars += [separator] + shapevars_meta
            self.shape_varmodel[:] = shapevars

            sizevars = ["(Same size)"] + filter_(is_continuous, domain)
            sizevars_meta = filter_(is_continuous, domain.metas)
            if sizevars_meta:
                sizevars += [separator] + sizevars_meta
            self.size_varmodel[:] = sizevars

            labelvars = ["(No labels)"]
            labelvars_meta = filter_(is_string, domain.metas)
            if labelvars_meta:
                labelvars += [separator] + labelvars_meta

            self.label_varmodel[:] = labelvars

            self.class_density_cb.setEnabled(domain.has_discrete_class)
            self.openContext(data)

    def handleNewSignals(self):
        """Reimplemented."""
        if self.data is not None:
            self._setup()
            self._start()

    def _toogle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            self.start_button.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def _clear_plot(self):
        self.plot.clear()
        self.plotdata = None
        self.legend.hide()
        self.legend.clear()

    def _setup(self):
        """
        Setup the plot.
        """
        X = self.data.X
        Y = self.data.Y
        mask = numpy.bitwise_or.reduce(numpy.isnan(X), axis=1)
        mask |= numpy.isnan(Y)
        valid = ~mask
        X = X[valid, :]
        Y = Y[valid]

        if self.data.domain.class_var.is_discrete:
            Y = Y.astype(int)
        X = (X - numpy.mean(X, axis=0))
        span = numpy.ptp(X, axis=0)
        X[:, span > 0] /= span[span > 0].reshape(1, -1)

        if self.initialization == OWFreeViz.Circular:
            anchors = linproj.linproj.defaultaxes(X.shape[1]).T
        else:
            anchors = numpy.random.random((X.shape[1], 2)) * 2 - 1

        EX = numpy.dot(X, anchors)
        radius = numpy.max(numpy.linalg.norm(EX, axis=1))

        jittervec = numpy.random.RandomState(4).rand(*EX.shape) * 2 - 1
        jittervec *= 0.01
        _, jitterfactor = self.JitterAmount[self.jitter]

        colorvar = self._color_var()
        shapevar = self._shape_var()
        sizevar = self._size_var()
        labelvar = self._label_var()

        if colorvar is not None:
            colors = plotutils.color_data(self.data, colorvar)[valid]
        else:
            colors = numpy.array([[192, 192, 192]])
            colors = numpy.tile(colors, (X.shape[0], 1))

        pendata = plotutils.pen_data(colors * 0.8)
        colors = numpy.hstack(
            [colors,
             numpy.full((colors.shape[0], 1), float(self.opacity))])
        brushdata = plotutils.brush_data(colors)

        shapedata = plotutils.shape_data(self.data, shapevar)[valid]
        sizedata = size_data(self.data, sizevar,
                             pointsize=self.point_size)[valid]
        if labelvar is not None:
            labeldata = plotutils.column_data(self.data, labelvar, valid)
            labeldata = [labelvar.str_val(val) for val in labeldata]
        else:
            labeldata = None

        coords = (EX / radius) + jittervec * jitterfactor
        item = linproj.ScatterPlotItem(
            x=coords[:, 0],
            y=coords[:, 1],
            brush=brushdata,
            pen=pendata,
            symbols=shapedata,
            size=sizedata,
            data=numpy.flatnonzero(valid),
            antialias=True,
        )

        self.plot.addItem(item)
        self.plot.setRange(QtCore.QRectF(-1.05, -1.05, 2.1, 2.1))

        # minimum visible anchor radius
        minradius = self.min_anchor_radius / 100 + 1e-5
        axisitems = []
        for anchor, var in zip(anchors, self.data.domain.attributes):
            axitem = AxisItem(
                line=QtCore.QLineF(0, 0, *anchor),
                label=var.name,
            )
            axitem.setVisible(numpy.linalg.norm(anchor) > minradius)
            axitem.setPen(pg.mkPen((100, 100, 100)))
            axitem.setArrowVisible(False)
            self.plot.addItem(axitem)
            axisitems.append(axitem)

        hidecircle = QtGui.QGraphicsEllipseItem()
        hidecircle.setRect(
            QtCore.QRectF(-minradius, -minradius, 2 * minradius,
                          2 * minradius))

        _pen = QtGui.QPen(Qt.lightGray, 1)
        _pen.setCosmetic(True)
        hidecircle.setPen(_pen)

        self.plot.addItem(hidecircle)

        self.plotdata = namespace(validmask=valid,
                                  embedding_coords=EX,
                                  jittervec=jittervec,
                                  anchors=anchors,
                                  mainitem=item,
                                  axisitems=axisitems,
                                  hidecircle=hidecircle,
                                  basecolors=colors,
                                  brushdata=brushdata,
                                  pendata=pendata,
                                  shapedata=shapedata,
                                  sizedata=sizedata,
                                  labeldata=labeldata,
                                  labelitems=[],
                                  densityimage=None,
                                  X=X,
                                  Y=Y,
                                  selectionmask=numpy.zeros_like(valid,
                                                                 dtype=bool))
        self._update_legend()
        self._update_labels()
        self._update_density()

    def _color_var(self):
        if self.color_var != "":
            return self.data.domain[self.color_var]
        else:
            return None

    def _update_color(self):
        if self.plotdata is None:
            return

        colorvar = self._color_var()
        validmask = self.plotdata.validmask
        selectionmask = self.plotdata.selectionmask
        if colorvar is not None:
            colors = plotutils.color_data(self.data, colorvar)[validmask]
        else:
            colors = numpy.array([[192, 192, 192]])
            colors = numpy.tile(colors, (self.plotdata.X.shape[0], 1))

        selectedmask = selectionmask[validmask]
        pointstyle = numpy.where(selectedmask, plotutils.Selected,
                                 plotutils.NoFlags)

        pendata = plotutils.pen_data(colors * 0.8, pointstyle)
        colors = numpy.hstack(
            [colors,
             numpy.full((colors.shape[0], 1), float(self.opacity))])
        brushdata = plotutils.brush_data(colors)

        self.plotdata.pendata = pendata
        self.plotdata.brushdata = brushdata
        self.plotdata.mainitem.setPen(pendata)
        self.plotdata.mainitem.setBrush(brushdata)

        self._update_legend()

    def _shape_var(self):
        if self.shape_var != "":
            return self.data.domain[self.shape_var]
        else:
            return None

    def _update_shape(self):
        if self.plotdata is None:
            return
        shapevar = self._shape_var()
        validmask = self.plotdata.validmask
        shapedata = plotutils.shape_data(self.data, shapevar)
        shapedata = shapedata[validmask]
        self.plotdata.shapedata = shapedata
        self.plotdata.mainitem.setSymbol(shapedata)
        self._update_legend()

    def _size_var(self):
        if self.size_var != "":
            return self.data.domain[self.size_var]
        else:
            return None

    def _update_size(self):
        if self.plotdata is None:
            return
        sizevar = self._size_var()
        validmask = self.plotdata.validmask

        sizedata = size_data(self.data, sizevar,
                             pointsize=self.point_size)[validmask]
        self.plotdata.sizedata = sizedata
        self.plotdata.mainitem.setSize(sizedata)

    def _label_var(self):
        if self.label_var != "":
            return self.data.domain[self.label_var]
        else:
            return None

    def _update_labels(self):
        if self.plotdata is None:
            return
        labelvar = self._label_var()

        if labelvar is not None:
            labeldata = plotutils.column_data(self.data, labelvar,
                                              self.plotdata.validmask)
            labeldata = [labelvar.str_val(val) for val in labeldata]
        else:
            labeldata = None

        if self.plotdata.labelitems:
            for item in self.plotdata.labelitems:
                item.setParentItem(None)
                self.plot.removeItem(item)
            self.plotdata.labelitems = []

        if labeldata is not None:
            coords = self.plotdata.embedding_coords
            coords = coords / numpy.max(numpy.linalg.norm(coords, axis=1))
            for (x, y), text in zip(coords, labeldata):
                item = pg.TextItem(text, anchor=(0.5, 0), color=0.0)
                item.setPos(x, y)
                self.plot.addItem(item)
                self.plotdata.labelitems.append(item)

    def _update_legend(self):
        self.legend.clear()
        if self.plotdata is None:
            return

        legend_data = plotutils.legend_data(self._color_var(),
                                            self._shape_var())
        self.legend.clear()
        self.legend.setVisible(bool(legend_data))

        for color, symbol, name in legend_data:
            self.legend.addItem(
                linproj.ScatterPlotItem(pen=color,
                                        brush=color,
                                        symbol=symbol,
                                        size=10), name)

    def _update_density(self):
        if self.plotdata is None:
            return

        if self.plotdata.densityimage is not None:
            self.plot.removeItem(self.plotdata.densityimage)
            self.plotdata.densityimage = None

        if self.data.domain.has_discrete_class and self.class_density:
            coords = self.plotdata.embedding_coords
            radius = numpy.linalg.norm(coords, axis=1).max()
            coords = coords / radius
            xmin = ymin = -1.05
            xmax = ymax = 1.05
            xdata, ydata = coords.T
            colors = plotutils.color_data(
                self.data, self.data.domain.class_var)[self.plotdata.validmask]
            imgitem = classdensity.class_density_image(xmin, xmax, ymin, ymax,
                                                       256, xdata, ydata,
                                                       colors)
            self.plot.addItem(imgitem)
            self.plotdata.densityimage = imgitem

    def _start(self):
        """
        Start the projection optimization.
        """
        if self.plotdata is None:
            return

        X, Y = self.plotdata.X, self.plotdata.Y
        anchors = self.plotdata.anchors
        _, p = OWFreeViz.ForceLaw[self.force_law]

        def update_freeviz(maxiter, itersteps, initial):
            done = False
            anchors = initial
            while not done:
                res = freeviz(X,
                              Y,
                              scale=False,
                              center=False,
                              initial=anchors,
                              p=p,
                              maxiter=min(itersteps, maxiter))
                EX, anchors_new = res[:2]
                yield res[:2]

                if numpy.all(
                        numpy.isclose(anchors,
                                      anchors_new,
                                      rtol=1e-5,
                                      atol=1e-4)):
                    return

                maxiter = maxiter - itersteps
                if maxiter <= 0:
                    return
                anchors = anchors_new

        _, interval = self.ReplotIntervals[self.replot_interval]
        if interval == -1:
            interval = self.maxiter

        self._loop.setCoroutine(update_freeviz(self.maxiter, interval,
                                               anchors))
        self.start_button.setText("Stop")
        self.progressBarInit(processEvents=False)

    def __reset_initialization(self):
        """
        Reset the current 'anchor' initialization, and restart the
        optimization if necessary.
        """
        running = self._loop.isRunning()

        if running:
            self._loop.cancel()

        if self.data is not None:
            self._clear_plot()
            self._setup()

        if running:
            self._start()

    def __reset_update_interval(self):
        running = self._loop.isRunning()
        if running:
            self._loop.cancel()
            if self.data is not None:
                self._start()

    def _update_xy(self):
        # Update the plotted embedding coordinates
        if self.plotdata is None:
            return

        item = self.plotdata.mainitem
        coords = self.plotdata.embedding_coords
        radius = numpy.max(numpy.linalg.norm(coords, axis=1))
        coords = coords / radius
        if self.jitter > 0:
            _, factor = self.JitterAmount[self.jitter]
            coords = coords + self.plotdata.jittervec * factor

        item.setData(x=coords[:, 0],
                     y=coords[:, 1],
                     brush=self.plotdata.brushdata,
                     pen=self.plotdata.pendata,
                     size=self.plotdata.sizedata,
                     symbol=self.plotdata.shapedata,
                     data=numpy.flatnonzero(self.plotdata.validmask))

        for anchor, item in zip(self.plotdata.anchors,
                                self.plotdata.axisitems):
            item.setLine(QtCore.QLineF(0, 0, *anchor))

        for (x, y), item in zip(coords, self.plotdata.labelitems):
            item.setPos(x, y)

    def _update_anchor_visibility(self):
        # Update the anchor/axes visibility
        if self.plotdata is None:
            return

        minradius = self.min_anchor_radius / 100 + 1e-5
        for anchor, item in zip(self.plotdata.anchors,
                                self.plotdata.axisitems):
            item.setVisible(numpy.linalg.norm(anchor) > minradius)
        self.plotdata.hidecircle.setRect(
            QtCore.QRectF(-minradius, -minradius, 2 * minradius,
                          2 * minradius))

    def __set_projection(self, res):
        # Set/update the projection matrix and coordinate embeddings
        assert self.plotdata is not None, "__set_projection call unexpected"
        _, increment = self.ReplotIntervals[self.replot_interval]
        increment = self.maxiter if increment == -1 else increment
        self.progressBarAdvance(increment * 100. / self.maxiter,
                                processEvents=False)
        embedding_coords, projection = res
        self.plotdata.embedding_coords = embedding_coords
        self.plotdata.anchors = projection
        self._update_xy()
        self._update_anchor_visibility()
        self._update_density()

    def __freeviz_finished(self):
        # Projection optimization has finished
        self.start_button.setText("Optimize")
        self.progressBarFinished(processEvents=False)
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def __select_area(self, selectarea):
        """Select instances in the specified plot area."""
        if self.plotdata is None:
            return

        item = self.plotdata.mainitem

        if item is None:
            return

        indices = [
            spot.data() for spot in item.points()
            if selectarea.contains(spot.pos())
        ]
        indices = numpy.array(indices, dtype=int)

        self.select(indices, QtGui.QApplication.keyboardModifiers())

    def select(self, indices, modifiers=Qt.NoModifier):
        """
        Select the instances specified by `indices`

        Parameters
        ----------
        indices : (N,) int ndarray
            Indices of instances to select.
        modifiers : Qt.KeyboardModifier
            Keyboard modifiers.
        """
        if self.plotdata is None:
            return

        current = self.plotdata.selectionmask

        if not modifiers & (Qt.ControlModifier | Qt.ShiftModifier
                            | Qt.AltModifier):
            # no modifiers -> clear current selection
            current = numpy.zeros_like(self.plotdata.validmask, dtype=bool)

        if modifiers & Qt.AltModifier:
            current[indices] = False
        elif modifiers & Qt.ControlModifier:
            current[indices] = ~current[indices]
        else:
            current[indices] = True
        self.plotdata.selectionmask = current
        self._update_color()
        self.commit()

    def commit(self):
        """
        Commit/send the widget output signals.
        """
        data = subset = components = None
        if self.data is not None:
            coords = self.plotdata.embedding_coords
            valid = self.plotdata.validmask
            selection = self.plotdata.selectionmask
            selectedindices = numpy.flatnonzero(valid & selection)

            C1Var = Orange.data.ContinuousVariable("Component1", )
            C2Var = Orange.data.ContinuousVariable("Component2")

            attributes = self.data.domain.attributes
            classes = self.data.domain.class_vars
            metas = self.data.domain.metas
            if self.embedding_domain_role == OWFreeViz.Attribute:
                attributes = attributes + (C1Var, C2Var)
            elif self.embedding_domain_role == OWFreeViz.Meta:
                metas = metas + (C1Var, C2Var)

            domain = Orange.data.Domain(attributes, classes, metas)
            data = self.data.from_table(domain, self.data)

            if self.embedding_domain_role == OWFreeViz.Attribute:
                data.X[valid, -2:] = coords
            elif self.embedding_domain_role == OWFreeViz.Meta:
                data.metas[valid, -2:] = coords

            if selectedindices.size:
                subset = data[selectedindices]

            compdomain = Orange.data.Domain(
                self.data.domain.attributes,
                metas=[Orange.data.StringVariable(name='component')])

            metas = numpy.array([["FreeViz 1"], ["FreeViz 2"]])
            components = Orange.data.Table(compdomain,
                                           self.plotdata.anchors.T,
                                           metas=metas)
            components.name = 'components'

        self.send("Data", data)
        self.send("Selected Data", subset)
        self.send("Components", components)

    def sizeHint(self):
        # reimplemented
        return QtCore.QSize(900, 700)

    def eventFilter(self, recv, event):
        # reimplemented
        if event.type() == QtCore.QEvent.GraphicsSceneHelp and \
                recv is self.plot.scene():
            return self._tooltip(event)
        else:
            return super().eventFilter(recv, event)

    def _tooltip(self, event):
        # Handle a help event for the plot's scene
        if self.plotdata is None:
            return False

        item = self.plotdata.mainitem
        pos = item.mapFromScene(event.scenePos())
        points = item.pointsAt(pos)
        indices = [spot.data() for spot in points]
        if not indices:
            return False

        tooltip = format_tooltip(self.data, columns=..., rows=indices)
        QtGui.QToolTip.showText(event.screenPos(), tooltip, widget=self.plot)
        return True
class OWDisplayProfiles(widget.OWWidget):
    name = "Data Profiles"
    description = "Visualization of data profiles (e.g., time series)."
    icon = "../widgets/icons/ExpressionProfiles.svg"
    priority = 1030

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Selected Data", Orange.data.Table)]
    settingsHandler = settings.DomainContextHandler()

    #: List of selected class indices
    selected_classes = settings.ContextSetting([])
    #: Show individual profiles
    display_individual = settings.Setting(False)
    #: Show average profile
    display_average = settings.Setting(True)
    #: Show data quartiles
    display_quartiles = settings.Setting(True)
    #: Profile label/id colum
    annot_index = settings.ContextSetting(0)
    auto_commit = settings.Setting(True)

    def __init__(self, parent=None):
        super().__init__(parent)

        self.classes = []

        self.data = None
        self.annotation_variables = []
        self.__groups = None
        self.__selected_data_indices = []

        # Setup GUI
        infobox = gui.widgetBox(self.controlArea, "Info")
        self.infoLabel = gui.widgetLabel(infobox, "No data on input.")
        displaybox = gui.widgetBox(self.controlArea, "Display")
        gui.checkBox(displaybox,
                     self,
                     "display_individual",
                     "Expression Profiles",
                     callback=self.__update_visibility)
        gui.checkBox(displaybox,
                     self,
                     "display_quartiles",
                     "Quartiles",
                     callback=self.__update_visibility)

        group_box = gui.widgetBox(self.controlArea, "Classes")
        self.group_listbox = gui.listBox(
            group_box,
            self,
            "selected_classes",
            "classes",
            selectionMode=QtGui.QListWidget.MultiSelection,
            callback=self.__on_class_selection_changed)
        self.unselectAllClassedQLB = gui.button(
            group_box, self, "Unselect all", callback=self.__select_all_toggle)

        self.annot_cb = gui.comboBox(self.controlArea,
                                     self,
                                     "annot_index",
                                     box="Profile Labels",
                                     callback=self.__update_tooltips)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, "auto_commit", "Commit")

        self.graph = pg.PlotWidget(background="w")
        self.graph.setRenderHint(QtGui.QPainter.Antialiasing, True)
        self.graph.scene().selectionChanged.connect(
            self.__on_curve_selection_changed)
        self.mainArea.layout().addWidget(self.graph)

    def sizeHint(self):
        return QtCore.QSize(800, 600)

    def clear(self):
        """
        Clear/reset the widget state.
        """
        self.group_listbox.clear()
        self.annot_cb.clear()
        self.data = None
        self.annotation_variables = []
        self.__groups = None
        self.__selected_data_indices = []
        self.graph.clear()

    def set_data(self, data):
        """
        Set the input profile dataset.
        """
        self.closeContext()
        self.clear()

        self.data = data
        if data is not None:
            n_instances = len(data)
            n_attrs = len(data.domain.attributes)
            self.infoLabel.setText("%i genes on input\n%i attributes" %
                                   (n_instances, n_attrs))

            if is_discrete(data.domain.class_var):
                class_vals = data.domain.class_var.values
            else:
                class_vals = []

            self.classes = list(class_vals)
            self.class_colors = \
                colorpalette.ColorPaletteGenerator(len(class_vals))
            self.selected_classes = list(range(len(class_vals)))
            for i in range(len(class_vals)):
                item = self.group_listbox.item(i)
                item.setIcon(colorpalette.ColorPixmap(self.class_colors[i]))

            variables = data.domain.variables + data.domain.metas
            annotvars = [
                var for var in variables if is_discrete(var) or is_string(var)
            ]

            for var in annotvars:
                self.annot_cb.addItem(*gui.attributeItem(var))

            if data.domain.class_var in annotvars:
                self.annot_index = annotvars.index(data.domain.class_var)
            self.annotation_variables = annotvars
            self.openContext(data)

            self._setup_plot()

        self.commit()

    def _setup_plot(self):
        """Setup the plot with new curve data."""
        assert self.data is not None

        data, domain = self.data, self.data.domain
        if is_discrete(domain.class_var):
            class_col_data, _ = data.get_column_view(domain.class_var)

            group_indices = [
                np.flatnonzero(class_col_data == i)
                for i in range(len(domain.class_var.values))
            ]
        else:
            group_indices = [np.arange(len(data))]

        X = np.arange(1, len(domain.attributes) + 1)
        groups = []

        for i, indices in enumerate(group_indices):
            if self.classes:
                color = self.class_colors[i]
            else:
                color = QColor(Qt.darkGray)
            group_data = data[indices, :]
            plot_x, plot_y, connect = disconnected_curve_data(group_data.X,
                                                              x=X)

            color.setAlpha(200)
            lightcolor = QColor(color.lighter(factor=150))
            lightcolor.setAlpha(150)
            pen = QPen(color, 2)
            pen.setCosmetic(True)

            lightpen = QPen(lightcolor, 1)
            lightpen.setCosmetic(True)
            hoverpen = QPen(pen)
            hoverpen.setWidth(2)

            curve = pg.PlotCurveItem(
                x=plot_x,
                y=plot_y,
                connect=connect,
                pen=lightpen,
                symbolSize=2,
                antialias=True,
            )
            self.graph.addItem(curve)

            hovercurves = []
            for index, profile in zip(indices, group_data.X):
                hcurve = HoverCurve(x=X,
                                    y=profile,
                                    pen=hoverpen,
                                    antialias=True)
                hcurve.setToolTip('{}'.format(index))
                hcurve._data_index = index
                hovercurves.append(hcurve)
                self.graph.addItem(hcurve)

            mean = np.nanmean(group_data.X, axis=0)

            meancurve = pg.PlotDataItem(x=X,
                                        y=mean,
                                        pen=pen,
                                        size=5,
                                        symbol="o",
                                        pxMode=True,
                                        symbolSize=5,
                                        antialias=True)
            hoverpen = QPen(hoverpen)
            hoverpen.setWidth(5)

            hc = HoverCurve(x=X, y=mean, pen=hoverpen, antialias=True)
            hc.setFlag(QGraphicsItem.ItemIsSelectable, False)
            self.graph.addItem(hc)

            self.graph.addItem(meancurve)
            q1, q2, q3 = np.nanpercentile(group_data.X, [25, 50, 75], axis=0)
            # TODO: implement and use a box plot item
            errorbar = pg.ErrorBarItem(x=X,
                                       y=mean,
                                       bottom=np.clip(mean - q1, 0, mean - q1),
                                       top=np.clip(q3 - mean, 0, q3 - mean),
                                       beam=0.5)
            self.graph.addItem(errorbar)
            groups.append(
                namespace(data=group_data,
                          indices=indices,
                          profiles=curve,
                          hovercurves=hovercurves,
                          mean=meancurve,
                          boxplot=errorbar))

        self.__groups = groups
        self.__update_visibility()
        self.__update_tooltips()

    def __update_visibility(self):
        if self.__groups is None:
            return

        if self.classes:
            selected = lambda i: i in self.selected_classes
        else:
            selected = lambda i: True
        for i, group in enumerate(self.__groups):
            isselected = selected(i)
            group.profiles.setVisible(isselected and self.display_individual)
            group.mean.setVisible(isselected)  # and self.display_average)
            group.boxplot.setVisible(isselected and self.display_quartiles)
            for hc in group.hovercurves:
                hc.setVisible(isselected and self.display_individual)

    def __update_tooltips(self):
        if self.__groups is None:
            return

        if 0 <= self.annot_index < len(self.annotation_variables):
            annotvar = self.annotation_variables[self.annot_index]
            column, _ = self.data.get_column_view(annotvar)
            column = [annotvar.str_val(val) for val in column]
        else:
            annotvar = None
            column = [str(i) for i in range(len(self.data))]

        for group in self.__groups:
            for hcurve in group.hovercurves:
                value = column[hcurve._data_index]
                hcurve.setToolTip(value)

    def __select_all_toggle(self):
        allselected = len(self.selected_classes) == len(self.classes)
        if allselected:
            self.selected_classes = []
        else:
            self.selected_classes = list(range(len(self.classes)))

        self.__on_class_selection_changed()

    def __on_class_selection_changed(self):
        mask = [
            i in self.selected_classes
            for i in range(self.group_listbox.count())
        ]
        self.unselectAllClassedQLB.setText(
            "Select all" if not all(mask) else "Unselect all")

        self.__update_visibility()

    def __on_annotation_index_changed(self):
        self.__update_tooltips()

    def __on_curve_selection_changed(self):
        if self.data is not None:
            selected = self.graph.scene().selectedItems()
            indices = [item._data_index for item in selected]
            self.__selected_data_indices = np.array(indices, dtype=int)
            self.commit()

    def commit(self):
        subset = None
        if self.data is not None and len(self.__selected_data_indices) > 0:
            subset = self.data[self.__selected_data_indices]

        self.send("Selected Data", subset)
Exemple #7
0
class OWDistributions(OWWidget):
    name = "Distributions"
    description = "Display value distributions of a data feature in a graph."
    icon = "icons/Distribution.svg"
    priority = 120
    keywords = []

    class Inputs:
        data = Input("Data", Table, doc="Set the input dataset")

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        histogram_data = Output("Histogram Data", Table)

    class Error(OWWidget.Error):
        no_defined_values_var = \
            Msg("Variable '{}' does not have any defined values")
        no_defined_values_pair = \
            Msg("No data instances with '{}' and '{}' defined")

    class Warning(OWWidget.Warning):
        ignored_nans = Msg("Data instances with missing values are ignored")

    settingsHandler = settings.DomainContextHandler()
    var = settings.ContextSetting(None)
    cvar = settings.ContextSetting(None)
    selection = settings.ContextSetting(set(), schema_only=True)
    # number_of_bins must be a context setting because selection depends on it
    number_of_bins = settings.ContextSetting(5, schema_only=True)

    fitted_distribution = settings.Setting(0)
    hide_bars = settings.Setting(False)
    show_probs = settings.Setting(False)
    stacked_columns = settings.Setting(False)
    cumulative_distr = settings.Setting(False)
    kde_smoothing = settings.Setting(10)

    auto_apply = settings.Setting(True)

    graph_name = "plot"

    Fitters = (("None", None, (),
                ()), ("Normal", norm, ("loc", "scale"),
                      ("μ", "σ²")), ("Beta", beta, ("a", "b", "loc", "scale"),
                                     ("α", "β", "-loc", "-scale")),
               ("Gamma", gamma, ("a", "loc", "scale"), ("α", "β", "-loc",
                                                        "-scale")),
               ("Rayleigh", rayleigh, ("loc", "scale"),
                ("-loc", "σ²")), ("Pareto", pareto, ("b", "loc", "scale"),
                                  ("α", "-loc", "-scale")),
               ("Exponential", expon, ("loc", "scale"),
                ("-loc", "λ")), ("Kernel density", AshCurve, ("a", ), ("", )))

    DragNone, DragAdd, DragRemove = range(3)

    def __init__(self):
        super().__init__()
        self.data = None
        self.valid_data = self.valid_group_data = None
        self.bar_items = []
        self.curve_items = []
        self.curve_descriptions = None
        self.binnings = []

        self.last_click_idx = None
        self.drag_operation = self.DragNone
        self.key_operation = None
        self._user_var_bins = {}

        gui.listView(self.controlArea,
                     self,
                     "var",
                     box="Variable",
                     model=DomainModel(valid_types=DomainModel.PRIMITIVE,
                                       separators=False),
                     callback=self._on_var_changed)

        box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
        slider = gui.hSlider(box,
                             self,
                             "number_of_bins",
                             label="Bin width",
                             orientation=Qt.Horizontal,
                             minValue=0,
                             maxValue=max(1,
                                          len(self.binnings) - 1),
                             createLabel=False,
                             callback=self._on_bins_changed)
        self.bin_width_label = gui.widgetLabel(slider.box)
        self.bin_width_label.setFixedWidth(35)
        self.bin_width_label.setAlignment(Qt.AlignRight)
        slider.sliderReleased.connect(self._on_bin_slider_released)
        gui.comboBox(box,
                     self,
                     "fitted_distribution",
                     label="Fitted distribution",
                     orientation=Qt.Horizontal,
                     items=(name[0] for name in self.Fitters),
                     callback=self._on_fitted_dist_changed)
        self.smoothing_box = gui.indentedBox(box, 40)
        gui.hSlider(self.smoothing_box,
                    self,
                    "kde_smoothing",
                    label="Smoothing",
                    orientation=Qt.Horizontal,
                    minValue=2,
                    maxValue=20,
                    callback=self.replot)
        gui.checkBox(box,
                     self,
                     "hide_bars",
                     "Hide bars",
                     stateWhenDisabled=False,
                     callback=self._on_hide_bars_changed,
                     disabled=not self.fitted_distribution)

        box = gui.vBox(self.controlArea, "Columns")
        gui.comboBox(box,
                     self,
                     "cvar",
                     label="Split by",
                     orientation=Qt.Horizontal,
                     model=DomainModel(
                         placeholder="(None)",
                         valid_types=(DiscreteVariable),
                     ),
                     callback=self._on_cvar_changed,
                     contentsLength=18)
        gui.checkBox(box,
                     self,
                     "stacked_columns",
                     "Stack columns",
                     callback=self.replot)
        gui.checkBox(box,
                     self,
                     "show_probs",
                     "Show probabilities",
                     callback=self._on_show_probabilities_changed)
        gui.checkBox(box,
                     self,
                     "cumulative_distr",
                     "Show cumulative distribution",
                     callback=self.replot)

        gui.auto_apply(self.controlArea, self, commit=self.apply)

        self._set_smoothing_visibility()
        self._setup_plots()
        self._setup_legend()

    def _setup_plots(self):
        def add_new_plot(zvalue):
            plot = pg.ViewBox(enableMouse=False, enableMenu=False)
            self.ploti.scene().addItem(plot)
            pg.AxisItem("right").linkToView(plot)
            plot.setXLink(self.ploti)
            plot.setZValue(zvalue)
            return plot

        self.plotview = DistributionWidget(background=None)
        self.plotview.item_clicked.connect(self._on_item_clicked)
        self.plotview.blank_clicked.connect(self._on_blank_clicked)
        self.plotview.mouse_released.connect(self._on_end_selecting)
        self.plotview.setRenderHint(QPainter.Antialiasing)
        self.mainArea.layout().addWidget(self.plotview)
        self.ploti = pg.PlotItem(
            enableMenu=False,
            enableMouse=False,
            axisItems={"bottom": ElidedAxisNoUnits("bottom")})
        self.plot = self.ploti.vb
        self.plot.setMouseEnabled(False, False)
        self.ploti.hideButtons()
        self.plotview.setCentralItem(self.ploti)

        self.plot_pdf = add_new_plot(10)
        self.plot_mark = add_new_plot(-10)
        self.plot_mark.setYRange(0, 1)
        self.ploti.vb.sigResized.connect(self.update_views)
        self.update_views()

        pen = QPen(self.palette().color(QPalette.Text))
        self.ploti.getAxis("bottom").setPen(pen)
        left = self.ploti.getAxis("left")
        left.setPen(pen)
        left.setStyle(stopAxisAtTick=(True, True))

    def _setup_legend(self):
        self._legend = LegendItem()
        self._legend.setParentItem(self.plot_pdf)
        self._legend.hide()
        self._legend.anchor((1, 0), (1, 0))

    # -----------------------------
    # Event and signal handlers

    def update_views(self):
        for plot in (self.plot_pdf, self.plot_mark):
            plot.setGeometry(self.plot.sceneBoundingRect())
            plot.linkedViewChanged(self.plot, plot.XAxis)

    def onDeleteWidget(self):
        self.plot.clear()
        self.plot_pdf.clear()
        self.plot_mark.clear()
        super().onDeleteWidget()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.var = self.cvar = None
        self.data = data
        domain = self.data.domain if self.data else None
        varmodel = self.controls.var.model()
        cvarmodel = self.controls.cvar.model()
        varmodel.set_domain(domain)
        cvarmodel.set_domain(domain)
        if varmodel:
            self.var = varmodel[min(len(domain.class_vars), len(varmodel) - 1)]
        if domain is not None and domain.has_discrete_class:
            self.cvar = domain.class_var
        self.reset_select()
        self._user_var_bins.clear()
        self.openContext(domain)
        self.set_valid_data()
        self.recompute_binnings()
        self.replot()
        self.apply()

    def _on_var_changed(self):
        self.reset_select()
        self.set_valid_data()
        self.recompute_binnings()
        self.replot()
        self.apply()

    def _on_cvar_changed(self):
        self.set_valid_data()
        self.replot()
        self.apply()

    def _on_bins_changed(self):
        self.reset_select()
        self._set_bin_width_slider_label()
        self.replot()
        # this is triggered when dragging, so don't call apply here;
        # apply is called on sliderReleased

    def _on_bin_slider_released(self):
        self._user_var_bins[self.var] = self.number_of_bins
        self.apply()

    def _on_fitted_dist_changed(self):
        self.controls.hide_bars.setDisabled(not self.fitted_distribution)
        self._set_smoothing_visibility()
        self.replot()

    def _on_hide_bars_changed(self):
        for bar in self.bar_items:  # pylint: disable=blacklisted-name
            bar.setHidden(self.hide_bars)
        self._set_curve_brushes()
        self.plot.update()

    def _set_smoothing_visibility(self):
        self.smoothing_box.setVisible(
            self.Fitters[self.fitted_distribution][1] is AshCurve)

    def _set_bin_width_slider_label(self):
        if self.number_of_bins < len(self.binnings):
            text = reduce(lambda s, rep: s.replace(*rep),
                          short_time_units.items(),
                          self.binnings[self.number_of_bins].width_label)
        else:
            text = ""
        self.bin_width_label.setText(text)

    def _on_show_probabilities_changed(self):
        label = self.controls.fitted_distribution.label
        if self.show_probs:
            label.setText("Fitted probability")
            label.setToolTip(
                "Chosen distribution is used to compute Bayesian probabilities"
            )
        else:
            label.setText("Fitted distribution")
            label.setToolTip("")
        self.replot()

    @property
    def is_valid(self):
        return self.valid_data is not None

    def set_valid_data(self):
        err_def_var = self.Error.no_defined_values_var
        err_def_pair = self.Error.no_defined_values_pair
        err_def_var.clear()
        err_def_pair.clear()
        self.Warning.ignored_nans.clear()

        self.valid_data = self.valid_group_data = None
        if self.var is None:
            return

        column = self.data.get_column_view(self.var)[0].astype(float)
        valid_mask = np.isfinite(column)
        if not np.any(valid_mask):
            self.Error.no_defined_values_var(self.var.name)
            return
        if self.cvar:
            ccolumn = self.data.get_column_view(self.cvar)[0].astype(float)
            valid_mask *= np.isfinite(ccolumn)
            if not np.any(valid_mask):
                self.Error.no_defined_values_pair(self.var.name,
                                                  self.cvar.name)
                return
            self.valid_group_data = ccolumn[valid_mask]
        if not np.all(valid_mask):
            self.Warning.ignored_nans()
        self.valid_data = column[valid_mask]

    # -----------------------------
    # Plotting

    def replot(self):
        self._clear_plot()
        if self.is_valid:
            self._set_axis_names()
            self._update_controls_state()
            self._call_plotting()
            self._display_legend()
        self.show_selection()

    def _clear_plot(self):
        self.plot.clear()
        self.plot_pdf.clear()
        self.plot_mark.clear()
        self.bar_items = []
        self.curve_items = []
        self._legend.clear()
        self._legend.hide()

    def _set_axis_names(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(self.var and self.var.name)
        bottomaxis.setShowUnit(not (self.var and self.var.is_time))

        leftaxis = self.ploti.getAxis("left")
        if self.show_probs and self.cvar:
            leftaxis.setLabel(
                f"Probability of '{self.cvar.name}' at given '{self.var.name}'"
            )
        else:
            leftaxis.setLabel("Frequency")
        leftaxis.resizeEvent()

    def _update_controls_state(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        self.continuous_box.setDisabled(self.var.is_discrete)
        self.controls.show_probs.setDisabled(self.cvar is None)
        self.controls.stacked_columns.setDisabled(self.cvar is None)

    def _call_plotting(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        self.curve_descriptions = None
        if self.var.is_discrete:
            if self.cvar:
                self._disc_split_plot()
            else:
                self._disc_plot()
        else:
            if self.cvar:
                self._cont_split_plot()
            else:
                self._cont_plot()
        self.plot.autoRange()

    def _add_bar(self,
                 x,
                 width,
                 padding,
                 freqs,
                 colors,
                 stacked,
                 expanded,
                 tooltip,
                 hidden=False):
        item = DistributionBarItem(x, width, padding, freqs, colors, stacked,
                                   expanded, tooltip, hidden)
        self.plot.addItem(item)
        self.bar_items.append(item)

    def _disc_plot(self):
        var = self.var
        self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
        colors = [QColor(0, 128, 255)]
        dist = distribution.get_distribution(self.data, self.var)
        for i, freq in enumerate(dist):
            tooltip = \
                "<p style='white-space:pre;'>" \
                f"<b>{escape(var.values[i])}</b>: {int(freq)} " \
                f"({100 * freq / len(self.valid_data):.2f} %) "
            self._add_bar(i - 0.5,
                          1,
                          0.1, [freq],
                          colors,
                          stacked=False,
                          expanded=False,
                          tooltip=tooltip)

    def _disc_split_plot(self):
        var = self.var
        self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
        gcolors = [QColor(*col) for col in self.cvar.colors]
        gvalues = self.cvar.values
        conts = contingency.get_contingency(self.data, self.cvar, self.var)
        total = len(self.data)
        for i, freqs in enumerate(conts):
            self._add_bar(i - 0.5,
                          1,
                          0.1,
                          freqs,
                          gcolors,
                          stacked=self.stacked_columns,
                          expanded=self.show_probs,
                          tooltip=self._split_tooltip(var.values[i],
                                                      np.sum(freqs), total,
                                                      gvalues, freqs))

    def _cont_plot(self):
        self._set_cont_ticks()
        data = self.valid_data
        y, x = np.histogram(data,
                            bins=self.binnings[self.number_of_bins].thresholds)
        total = len(data)
        colors = [QColor(0, 128, 255)]
        if self.fitted_distribution:
            colors[0] = colors[0].lighter(130)

        tot_freq = 0
        lasti = len(y) - 1
        for i, (x0, x1), freq in zip(count(), zip(x, x[1:]), y):
            tot_freq += freq
            tooltip = \
                "<p style='white-space:pre;'>" \
                f"<b>{escape(self.str_int(x0, x1, not i, i == lasti))}</b>: " \
                f"{freq} ({100 * freq / total:.2f} %)</p>"
            self._add_bar(x0,
                          x1 - x0,
                          0, [tot_freq if self.cumulative_distr else freq],
                          colors,
                          stacked=False,
                          expanded=False,
                          tooltip=tooltip,
                          hidden=self.hide_bars)

        if self.fitted_distribution:
            self._plot_approximations(x[0], x[-1],
                                      [self._fit_approximation(data)],
                                      [QColor(0, 0, 0)], (1, ))

    def _cont_split_plot(self):
        self._set_cont_ticks()
        data = self.valid_data
        _, bins = np.histogram(
            data, bins=self.binnings[self.number_of_bins].thresholds)
        gvalues = self.cvar.values
        varcolors = [QColor(*col) for col in self.cvar.colors]
        if self.fitted_distribution:
            gcolors = [c.lighter(130) for c in varcolors]
        else:
            gcolors = varcolors
        nvalues = len(gvalues)
        ys = []
        fitters = []
        prior_sizes = []
        for val_idx in range(nvalues):
            group_data = data[self.valid_group_data == val_idx]
            prior_sizes.append(len(group_data))
            ys.append(np.histogram(group_data, bins)[0])
            if self.fitted_distribution:
                fitters.append(self._fit_approximation(group_data))
        total = len(data)
        prior_sizes = np.array(prior_sizes)
        tot_freqs = np.zeros(len(ys))

        lasti = len(ys[0]) - 1
        for i, x0, x1, freqs in zip(count(), bins, bins[1:], zip(*ys)):
            tot_freqs += freqs
            plotfreqs = tot_freqs.copy() if self.cumulative_distr else freqs
            self._add_bar(x0,
                          x1 - x0,
                          0 if self.stacked_columns else 0.1,
                          plotfreqs,
                          gcolors,
                          stacked=self.stacked_columns,
                          expanded=self.show_probs,
                          hidden=self.hide_bars,
                          tooltip=self._split_tooltip(
                              self.str_int(x0, x1, not i, i == lasti),
                              np.sum(plotfreqs), total, gvalues, plotfreqs))

        if fitters:
            self._plot_approximations(bins[0], bins[-1], fitters, varcolors,
                                      prior_sizes / len(data))

    def _set_cont_ticks(self):
        axis = self.ploti.getAxis("bottom")
        if self.var and self.var.is_time:
            binning = self.binnings[self.number_of_bins]
            labels = np.array(binning.short_labels)
            thresholds = np.array(binning.thresholds)
            lengths = np.array([len(lab) for lab in labels])
            slengths = set(lengths)
            if len(slengths) == 1:
                ticks = [
                    list(zip(thresholds[::2], labels[::2])),
                    list(zip(thresholds[1::2], labels[1::2]))
                ]
            else:
                ticks = []
                for length in sorted(slengths, reverse=True):
                    idxs = lengths == length
                    ticks.append(list(zip(thresholds[idxs], labels[idxs])))
            axis.setTicks(ticks)
        else:
            axis.setTicks(None)

    def _fit_approximation(self, y):
        def join_pars(pairs):
            strv = self.var.str_val
            return ", ".join(f"{sname}={strv(val)}" for sname, val in pairs)

        def str_params():
            s = join_pars((sname, val)
                          for sname, val in zip(str_names, fitted)
                          if sname and sname[0] != "-")
            par = join_pars((sname[1:], val)
                            for sname, val in zip(str_names, fitted)
                            if sname and sname[0] == "-")
            if par:
                s += f" ({par})"
            return s

        if not y.size:
            return None, None
        _, dist, names, str_names = self.Fitters[self.fitted_distribution]
        fitted = dist.fit(y)
        params = dict(zip(names, fitted))
        return partial(dist.pdf, **params), str_params()

    def _plot_approximations(self, x0, x1, fitters, colors, prior_probs):
        x = np.linspace(x0, x1, 100)
        ys = np.zeros((len(fitters), 100))
        self.curve_descriptions = [s for _, s in fitters]
        for y, (fitter, _) in zip(ys, fitters):
            if fitter is None:
                continue
            if self.Fitters[self.fitted_distribution][1] is AshCurve:
                y[:] = fitter(x, sigma=(22 - self.kde_smoothing) / 40)
            else:
                y[:] = fitter(x)
            if self.cumulative_distr:
                y[:] = np.cumsum(y)
        tots = np.sum(ys, axis=0)

        show_probs = self.show_probs and self.cvar is not None
        plot = self.ploti if show_probs else self.plot_pdf

        for y, prior_prob, color in zip(ys, prior_probs, colors):
            if not prior_prob:
                continue
            if show_probs:
                y_p = y * prior_prob
                tot = (y_p + (tots - y) * (1 - prior_prob))
                tot[tot == 0] = 1
                y = y_p / tot
            curve = pg.PlotCurveItem(x=x,
                                     y=y,
                                     fillLevel=0,
                                     pen=pg.mkPen(width=5, color=color),
                                     shadowPen=pg.mkPen(
                                         width=8, color=color.darker(120)))
            plot.addItem(curve)
            self.curve_items.append(curve)
        if not show_probs:
            self.plot_pdf.autoRange()
        self._set_curve_brushes()

    def _set_curve_brushes(self):
        for curve in self.curve_items:
            if self.hide_bars:
                color = curve.opts['pen'].color().lighter(160)
                color.setAlpha(128)
                curve.setBrush(pg.mkBrush(color))
            else:
                curve.setBrush(None)

    @staticmethod
    def _split_tooltip(valname, tot_group, total, gvalues, freqs):
        div_group = tot_group or 1
        cs = "white-space:pre; text-align: right;"
        s = f"style='{cs} padding-left: 1em'"
        snp = f"style='{cs}'"
        return f"<table style='border-collapse: collapse'>" \
               f"<tr><th {s}>{escape(valname)}:</th>" \
               f"<td {snp}><b>{int(tot_group)}</b></td>" \
               "<td/>" \
               f"<td {s}><b>{100 * tot_group / total:.2f} %</b></td></tr>" + \
               f"<tr><td/><td/><td {s}>(in group)</td><td {s}>(overall)</td>" \
               "</tr>" + \
               "".join(
                   "<tr>"
                   f"<th {s}>{value}:</th>"
                   f"<td {snp}><b>{int(freq)}</b></td>"
                   f"<td {s}>{100 * freq / div_group:.2f} %</td>"
                   f"<td {s}>{100 * freq / total:.2f} %</td>"
                   "</tr>"
                   for value, freq in zip(gvalues, freqs)) + \
               "</table>"

    def _display_legend(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        if self.cvar is None:
            if not self.curve_descriptions or not self.curve_descriptions[0]:
                self._legend.hide()
                return
            self._legend.addItem(
                pg.PlotCurveItem(pen=pg.mkPen(width=5, color=0.0)),
                self.curve_descriptions[0])
        else:
            cvar_values = self.cvar.values
            colors = [QColor(*col) for col in self.cvar.colors]
            descriptions = self.curve_descriptions or repeat(None)
            for color, name, desc in zip(colors, cvar_values, descriptions):
                self._legend.addItem(
                    ScatterPlotItem(pen=color, brush=color, size=10,
                                    shape="s"),
                    escape(name + (f" ({desc})" if desc else "")))
        self._legend.show()

    # -----------------------------
    # Bins

    def recompute_binnings(self):
        if self.is_valid and self.var.is_continuous:
            # binning is computed on valid var data, ignoring any cvar nans
            column = self.data.get_column_view(self.var)[0].astype(float)
            if np.any(np.isfinite(column)):
                if self.var.is_time:
                    self.binnings = time_binnings(column, min_unique=5)
                    self.bin_width_label.setFixedWidth(45)
                else:
                    self.binnings = decimal_binnings(
                        column,
                        min_width=self.min_var_resolution(self.var),
                        add_unique=10,
                        min_unique=5)
                    self.bin_width_label.setFixedWidth(35)
                max_bins = len(self.binnings) - 1
        else:
            self.binnings = []
            max_bins = 0

        self.controls.number_of_bins.setMaximum(max_bins)
        self.number_of_bins = min(
            max_bins, self._user_var_bins.get(self.var, self.number_of_bins))
        self._set_bin_width_slider_label()

    @staticmethod
    def min_var_resolution(var):
        # pylint: disable=unidiomatic-typecheck
        if type(var) is not ContinuousVariable:
            return 0
        return 10**-var.number_of_decimals

    def str_int(self, x0, x1, first, last):
        var = self.var
        sx0, sx1 = var.repr_val(x0), var.repr_val(x1)
        if self.cumulative_distr:
            return f"{var.name} < {sx1}"
        elif first and last:
            return f"{var.name} = {sx0}"
        elif first:
            return f"{var.name} < {sx1}"
        elif last:
            return f"{var.name} ≥ {sx0}"
        elif sx0 == sx1 or x1 - x0 <= self.min_var_resolution(var):
            return f"{var.name} = {sx0}"
        else:
            return f"{sx0} ≤ {var.name} < {sx1}"

    # -----------------------------
    # Selection

    def _on_item_clicked(self, item, modifiers, drag):
        def add_or_remove(idx, add):
            self.drag_operation = [self.DragRemove, self.DragAdd][add]
            if add:
                self.selection.add(idx)
            else:
                if idx in self.selection:
                    # This can be False when removing with dragging and the
                    # mouse crosses unselected items
                    self.selection.remove(idx)

        def add_range(add):
            if self.last_click_idx is None:
                add = True
                idx_range = {idx}
            else:
                from_idx, to_idx = sorted((self.last_click_idx, idx))
                idx_range = set(range(from_idx, to_idx + 1))
            self.drag_operation = [self.DragRemove, self.DragAdd][add]
            if add:
                self.selection |= idx_range
            else:
                self.selection -= idx_range

        self.key_operation = None
        if item is None:
            self.reset_select()
            return

        idx = self.bar_items.index(item)
        if drag:
            # Dragging has to add a range, otherwise fast dragging skips bars
            add_range(self.drag_operation == self.DragAdd)
        else:
            if modifiers & Qt.ShiftModifier:
                add_range(self.drag_operation == self.DragAdd)
            elif modifiers & Qt.ControlModifier:
                add_or_remove(idx, add=idx not in self.selection)
            else:
                if self.selection == {idx}:
                    # Clicking on a single selected bar  deselects it,
                    # but dragging from here will select
                    add_or_remove(idx, add=False)
                    self.drag_operation = self.DragAdd
                else:
                    self.selection.clear()
                    add_or_remove(idx, add=True)
        self.last_click_idx = idx

        self.show_selection()

    def _on_blank_clicked(self):
        self.reset_select()

    def reset_select(self):
        self.selection.clear()
        self.last_click_idx = None
        self.drag_operation = None
        self.key_operation = None
        self.show_selection()

    def _on_end_selecting(self):
        self.apply()

    def show_selection(self):
        self.plot_mark.clear()
        if not self.is_valid:  # though if it's not, selection is empty anyway
            return

        blue = QColor(Qt.blue)
        pen = QPen(QBrush(blue), 3)
        pen.setCosmetic(True)
        brush = QBrush(blue.lighter(190))

        for group in self.grouped_selection():
            group = list(group)
            left_idx, right_idx = group[0], group[-1]
            left_pad, right_pad = self._determine_padding(left_idx, right_idx)
            x0 = self.bar_items[left_idx].x0 - left_pad
            x1 = self.bar_items[right_idx].x1 + right_pad
            item = QGraphicsRectItem(x0, 0, x1 - x0, 1)
            item.setPen(pen)
            item.setBrush(brush)
            if self.var.is_continuous:
                valname = self.str_int(x0, x1, not left_idx,
                                       right_idx == len(self.bar_items) - 1)
                inside = sum(np.sum(self.bar_items[i].freqs) for i in group)
                total = len(self.valid_data)
                item.setToolTip("<p style='white-space:pre;'>"
                                f"<b>{escape(valname)}</b>: "
                                f"{inside} ({100 * inside / total:.2f} %)")
            self.plot_mark.addItem(item)

    def _determine_padding(self, left_idx, right_idx):
        def _padding(i):
            return (self.bar_items[i + 1].x0 - self.bar_items[i].x1) / 2

        if len(self.bar_items) == 1:
            return 6, 6
        if left_idx == 0 and right_idx == len(self.bar_items) - 1:
            return (_padding(0), ) * 2

        if left_idx > 0:
            left_pad = _padding(left_idx - 1)
        if right_idx < len(self.bar_items) - 1:
            right_pad = _padding(right_idx)
        else:
            right_pad = left_pad
        if left_idx == 0:
            left_pad = right_pad
        return left_pad, right_pad

    def grouped_selection(self):
        return [[g[1] for g in group]
                for _, group in groupby(enumerate(sorted(self.selection)),
                                        key=lambda x: x[1] - x[0])]

    def keyPressEvent(self, e):
        def on_nothing_selected():
            if e.key() == Qt.Key_Left:
                self.last_click_idx = len(self.bar_items) - 1
            else:
                self.last_click_idx = 0
            self.selection.add(self.last_click_idx)

        def on_key_left():
            if e.modifiers() & Qt.ShiftModifier:
                if self.key_operation == Qt.Key_Right and first != last:
                    self.selection.remove(last)
                    self.last_click_idx = last - 1
                elif first:
                    self.key_operation = Qt.Key_Left
                    self.selection.add(first - 1)
                    self.last_click_idx = first - 1
            else:
                self.selection.clear()
                self.last_click_idx = max(first - 1, 0)
                self.selection.add(self.last_click_idx)

        def on_key_right():
            if e.modifiers() & Qt.ShiftModifier:
                if self.key_operation == Qt.Key_Left and first != last:
                    self.selection.remove(first)
                    self.last_click_idx = first + 1
                elif not self._is_last_bar(last):
                    self.key_operation = Qt.Key_Right
                    self.selection.add(last + 1)
                    self.last_click_idx = last + 1
            else:
                self.selection.clear()
                self.last_click_idx = min(last + 1, len(self.bar_items) - 1)
                self.selection.add(self.last_click_idx)

        if not self.is_valid or not self.bar_items \
                or e.key() not in (Qt.Key_Left, Qt.Key_Right):
            super().keyPressEvent(e)
            return

        prev_selection = self.selection.copy()
        if not self.selection:
            on_nothing_selected()
        else:
            first, last = min(self.selection), max(self.selection)
            if e.key() == Qt.Key_Left:
                on_key_left()
            else:
                on_key_right()

        if self.selection != prev_selection:
            self.drag_operation = self.DragAdd
            self.show_selection()
            self.apply()

    def keyReleaseEvent(self, ev):
        if ev.key() == Qt.Key_Shift:
            self.key_operation = None
        super().keyReleaseEvent(ev)

    # -----------------------------
    # Output

    def apply(self):
        data = self.data
        selected_data = annotated_data = histogram_data = None
        if self.is_valid:
            if self.var.is_discrete:
                group_indices, values = self._get_output_indices_disc()
            else:
                group_indices, values = self._get_output_indices_cont()
                hist_indices, hist_values = self._get_histogram_indices()
                histogram_data = create_groups_table(data,
                                                     hist_indices,
                                                     values=hist_values)
            selected = np.nonzero(group_indices)[0]
            if selected.size:
                selected_data = create_groups_table(data,
                                                    group_indices,
                                                    include_unselected=False,
                                                    values=values)
                annotated_data = create_annotated_table(data, selected)

        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.histogram_data.send(histogram_data)

    def _get_output_indices_disc(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        for group_idx, val_idx in enumerate(self.selection, start=1):
            group_indices[col == val_idx] = group_idx
        values = [self.var.values[i] for i in self.selection]
        return group_indices, values

    def _get_output_indices_cont(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        values = []
        for group_idx, group in enumerate(self.grouped_selection(), start=1):
            x0 = x1 = None
            for bar_idx in group:
                minx, maxx, mask = self._get_cont_baritem_indices(col, bar_idx)
                if x0 is None:
                    x0 = minx
                x1 = maxx
                group_indices[mask] = group_idx
            # pylint: disable=undefined-loop-variable
            values.append(
                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
        return group_indices, values

    def _get_histogram_indices(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        values = []
        for bar_idx in range(len(self.bar_items)):
            x0, x1, mask = self._get_cont_baritem_indices(col, bar_idx)
            group_indices[mask] = bar_idx + 1
            values.append(
                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
        return group_indices, values

    def _get_cont_baritem_indices(self, col, bar_idx):
        bar_item = self.bar_items[bar_idx]
        minx = bar_item.x0
        maxx = bar_item.x1 + (bar_idx == len(self.bar_items) - 1)
        with np.errstate(invalid="ignore"):
            return minx, maxx, (col >= minx) * (col < maxx)

    def _is_last_bar(self, idx):
        return idx == len(self.bar_items) - 1

    # -----------------------------
    # Report

    def get_widget_name_extension(self):
        return self.var

    def send_report(self):
        self.plotview.scene().setSceneRect(self.plotview.sceneRect())
        if not self.is_valid:
            return
        self.report_plot()
        if self.cumulative_distr:
            text = f"Cummulative distribution of '{self.var.name}'"
        else:
            text = f"Distribution of '{self.var.name}'"
        if self.cvar:
            text += f" with columns split by '{self.cvar.name}'"
        self.report_caption(text)
Exemple #8
0
class OWImpute(OWWidget):
    name = "Impute"
    description = "Impute missing values in the data table."
    icon = "icons/Impute.svg"
    priority = 2130

    class Inputs:
        data = Input("Data", Orange.data.Table)
        learner = Input("Learner", Learner)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    class Error(OWWidget.Error):
        imputation_failed = Msg("Imputation failed for '{}'")
        model_based_imputer_sparse = Msg(
            "Model based imputer does not work for sparse data")

    settingsHandler = settings.DomainContextHandler()

    _default_method_index = settings.Setting(int(Method.Leave))  # type: int
    # Per-variable imputation state (synced in storeSpecificSettings)
    _variable_imputation_state = settings.ContextSetting(
        {})  # type: VariableState

    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()
        self.data = None  # type: Optional[Orange.data.Table]
        self.learner = None  # type: Optional[Learner]
        self.default_learner = SimpleTreeLearner()
        self.modified = False
        self.executor = qconcurrent.ThreadExecutor(self)
        self.__task = None

        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(10, 10, 10, 10)
        self.controlArea.layout().addLayout(main_layout)

        box = QGroupBox(title=self.tr("Default Method"), flat=False)
        box_layout = QVBoxLayout(box)
        main_layout.addWidget(box)

        button_group = QButtonGroup()
        button_group.buttonClicked[int].connect(self.set_default_method)

        for method, _ in list(METHODS.items())[1:-1]:
            imputer = self.create_imputer(method)
            button = QRadioButton(imputer.name)
            button.setChecked(method == self.default_method_index)
            button_group.addButton(button, method)
            box_layout.addWidget(button)

        self.default_button_group = button_group

        box = QGroupBox(title=self.tr("Individual Attribute Settings"),
                        flat=False)
        main_layout.addWidget(box)

        horizontal_layout = QHBoxLayout(box)
        main_layout.addWidget(box)

        self.varview = QListView(selectionMode=QListView.ExtendedSelection)
        self.varview.setItemDelegate(DisplayFormatDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._on_var_selection_changed)
        self.selection = self.varview.selectionModel()

        horizontal_layout.addWidget(self.varview)

        method_layout = QVBoxLayout()
        horizontal_layout.addLayout(method_layout)

        button_group = QButtonGroup()
        for method in Method:
            imputer = self.create_imputer(method)
            button = QRadioButton(text=imputer.name)
            button_group.addButton(button, method)
            method_layout.addWidget(button)

        self.value_combo = QComboBox(
            minimumContentsLength=8,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            activated=self._on_value_selected)
        self.value_double = QDoubleSpinBox(
            editingFinished=self._on_value_selected,
            minimum=-1000.,
            maximum=1000.,
            singleStep=.1,
            decimals=3,
        )
        self.value_stack = value_stack = QStackedWidget()
        value_stack.addWidget(self.value_combo)
        value_stack.addWidget(self.value_double)
        method_layout.addWidget(value_stack)

        button_group.buttonClicked[int].connect(
            self.set_method_for_current_selection)

        method_layout.addStretch(2)

        reset_button = QPushButton("Restore All to Default",
                                   checked=False,
                                   checkable=False,
                                   clicked=self.reset_variable_state,
                                   default=False,
                                   autoDefault=False)
        method_layout.addWidget(reset_button)

        self.variable_button_group = button_group

        box = gui.auto_commit(self.controlArea,
                              self,
                              "autocommit",
                              "Apply",
                              orientation=Qt.Horizontal,
                              checkbox_label="Apply automatically")
        box.button.setFixedWidth(180)
        box.layout().insertStretch(0)

    def create_imputer(self, method, *args):
        # type: (Method, ...) -> impute.BaseImputeMethod
        if method == Method.Model:
            if self.learner is not None:
                return impute.Model(self.learner)
            else:
                return impute.Model(self.default_learner)
        elif method == Method.AsAboveSoBelow:
            assert self.default_method_index != Method.AsAboveSoBelow
            default = self.create_imputer(Method(self.default_method_index))
            m = AsDefault()
            m.method = default
            return m
        else:
            return METHODS[method](*args)

    @property
    def default_method_index(self):
        return self._default_method_index

    @default_method_index.setter
    def default_method_index(self, index):
        if self._default_method_index != index:
            assert index != Method.AsAboveSoBelow
            self._default_method_index = index
            self.default_button_group.button(index).setChecked(True)
            # update variable view
            self.update_varview()
            self._invalidate()

    def set_default_method(self, index):
        """Set the current selected default imputation method.
        """
        self.default_method_index = index

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.varmodel[:] = []
        self._variable_imputation_state = {}  # type: VariableState
        self.modified = False
        self.data = data

        if data is not None:
            self.varmodel[:] = data.domain.variables
            self.openContext(data.domain)
            # restore per variable imputation state
            self._restore_state(self._variable_imputation_state)

        self.update_varview()
        self.unconditional_commit()

    @Inputs.learner
    def set_learner(self, learner):
        self.learner = learner or self.default_learner
        imputer = self.create_imputer(Method.Model)
        button = self.default_button_group.button(Method.Model)
        button.setText(imputer.name)

        variable_button = self.variable_button_group.button(Method.Model)
        variable_button.setText(imputer.name)

        if learner is not None:
            self.default_method_index = Method.Model

        self.update_varview()
        self.commit()

    def get_method_for_column(self, column_index):
        # type: (int) -> impute.BaseImputeMethod
        """
        Return the imputation method for column by its index.
        """
        assert 0 <= column_index < len(self.varmodel)
        idx = self.varmodel.index(column_index, 0)
        state = idx.data(StateRole)
        if state is None:
            state = (Method.AsAboveSoBelow, ())
        return self.create_imputer(state[0], *state[1])

    def _invalidate(self):
        self.modified = True
        if self.__task is not None:
            self.cancel()
        self.commit()

    def commit(self):
        self.cancel()
        self.warning()
        self.Error.imputation_failed.clear()
        self.Error.model_based_imputer_sparse.clear()

        if self.data is None or len(self.data) == 0 or len(self.varmodel) == 0:
            self.Outputs.data.send(self.data)
            self.modified = False
            return

        data = self.data
        impute_state = [(i, var, self.get_method_for_column(i))
                        for i, var in enumerate(self.varmodel)]
        # normalize to the effective method bypasing AsDefault
        impute_state = [(i, var, m.method if isinstance(m, AsDefault) else m)
                        for i, var, m in impute_state]

        def impute_one(method, var, data):
            # type: (impute.BaseImputeMethod, Variable, Table) -> Any
            if isinstance(method, impute.Model) and data.is_sparse():
                raise SparseNotSupported()
            elif isinstance(method, impute.DropInstances):
                return RowMask(method(data, var))
            elif not method.supports_variable(var):
                raise VariableNotSupported(var)
            else:
                return method(data, var)

        futures = []
        for _, var, method in impute_state:
            f = self.executor.submit(impute_one, copy.deepcopy(method), var,
                                     data)
            futures.append(f)

        w = qconcurrent.FutureSetWatcher(futures)
        w.doneAll.connect(self.__commit_finish)
        w.progressChanged.connect(self.__progress_changed)
        self.__task = Task(futures, w)
        self.progressBarInit(processEvents=False)
        self.setBlocking(True)

    @Slot()
    def __commit_finish(self):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        futures = self.__task.futures
        assert len(futures) == len(self.varmodel)
        assert self.data is not None

        self.__task = None
        self.setBlocking(False)
        self.progressBarFinished()

        data = self.data
        attributes = []
        class_vars = []
        drop_mask = np.zeros(len(self.data), bool)

        for i, (var, fut) in enumerate(zip(self.varmodel, futures)):
            assert fut.done()
            newvar = []
            try:
                res = fut.result()
            except SparseNotSupported:
                self.Error.model_based_imputer_sparse()
                # ?? break
            except VariableNotSupported:
                self.warning("Default method can not handle '{}'".format(
                    var.name))
            except Exception:  # pylint: disable=broad-except
                log = logging.getLogger(__name__)
                log.info("Error for %s", var, exc_info=True)
                self.Error.imputation_failed(var.name)
                attributes = class_vars = None
                break
            else:
                if isinstance(res, RowMask):
                    drop_mask |= res.mask
                    newvar = var
                else:
                    newvar = res

            if isinstance(newvar, Orange.data.Variable):
                newvar = [newvar]

            if i < len(data.domain.attributes):
                attributes.extend(newvar)
            else:
                class_vars.extend(newvar)

        if attributes is None:
            data = None
        else:
            domain = Orange.data.Domain(attributes, class_vars,
                                        data.domain.metas)
            try:
                data = self.data.from_table(domain, data[~drop_mask])
            except Exception:  # pylint: disable=broad-except
                log = logging.getLogger(__name__)
                log.info("Error", exc_info=True)
                self.Error.imputation_failed("Unknown")
                data = None

        self.Outputs.data.send(data)
        self.modified = False

    @Slot(int, int)
    def __progress_changed(self, n, d):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        self.progressBarSet(100. * n / d)

    def cancel(self):
        if self.__task is not None:
            task, self.__task = self.__task, None
            task.cancel()
            task.watcher.doneAll.disconnect(self.__commit_finish)
            task.watcher.progressChanged.disconnect(self.__progress_changed)
            concurrent.futures.wait(task.futures)
            task.watcher.flush()
            self.progressBarFinished()
            self.setBlocking(False)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    def send_report(self):
        specific = []
        for i, var in enumerate(self.varmodel):
            method = self.get_method_for_column(i)
            if not isinstance(method, AsDefault):
                specific.append("{} ({})".format(var.name, str(method)))

        default = self.create_imputer(Method.AsAboveSoBelow)
        if specific:
            self.report_items((("Default method", default.name),
                               ("Specific imputers", ", ".join(specific))))
        else:
            self.report_items((("Method", default.name), ))

    def _on_var_selection_changed(self):
        indexes = self.selection.selectedIndexes()
        defmethod = (Method.AsAboveSoBelow, ())
        methods = [index.data(StateRole) for index in indexes]
        methods = [m if m is not None else defmethod for m in methods]
        methods = set(methods)
        selected_vars = [self.varmodel[index.row()] for index in indexes]
        has_discrete = any(var.is_discrete for var in selected_vars)
        fixed_value = None
        value_stack_enabled = False
        current_value_widget = None

        if len(methods) == 1:
            method_type, parameters = methods.pop()
            for m in Method:
                if method_type == m:
                    self.variable_button_group.button(m).setChecked(True)

            if method_type == Method.Default:
                (fixed_value, ) = parameters

        elif self.variable_button_group.checkedButton() is not None:
            # Uncheck the current button
            self.variable_button_group.setExclusive(False)
            self.variable_button_group.checkedButton().setChecked(False)
            self.variable_button_group.setExclusive(True)
            assert self.variable_button_group.checkedButton() is None

        # Update variable methods GUI enabled state based on selection.
        for method in Method:
            # use a default constructed imputer to query support
            imputer = self.create_imputer(method)
            enabled = all(
                imputer.supports_variable(var) for var in selected_vars)
            button = self.variable_button_group.button(method)
            button.setEnabled(enabled)

        # Update the "Value" edit GUI.
        if not has_discrete:
            # no discrete variables -> allow mass edit for all (continuous vars)
            value_stack_enabled = True
            current_value_widget = self.value_double
        elif len(selected_vars) == 1:
            # single discrete var -> enable and fill the values combo
            value_stack_enabled = True
            current_value_widget = self.value_combo
            self.value_combo.clear()
            self.value_combo.addItems(selected_vars[0].values)
        else:
            # mixed type selection -> disable
            value_stack_enabled = False
            current_value_widget = None
            self.variable_button_group.button(Method.Default).setEnabled(False)

        self.value_stack.setEnabled(value_stack_enabled)
        if current_value_widget is not None:
            self.value_stack.setCurrentWidget(current_value_widget)
            if fixed_value is not None:
                # set current value
                if current_value_widget is self.value_combo:
                    self.value_combo.setCurrentIndex(fixed_value)
                elif current_value_widget is self.value_double:
                    self.value_double.setValue(fixed_value)
                else:
                    assert False

    def set_method_for_current_selection(self, method_index):
        # type: (Method) -> None
        indexes = self.selection.selectedIndexes()
        self.set_method_for_indexes(indexes, method_index)

    def set_method_for_indexes(self, indexes, method_index):
        # type: (List[QModelIndex], Method) -> None
        if method_index == Method.AsAboveSoBelow:
            for index in indexes:
                self.varmodel.setData(index, None, StateRole)
        elif method_index == Method.Default:
            current = self.value_stack.currentWidget()
            if current is self.value_combo:
                value = self.value_combo.currentIndex()
            else:
                value = self.value_double.value()
            for index in indexes:
                state = (int(Method.Default), (value, ))
                self.varmodel.setData(index, state, StateRole)
        else:
            state = (int(method_index), ())
            for index in indexes:
                self.varmodel.setData(index, state, StateRole)

        self.update_varview(indexes)
        self._invalidate()

    def update_varview(self, indexes=None):
        if indexes is None:
            indexes = map(self.varmodel.index, range(len(self.varmodel)))

        for index in indexes:
            self.varmodel.setData(index,
                                  self.get_method_for_column(index.row()),
                                  DisplayMethodRole)

    def _on_value_selected(self):
        # The fixed 'Value' in the widget has been changed by the user.
        self.variable_button_group.button(Method.Default).setChecked(True)
        self.set_method_for_current_selection(Method.Default)

    def reset_variable_state(self):
        indexes = list(map(self.varmodel.index, range(len(self.varmodel))))
        self.set_method_for_indexes(indexes, Method.AsAboveSoBelow)
        self.variable_button_group.button(
            Method.AsAboveSoBelow).setChecked(True)

    def _store_state(self):
        # type: () -> VariableState
        """
        Save the current variable imputation state
        """
        state = {}  # type: VariableState
        for i, var in enumerate(self.varmodel):
            index = self.varmodel.index(i)
            m = index.data(StateRole)
            if m is not None:
                state[var_key(var)] = m
        return state

    def _restore_state(self, state):
        # type: (VariableState) -> None
        """
        Restore the variable imputation state from the saved state
        """
        def check(state):
            # check if state is a proper State
            if isinstance(state, tuple) and len(state) == 2:
                m, p = state
                if isinstance(m, int) and isinstance(p, tuple) and \
                        0 <= m < len(Method):
                    return True
            return False

        for i, var in enumerate(self.varmodel):
            m = state.get(var_key(var), None)
            if check(m):
                self.varmodel.setData(self.varmodel.index(i), m, StateRole)

    def storeSpecificSettings(self):
        self._variable_imputation_state = self._store_state()
        super().storeSpecificSettings()
Exemple #9
0
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results", Results)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")
        wrong_targets = Msg(
            "Some model(s) predict a different target (see more ...)\n{}")

    class Error(OWWidget.Error):
        predictor_failed = Msg("Some predictor(s) failed (see more ...)\n{}")
        scorer_failed = Msg("Some scorer(s) failed (see more ...)\n{}")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: List of selected class value indices in the `class_values` list
    selected_classes = settings.ContextSetting([])
    selection = settings.Setting([], schema_only=True)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []
        self.left_width = 10
        self.selection_store = None
        self.__pending_selection = self.selection

        self._set_input_summary()
        self._set_output_summary(None)

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.ExtendedSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        self.reset_button = gui.button(
            self.controlArea,
            self,
            "Restore Original Order",
            callback=self._reset_order,
            tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.ExtendedSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(sortingEnabled=True,
                                  verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.dataview.setItemDelegate(DataItemDelegate(self.dataview))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.splitterMoved.connect(self.splitter_resized)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)

    def get_selection_store(self, proxy):
        # Both proxies map the same, so it doesn't matter which one is used
        # to initialize SharedSelectionStore
        if self.selection_store is None:
            self.selection_store = SharedSelectionStore(proxy)
        return self.selection_store

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.Warning.empty_data(shown=data is not None and not data)
        self.data = data
        self.selection_store = None
        if not data:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = SortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            sel_model = SharedSelectionModel(
                self.get_selection_store(modelproxy), modelproxy,
                self.dataview)
            self.dataview.setSelectionModel(sel_model)
            if self.__pending_selection is not None:
                self.selection = self.__pending_selection
                self.__pending_selection = None
                self.selection_store.select_rows(
                    set(self.selection), QItemSelectionModel.ClearAndSelect)
            sel_model.selectionChanged.connect(self.commit)
            sel_model.selectionChanged.connect(self._store_selection)

            self.dataview.model().list_sorted.connect(
                partial(self._update_data_sort_order, self.dataview,
                        self.predictionsview))

        self._invalidate_predictions()

    def _store_selection(self):
        self.selection = list(self.selection_store.rows)

    @property
    def class_var(self):
        return self.data and self.data.domain.class_var

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = PredictorSlot(predictor, predictor.name,
                                                None)

    def _set_class_values(self):
        class_values = []
        for slot in self.predictors.values():
            class_var = slot.predictor.domain.class_var
            if class_var and class_var.is_discrete:
                for value in class_var.values:
                    if value not in class_values:
                        class_values.append(value)

        if self.class_var and self.class_var.is_discrete:
            values = self.class_var.values
            self.class_values = sorted(class_values,
                                       key=lambda val: val not in values)
            self.selected_classes = [
                i for i, name in enumerate(class_values) if name in values
            ]
        else:
            self.class_values = class_values  # This assignment updates listview
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_values()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._set_input_summary()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        if self.class_var:
            domain = self.data.domain
            classless_data = self.data.transform(
                Domain(domain.attributes, None, domain.metas))
        else:
            classless_data = self.data

        for inputid, slot in self.predictors.items():
            if isinstance(slot.results, Results):
                continue

            predictor = slot.predictor
            try:
                if predictor.domain.class_var.is_discrete:
                    pred, prob = predictor(classless_data, Model.ValueProbs)
                else:
                    pred = predictor(classless_data, Model.Value)
                    prob = numpy.zeros((len(pred), 0))
            except (ValueError, DomainTransformationError) as err:
                self.predictors[inputid] = \
                    slot._replace(results=f"{predictor.name}: {err}")
                continue

            results = Results()
            results.data = self.data
            results.domain = self.data.domain
            results.row_indices = numpy.arange(len(self.data))
            results.folds = (Ellipsis, )
            results.actual = self.data.Y
            results.unmapped_probabilities = prob
            results.unmapped_predicted = pred
            results.probabilities = results.predicted = None
            self.predictors[inputid] = slot._replace(results=results)

            target = predictor.domain.class_var
            if target != self.class_var:
                continue

            if target is not self.class_var and target.is_discrete:
                backmappers, n_values = predictor.get_backmappers(self.data)
                prob = predictor.backmap_probs(prob, n_values, backmappers)
                pred = predictor.backmap_value(pred, prob, n_values,
                                               backmappers)
            results.predicted = pred.reshape((1, len(self.data)))
            results.probabilities = prob.reshape((1, ) + prob.shape)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()
        scorers = usable_scorers(self.class_var) if self.class_var else []
        self.score_table.update_header(scorers)
        errors = []
        for inputid, pred in self.predictors.items():
            results = self.predictors[inputid].results
            if not isinstance(results, Results) or results.predicted is None:
                continue
            row = [
                QStandardItem(learner_name(pred.predictor)),
                QStandardItem("N/A"),
                QStandardItem("N/A")
            ]
            for scorer in scorers:
                item = QStandardItem()
                try:
                    score = scorer_caller(scorer, results)()[0]
                    item.setText(f"{score:.3f}")
                except Exception as exc:  # pylint: disable=broad-except
                    item.setToolTip(str(exc))
                    if scorer.name in self.score_table.shown_scores:
                        errors.append(str(exc))
                row.append(item)
            self.score_table.model.appendRow(row)

        view = self.score_table.view
        if model.rowCount():
            view.setVisible(True)
            view.ensurePolished()
            view.setFixedHeight(5 + view.horizontalHeader().height() +
                                view.verticalHeader().sectionSize(0) *
                                model.rowCount())
        else:
            view.setVisible(False)

        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(f"- {p.predictor.name}: {p.results}"
                           for p in self.predictors.values()
                           if isinstance(p.results, str) and p.results)
        self.Error.predictor_failed(errors, shown=bool(errors))

        if self.class_var:
            inv_targets = "\n".join(
                f"- {pred.name} predicts '{pred.domain.class_var.name}'"
                for pred in (p.predictor for p in self.predictors.values()
                             if isinstance(p.results, Results)
                             and p.results.probabilities is None))
            self.Warning.wrong_targets(inv_targets, shown=bool(inv_targets))
        else:
            self.Warning.wrong_targets.clear()

    def _set_input_summary(self):
        if not self.data and not self.predictors:
            self.info.set_input_summary(self.info.NoInput)
            return

        summary = len(self.data) if self.data else 0
        details = self._get_details()
        self.info.set_input_summary(summary, details, format=Qt.RichText)

    def _get_details(self):
        details = "Data:<br>"
        details += format_summary_details(self.data).replace('\n', '<br>') if \
            self.data else "No data on input."
        details += "<hr>"
        pred_names = [v.name for v in self.predictors.values()]
        n_predictors = len(self.predictors)
        if n_predictors:
            n_valid = len(self._non_errored_predictors())
            details += plural("Model: {number} model{s}", n_predictors)
            if n_valid != n_predictors:
                details += f" ({n_predictors - n_valid} failed)"
            details += "<ul>"
            for name in pred_names:
                details += f"<li>{name}</li>"
            details += "</ul>"
        else:
            details += "Model:<br>No model on input."
        return details

    def _set_output_summary(self, output):
        summary = len(output) if output else self.info.NoOutput
        details = format_summary_details(output) if output else ""
        self.info.set_output_summary(summary, details)

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _non_errored_predictors(self):
        return [
            p for p in self.predictors.values()
            if isinstance(p.results, Results)
        ]

    def _reordered_probabilities(self, prediction):
        cur_values = prediction.predictor.domain.class_var.values
        new_ind = [self.class_values.index(x) for x in cur_values]
        probs = prediction.results.unmapped_probabilities
        new_probs = numpy.full((probs.shape[0], len(self.class_values)),
                               numpy.nan)
        new_probs[:, new_ind] = probs
        return new_probs

    def _update_predictions_model(self):
        results = []
        headers = []
        for p in self._non_errored_predictors():
            values = p.results.unmapped_predicted
            target = p.predictor.domain.class_var
            if target.is_discrete:
                # order probabilities in order from Show prob. for
                prob = self._reordered_probabilities(p)
                values = [Value(target, v) for v in values]
            else:
                prob = numpy.zeros((len(values), 0))
            results.append((values, prob))
            headers.append(p.predictor.name)

        if results:
            results = list(zip(*(zip(*res) for res in results)))
            model = PredictionsModel(results, headers)
        else:
            model = None

        if self.selection_store is not None:
            self.selection_store.unregister(
                self.predictionsview.selectionModel())

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setModel(predmodel)

        self.predictionsview.setSelectionModel(
            SharedSelectionModel(self.get_selection_store(predmodel),
                                 predmodel, self.predictionsview))

        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        self.predictionsview.model().list_sorted.connect(
            partial(self._update_data_sort_order, self.predictionsview,
                    self.dataview))

        self.predictionsview.resizeColumnsToContents()

    def _update_data_sort_order(self, sort_source_view, sort_dest_view):
        sort_dest = sort_dest_view.model()
        sort_source = sort_source_view.model()
        sortindicatorshown = False
        if sort_dest is not None:
            assert isinstance(sort_dest, QSortFilterProxyModel)
            n = sort_dest.rowCount()
            if sort_source is not None and sort_source.sortColumn() >= 0:
                sortind = numpy.argsort([
                    sort_source.mapToSource(sort_source.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            sort_dest.setSortIndices(sortind)

        sort_dest_view.horizontalHeader().setSortIndicatorShown(False)
        sort_source_view.horizontalHeader().setSortIndicatorShown(
            sortindicatorshown)
        self.commit()

    def _reset_order(self):
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.setSortIndices(None)
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.setSortIndices(None)
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)
        self.dataview.horizontalHeader().setSortIndicatorShown(False)

    def _all_color_values(self):
        """
        Return list of colors together with their values from all predictors
        classes. Colors and values are sorted according to the values order
        for simpler comparison.
        """
        predictors = self._non_errored_predictors()
        color_values = [
            list(
                zip(*sorted(zip(p.predictor.domain.class_var.colors,
                                p.predictor.domain.class_var.values),
                            key=itemgetter(1)))) for p in predictors
            if p.predictor.domain.class_var.is_discrete
        ]
        return color_values if color_values else [([], [])]

    @staticmethod
    def _colors_match(colors1, values1, color2, values2):
        """
        Test whether colors for values match. Colors matches when all
        values match for shorter list and colors match for shorter list.
        It is assumed that values will be sorted together with their colors.
        """
        shorter_length = min(len(colors1), len(color2))
        return (values1[:shorter_length] == values2[:shorter_length]
                and (numpy.array(colors1[:shorter_length]) == numpy.array(
                    color2[:shorter_length])).all())

    def _get_colors(self):
        """
        Defines colors for values. If colors match in all models use the union
        otherwise use standard colors.
        """
        all_colors_values = self._all_color_values()
        base_color, base_values = all_colors_values[0]
        for c, v in all_colors_values[1:]:
            if not self._colors_match(base_color, base_values, c, v):
                base_color = []
                break
            # replace base_color if longer
            if len(v) > len(base_color):
                base_color = c
                base_values = v

        if len(base_color) != len(self.class_values):
            return LimitedDiscretePalette(len(self.class_values)).palette
        # reorder colors to widgets order
        colors = [None] * len(self.class_values)
        for c, v in zip(base_color, base_values):
            colors[self.class_values.index(v)] = c
        return colors

    def _update_prediction_delegate(self):
        self._delegates.clear()
        colors = self._get_colors()
        for col, slot in enumerate(self.predictors.values()):
            target = slot.predictor.domain.class_var
            shown_probs = (() if target.is_continuous else [
                val if self.class_values[val] in target.values else None
                for val in self.selected_classes
            ])
            delegate = PredictionsItemDelegate(
                None if target.is_continuous else self.class_values,
                colors,
                shown_probs,
                target.format_str if target.is_continuous else None,
                parent=self.predictionsview)
            # QAbstractItemView does not take ownership of delegates, so we must
            self._delegates.append(delegate)
            self.predictionsview.setItemDelegateForColumn(col, delegate)
            self.predictionsview.setColumnHidden(col, False)

        self.predictionsview.resizeColumnsToContents()
        self._recompute_splitter_sizes()
        if self.predictionsview.model() is not None:
            self.predictionsview.model().setProbInd(self.selected_classes)

    def _recompute_splitter_sizes(self):
        if not self.data:
            return
        view = self.predictionsview
        self.left_width = \
            view.horizontalHeader().length() + view.verticalHeader().width()
        self._update_splitter()

    def _update_splitter(self):
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([self.left_width, w1 + w2 - self.left_width])

    def splitter_resized(self):
        self.left_width = self.splitter.sizes()[0]

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = [
            p for p in self._non_errored_predictors()
            if p.results.predicted is not None
        ]
        if not slots:
            self.Outputs.evaluation_results.send(None)
            return

        nanmask = numpy.isnan(self.data.get_column_view(self.class_var)[0])
        data = self.data[~nanmask]
        results = Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if self.class_var and self.class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        if not self.data:
            self._set_output_summary(None)
            self.Outputs.predictions.send(None)
            return

        newmetas = []
        newcolumns = []
        for slot in self._non_errored_predictors():
            if slot.predictor.domain.class_var.is_discrete:
                self._add_classification_out_columns(slot, newmetas,
                                                     newcolumns)
            else:
                self._add_regression_out_columns(slot, newmetas, newcolumns)

        attrs = list(self.data.domain.attributes)
        metas = list(self.data.domain.metas)
        names = [
            var.name
            for var in chain(attrs, self.data.domain.class_vars, metas) if var
        ]
        uniq_newmetas = []
        for new_ in newmetas:
            uniq = get_unique_names(names, new_.name)
            if uniq != new_.name:
                new_ = new_.copy(name=uniq)
            uniq_newmetas.append(new_)
            names.append(uniq)

        metas += uniq_newmetas
        domain = Orange.data.Domain(attrs, self.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns

        index = self.dataview.model().index
        map_to = self.dataview.model().mapToSource
        assert self.selection_store is not None
        rows = None
        if self.selection_store.rows:
            rows = [
                ind.row()
                for ind in self.dataview.selectionModel().selectedRows(0)
            ]
            rows.sort()
        elif self.dataview.model().isSorted() \
                or self.predictionsview.model().isSorted():
            rows = list(range(len(self.data)))
        if rows:
            source_rows = [map_to(index(row, 0)).row() for row in rows]
            predictions = predictions[source_rows]
        self.Outputs.predictions.send(predictions)
        self._set_output_summary(predictions)

    @staticmethod
    def _add_classification_out_columns(slot, newmetas, newcolumns):
        # Mapped or unmapped predictions?!
        # Or provide a checkbox so the user decides?
        pred = slot.predictor
        name = pred.name
        values = pred.domain.class_var.values
        newmetas.append(DiscreteVariable(name=name, values=values))
        newcolumns.append(slot.results.unmapped_predicted.reshape(-1, 1))
        newmetas += [
            ContinuousVariable(name=f"{name} ({value})") for value in values
        ]
        newcolumns.append(slot.results.unmapped_probabilities)

    @staticmethod
    def _add_regression_out_columns(slot, newmetas, newcolumns):
        newmetas.append(ContinuousVariable(name=slot.predictor.name))
        newcolumns.append(slot.results.unmapped_predicted.reshape((-1, 1)))

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_view = self.predictionsview
            predictions_model = predictions_view.model()

            # use ItemDelegate to style prediction values
            delegates = [
                predictions_view.itemDelegateForColumn(i)
                for i in range(predictions_model.columnCount())
            ]

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [delegate.displayText(
                          predictions_model.data(predictions_model.index(i, j)),
                          QLocale())
                       for j, delegate in enumerate(delegates)] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data:
            text = self._get_details().replace('\n', '<br>')
            if self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)

            self.report_table("Scores", self.score_table.view)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._update_splitter()

    def showEvent(self, event):
        super().showEvent(event)
        QTimer.singleShot(0, self._update_splitter)
Exemple #10
0
class OWLinearProjection(widget.OWWidget):
    name = "Linear Projection"
    description = "A multi-axis projection of data onto " \
                  "a two-dimensional plane."
    icon = "icons/LinearProjection.svg"
    priority = 240

    selection_indices = settings.Setting(None, schema_only=True)

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        projection = Input("Projection", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    Placement = Enum("Placement",
                     dict(Circular=0,
                          LDA=1,
                          PCA=2,
                          Projection=3),
                     type=int,
                     qualname="OWLinearProjection.Placement")

    Component_name = {Placement.Circular: "C", Placement.LDA: "LD", Placement.PCA: "PC"}
    Variable_name = {Placement.Circular: "circular",
                     Placement.LDA: "lda",
                     Placement.PCA: "pca",
                     Placement.Projection: "projection"}

    jitter_sizes = [0, 0.1, 0.5, 1.0, 2.0]

    settings_version = 3
    settingsHandler = settings.DomainContextHandler()

    variable_state = settings.ContextSetting({})
    placement = settings.Setting(Placement.Circular)
    radius = settings.Setting(0)
    auto_commit = settings.Setting(True)

    resolution = 256

    graph = settings.SettingProvider(OWLinProjGraph)
    ReplotRequest = QEvent.registerEventType()
    vizrank = settings.SettingProvider(LinearProjectionVizRank)
    graph_name = "graph.plot_widget.plotItem"

    class Warning(widget.OWWidget.Warning):
        no_cont_features = widget.Msg("Plotting requires numeric features")
        not_enough_components = widget.Msg("Input projection has less than 2 components")
        trivial_components = widget.Msg(
            "All components of the PCA are trivial (explain 0 variance). "
            "Input data is constant (or near constant).")

    class Error(widget.OWWidget.Error):
        proj_and_domain_match = widget.Msg("Projection and Data domains do not match")
        no_valid_data = widget.Msg("No projection due to invalid data")

    def __init__(self):
        super().__init__()

        self.data = None
        self.projection = None
        self.subset_data = None
        self._subset_mask = None
        self._selection = None
        self.__replot_requested = False
        self.n_cont_var = 0
        #: Remember the saved state to restore
        self.__pending_selection_restore = self.selection_indices
        self.selection_indices = None

        self.variable_x = None
        self.variable_y = None

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWLinProjGraph(self, box, "Plot", view_box=LinProjInteractiveViewBox)
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        SIZE_POLICY = (QSizePolicy.Minimum, QSizePolicy.Maximum)

        self.variables_selection = VariablesSelection()
        self.model_selected = VariableListModel(enable_dnd=True)
        self.model_other = VariableListModel(enable_dnd=True)
        self.variables_selection(self, self.model_selected, self.model_other)

        self.vizrank, self.btn_vizrank = LinearProjectionVizRank.add_vizrank(
            self.controlArea, self, "Suggest Features", self._vizrank)
        self.variables_selection.add_remove.layout().addWidget(self.btn_vizrank)

        box = gui.widgetBox(
            self.controlArea, "Placement", sizePolicy=SIZE_POLICY)
        self.radio_placement = gui.radioButtonsInBox(
            box, self, "placement",
            btnLabels=["Circular Placement",
                       "Linear Discriminant Analysis",
                       "Principal Component Analysis",
                       "Use input projection"],
            callback=self._change_placement
        )

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        box = g.point_properties_box(self.controlArea)
        self.models = g.points_models
        g.add_widget(g.JitterSizeSlider, box)
        box.setSizePolicy(*SIZE_POLICY)

        box = gui.widgetBox(self.controlArea, "Hide axes", sizePolicy=SIZE_POLICY)
        self.rslider = gui.hSlider(
            box, self, "radius", minValue=0, maxValue=100,
            step=5, label="Radius", createLabel=False, ticks=True,
            callback=self.update_radius)
        self.rslider.setTickInterval(0)
        self.rslider.setPageStep(10)

        box = gui.vBox(self.controlArea, "Plot Properties")
        box.setSizePolicy(*SIZE_POLICY)

        g.add_widgets([g.ShowLegend,
                       g.ToolTipShowsAll,
                       g.ClassDensity,
                       g.LabelOnlySelected], box)

        box = self.graph.box_zoom_select(self.controlArea)
        box.setSizePolicy(*SIZE_POLICY)

        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)
        gui.auto_commit(self.controlArea, self, "auto_commit", "Send Selection",
                        auto_label="Send Automatically")
        self.graph.zoom_actions(self)

        self._new_plotdata()
        self._change_placement()
        self.graph.jitter_continuous = True

    def reset_graph_data(self):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph(reset_view=True)

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def _vizrank(self, attrs):
        self.variables_selection.display_none()
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [var for var in self.model_other if var not in attrs]

    def _change_placement(self):
        placement = self.placement
        p_Circular = self.Placement.Circular
        p_LDA = self.Placement.LDA
        self.variables_selection.set_enabled(placement in [p_Circular, p_LDA])
        self._vizrank_color_change()
        self.rslider.setEnabled(placement != p_Circular)
        self._setup_plot()
        self.commit()

    def _get_min_radius(self):
        return self.radius * np.max(np.linalg.norm(self.plotdata.axes, axis=1)) / 100 + 1e-5

    def update_radius(self):
        # Update the anchor/axes visibility
        pd = self.plotdata
        assert pd is not None
        if pd.hidecircle is None:
            return
        min_radius = self._get_min_radius()
        for anchor, item in zip(pd.axes, pd.axisitems):
            item.setVisible(np.linalg.norm(anchor) > min_radius)
        pd.hidecircle.setRect(QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius))

    def _new_plotdata(self):
        self.plotdata = namespace(
            valid_mask=None,
            embedding_coords=None,
            axisitems=[],
            axes=[],
            variables=[],
            data=None,
            hidecircle=None
        )

    def _anchor_circle(self, variables):
        # minimum visible anchor radius (radius)
        min_radius = self._get_min_radius()
        axisitems = []
        for anchor, var in zip(self.plotdata.axes, variables[:]):
            axitem = AnchorItem(line=QLineF(0, 0, *anchor), text=var.name,)
            axitem.setVisible(np.linalg.norm(anchor) > min_radius)
            axitem.setPen(pg.mkPen((100, 100, 100)))
            axitem.setArrowVisible(True)
            self.viewbox.addItem(axitem)
            axisitems.append(axitem)

        self.plotdata.axisitems = axisitems
        if self.placement == self.Placement.Circular:
            return

        hidecircle = QGraphicsEllipseItem()
        hidecircle.setRect(QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius))

        _pen = QPen(Qt.lightGray, 1)
        _pen.setCosmetic(True)
        hidecircle.setPen(_pen)

        self.viewbox.addItem(hidecircle)
        self.plotdata.hidecircle = hidecircle

    def update_colors(self):
        self._vizrank_color_change()

    def clear(self):
        # Clear/reset the widget state
        self.data = None
        self.model_selected.clear()
        self.model_other.clear()
        self._clear_plot()
        self.selection_indices = None

    def _clear_plot(self):
        self.Warning.trivial_components.clear()
        for axisitem in self.plotdata.axisitems:
            self.viewbox.removeItem(axisitem)
        if self.plotdata.hidecircle:
            self.viewbox.removeItem(self.plotdata.hidecircle)
        self._new_plotdata()
        self.graph.hide_axes()

    def invalidate_plot(self):
        """
        Schedule a delayed replot.
        """
        if not self.__replot_requested:
            self.__replot_requested = True
            QApplication.postEvent(self, QEvent(self.ReplotRequest), Qt.LowEventPriority - 10)

    def init_attr_values(self):
        self.graph.set_domain(self.data)

    def _vizrank_color_change(self):
        is_enabled = False
        if self.data is None:
            self.btn_vizrank.setToolTip("There is no data.")
            return
        vars = [v for v in chain(self.data.domain.variables, self.data.domain.metas) if
                v.is_primitive and v is not self.graph.attr_color]
        self.n_cont_var = len(vars)
        if self.placement not in [self.Placement.Circular, self.Placement.LDA]:
            msg = "Suggest Features works only for Circular and " \
                  "Linear Discriminant Analysis Projection"
        elif self.graph.attr_color is None:
            msg = "Color variable has to be selected"
        elif self.graph.attr_color.is_continuous and self.placement == self.Placement.LDA:
            msg = "Suggest Features does not work for Linear Discriminant Analysis Projection " \
                  "when continuous color variable is selected."
        elif len(vars) < 3:
            msg = "Not enough available continuous variables"
        else:
            is_enabled = True
            msg = ""
        self.btn_vizrank.setToolTip(msg)
        self.btn_vizrank.setEnabled(is_enabled)
        self.vizrank.stop_and_reset(is_enabled)

    @Inputs.projection
    def set_projection(self, projection):
        self.Warning.not_enough_components.clear()
        if projection and len(projection) < 2:
            self.Warning.not_enough_components()
            projection = None
        if projection is not None:
            self.placement = self.Placement.Projection
        self.projection = projection

    @Inputs.data
    def set_data(self, data):
        """
        Set the input dataset.

        Args:
            data (Orange.data.table): data instances
        """
        def sql(data):
            if isinstance(data, SqlTable):
                if data.approx_len() < 4000:
                    data = Table(data)
                else:
                    self.information("Data has been sampled")
                    data_sample = data.sample_time(1, no_cache=True)
                    data_sample.download_data(2000, partial=True)
                    data = Table(data_sample)
            return data

        def settings(data):
            # get the default encoded state, replacing the position with Inf
            state = VariablesSelection.encode_var_state(
                [list(self.model_selected), list(self.model_other)]
            )
            state = {key: (source_ind, np.inf) for key, (source_ind, _) in state.items()}

            self.openContext(data.domain)
            selected_keys = [key for key, (sind, _) in self.variable_state.items() if sind == 0]

            if set(selected_keys).issubset(set(state.keys())):
                pass

            if self.__pending_selection_restore is not None:
                self._selection = np.array(self.__pending_selection_restore, dtype=int)
                self.__pending_selection_restore = None

            # update the defaults state (the encoded state must contain
            # all variables in the input domain)
            state.update(self.variable_state)
            # ... and restore it with saved positions taking precedence over
            # the defaults
            selected, other = VariablesSelection.decode_var_state(
                state, [list(self.model_selected), list(self.model_other)])
            return selected, other

        self.closeContext()
        self.clear()
        self.Warning.no_cont_features.clear()
        self.information()
        data = sql(data)
        if data is not None:
            domain = data.domain
            vars = [var for var in chain(domain.variables, domain.metas) if var.is_continuous]
            if not len(vars):
                self.Warning.no_cont_features()
                data = None
        self.data = data
        self.init_attr_values()
        if data is not None and len(data):
            self._initialize(data)
            self.model_selected[:], self.model_other[:] = settings(data)
            self.vizrank.stop_and_reset()
            self.vizrank.attrs = self.data.domain.attributes if self.data is not None else []

    def _check_possible_opt(self):
        def set_enabled(is_enabled):
            for btn in self.radio_placement.buttons:
                btn.setEnabled(is_enabled)
            self.variables_selection.set_enabled(is_enabled)

        p_Circular = self.Placement.Circular
        p_LDA = self.Placement.LDA
        p_Input = self.Placement.Projection
        if self.data:
            set_enabled(True)
            domain = self.data.domain
            if not domain.has_discrete_class or len(domain.class_var.values) < 2:
                self.radio_placement.buttons[p_LDA].setEnabled(False)
                if self.placement == p_LDA:
                    self.placement = p_Circular
            if not self.projection:
                self.radio_placement.buttons[p_Input].setEnabled(False)
                if self.placement == p_Input:
                    self.placement = p_Circular
            self._setup_plot()
        else:
            self.graph.new_data(None)
            self.rslider.setEnabled(False)
            set_enabled(False)
        self.commit()

    @Inputs.data_subset
    def set_subset_data(self, subset):
        """
        Set the supplementary input subset dataset.

        Args:
            subset (Orange.data.table): subset of data instances
        """
        self.subset_data = subset
        self._subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if self.data is not None and self.subset_data is not None:
            # Update the plot's highlight items
            dataids = self.data.ids.ravel()
            subsetids = np.unique(self.subset_data.ids)
            self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True)
        self._check_possible_opt()
        self._change_placement()
        self.commit()

    def customEvent(self, event):
        if event.type() == OWLinearProjection.ReplotRequest:
            self.__replot_requested = False
            self._setup_plot()
            self.commit()
        else:
            super().customEvent(event)

    def closeContext(self):
        self.variable_state = VariablesSelection.encode_var_state(
            [list(self.model_selected), list(self.model_other)]
        )
        super().closeContext()

    def _initialize(self, data):
        # Initialize the GUI controls from data's domain.
        vars = [v for v in chain(data.domain.metas, data.domain.attributes) if v.is_continuous]
        self.model_other[:] = vars[3:]
        self.model_selected[:] = vars[:3]

    def prepare_plot_data(self, variables):
        def projection(variables):
            if set(self.projection.domain.attributes).issuperset(variables):
                axes = self.projection[:2, variables].X
            elif set(f.name for f in
                     self.projection.domain.attributes).issuperset(f.name for f in variables):
                axes = self.projection[:2, [f.name for f in variables]].X
            else:
                self.Error.proj_and_domain_match()
                axes = None
            return axes

        def get_axes(variables):
            self.Error.proj_and_domain_match.clear()
            axes = None
            if self.placement == self.Placement.Circular:
                axes = LinProj.defaultaxes(len(variables))
            elif self.placement == self.Placement.LDA:
                axes = self._get_lda(self.data, variables)
            elif self.placement == self.Placement.Projection and self.projection:
                axes = projection(variables)
            return axes

        coords = [column_data(self.data, var, dtype=float) for var in variables]
        coords = np.vstack(coords)
        p, N = coords.shape
        assert N == len(self.data), p == len(variables)

        axes = get_axes(variables)
        if axes is None:
            return None, None, None
        assert axes.shape == (2, p)

        valid_mask = ~np.isnan(coords).any(axis=0)
        coords = coords[:, valid_mask]

        X, Y = np.dot(axes, coords)
        if X.size and Y.size:
            X = normalized(X)
            Y = normalized(Y)

        return valid_mask, np.stack((X, Y), axis=1), axes.T

    def _setup_plot(self):
        self._clear_plot()
        if self.data is None:
            return
        self.__replot_requested = False
        names = get_unique_names([v.name for v in chain(self.data.domain.variables,
                                                        self.data.domain.metas)],
                                 ["{}-x".format(self.Variable_name[self.placement]),
                                  "{}-y".format(self.Variable_name[self.placement])])
        self.variable_x = ContinuousVariable(names[0])
        self.variable_y = ContinuousVariable(names[1])
        if self.placement in [self.Placement.Circular, self.Placement.LDA]:
            variables = list(self.model_selected)
        elif self.placement == self.Placement.Projection:
            variables = self.model_selected[:] + self.model_other[:]
        elif self.placement == self.Placement.PCA:
            variables = [var for var in self.data.domain.attributes if var.is_continuous]
        if not variables:
            self.graph.new_data(None)
            return
        if self.placement == self.Placement.PCA:
            valid_mask, ec, axes = self._get_pca()
            variables = self._pca.orig_domain.attributes
        else:
            valid_mask, ec, axes = self.prepare_plot_data(variables)

        self.plotdata.variables = variables
        self.plotdata.valid_mask = valid_mask
        self.plotdata.embedding_coords = ec
        self.plotdata.axes = axes
        if any(e is None for e in (valid_mask, ec, axes)):
            return

        if not sum(valid_mask):
            self.Error.no_valid_data()
            self.graph.new_data(None, None)
            return
        self.Error.no_valid_data.clear()

        self._anchor_circle(variables=variables)
        self._plot()

    def _plot(self):
        domain = self.data.domain
        new_metas = domain.metas + (self.variable_x, self.variable_y)
        domain = Domain(attributes=domain.attributes, class_vars=domain.class_vars, metas=new_metas)
        valid_mask = self.plotdata.valid_mask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[valid_mask] = self.plotdata.embedding_coords
        self.plotdata.data = data = self.data.transform(domain)
        data[:, self.variable_x] = array[:, 0].reshape(-1, 1)
        data[:, self.variable_y] = array[:, 1].reshape(-1, 1)
        subset_data = data[self._subset_mask & valid_mask]\
            if self._subset_mask is not None and len(self._subset_mask) else None
        self.plotdata.data = data
        self.graph.new_data(data[valid_mask], subset_data)
        if self._selection is not None:
            self.graph.selection = self._selection[valid_mask]
        self.graph.update_data(self.variable_x, self.variable_y, False)

    def _get_lda(self, data, variables):
        domain = Domain(attributes=variables, class_vars=data.domain.class_vars)
        data = data.transform(domain)
        lda = LinearDiscriminantAnalysis(solver='eigen', n_components=2)
        lda.fit(data.X, data.Y)
        scalings = lda.scalings_[:, :2].T
        if scalings.shape == (1, 1):
            scalings = np.array([[1.], [0.]])
        return scalings

    def _get_pca(self):
        data = self.data
        MAX_COMPONENTS = 2
        ncomponents = 2
        DECOMPOSITIONS = [PCA]  # TruncatedSVD
        cls = DECOMPOSITIONS[0]
        pca_projector = cls(n_components=MAX_COMPONENTS)
        pca_projector.component = ncomponents
        pca_projector.preprocessors = cls.preprocessors + [Normalize()]

        pca = pca_projector(data)
        variance_ratio = pca.explained_variance_ratio_
        cumulative = np.cumsum(variance_ratio)

        self._pca = pca
        if not np.isfinite(cumulative[-1]):
            self.Warning.trivial_components()

        coords = pca(data).X
        valid_mask = ~np.isnan(coords).any(axis=1)
        # scale axes
        max_radius = np.min([np.abs(np.min(coords, axis=0)), np.max(coords, axis=0)])
        axes = pca.components_.T.copy()
        axes *= max_radius / np.max(np.linalg.norm(axes, axis=1))
        return valid_mask, coords, axes

    def _update_graph(self, reset_view=False):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x, self.variable_y, reset_view)

    def update_density(self):
        self._update_graph(reset_view=False)

    def selection_changed(self):
        if self.graph.selection is not None:
            self._selection = np.zeros(len(self.data), dtype=np.uint8)
            self._selection[self.plotdata.valid_mask] = self.graph.selection
            self.selection_indices = self._selection.tolist()
        else:
            self._selection = self.selection_indices = None
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        def prepare_components():
            if self.placement in [self.Placement.Circular, self.Placement.LDA]:
                attrs = [a for a in self.model_selected[:]]
                axes = self.plotdata.axes
            elif self.placement == self.Placement.PCA:
                axes = self._pca.components_.T
                attrs = [a for a in self._pca.orig_domain.attributes]
            if self.placement != self.Placement.Projection:
                domain = Domain([ContinuousVariable(a.name, compute_value=lambda _: None)
                                 for a in attrs],
                                metas=[StringVariable(name='component')])
                metas = np.array([["{}{}".format(self.Component_name[self.placement], i + 1)
                                   for i in range(axes.shape[1])]],
                                 dtype=object).T
                components = Table(domain, axes.T, metas=metas)
                components.name = 'components'
            else:
                components = self.projection
            return components

        selected = annotated = components = None
        if self.data is not None and self.plotdata.data is not None:
            components = prepare_components()

            graph = self.graph
            mask = self.plotdata.valid_mask.astype(int)
            mask[mask == 1] = graph.selection if graph.selection is not None \
            else [False * len(mask)]

            selection = np.array([], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            name = self.data.name
            data = self.plotdata.data
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes

            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        def projection_name():
            name = ("Circular Placement",
                    "Linear Discriminant Analysis",
                    "Principal Component Analysis",
                    "Input projection")
            return name[self.placement]

        caption = report.render_items_vert((
            ("Projection", projection_name()),
            ("Color", name(self.graph.attr_color)),
            ("Label", name(self.graph.attr_label)),
            ("Shape", name(self.graph.attr_shape)),
            ("Size", name(self.graph.attr_size)),
            ("Jittering", self.graph.jitter_size != 0 and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_["point_width"] = settings_["point_size"]
        if version < 3:
            settings_graph = {}
            settings_graph["jitter_size"] = settings_["jitter_value"]
            settings_graph["point_width"] = settings_["point_width"]
            settings_graph["alpha_value"] = settings_["alpha_value"]
            settings_graph["class_density"] = settings_["class_density"]
            settings_["graph"] = settings_graph

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            c_domain = [t for t in context.ordered_domain if t[1] == 2]
            d_domain = [t for t in context.ordered_domain if t[1] == 1]
            for d, old_val, new_val in ((domain, "color_index", "attr_color"),
                                        (d_domain, "shape_index", "attr_shape"),
                                        (c_domain, "size_index", "attr_size")):
                index = context.values[old_val][0] - 1
                context.values[new_val] = (d[index][0], d[index][1] + 100) \
                    if 0 <= index < len(d) else None
        if version < 3:
            context.values["graph"] = {
                "attr_color": context.values["attr_color"],
                "attr_shape": context.values["attr_shape"],
                "attr_size": context.values["attr_size"]
            }
Exemple #11
0
class OWMap(OWDataProjectionWidget):
    """
    Scatter plot visualization of coordinates data with geographic maps for
    background.
    """

    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/GeoMap.svg"
    priority = 100

    replaces = [
        "Orange.widgets.visualize.owmap.OWMap",
    ]

    settings_version = 3

    attr_lat = settings.ContextSetting(None)
    attr_lon = settings.ContextSetting(None)

    GRAPH_CLASS = OWScatterPlotMapGraph
    graph = settings.SettingProvider(OWScatterPlotMapGraph)
    embedding_variables_names = None

    class Error(OWDataProjectionWidget.Error):
        no_lat_lon_vars = Msg("Data has no latitude and longitude variables.")

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")
        out_of_range = Msg(
            "Points with out of range latitude or longitude are not displayed."
        )
        no_internet = Msg("Cannot fetch map from the internet. "
                          "Displaying only cached parts.")

    class Information(OWDataProjectionWidget.Information):
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        super().__init__()
        self._attr_lat, self._attr_lon = None, None
        self.graph.show_internet_error.connect(self._show_internet_error)

    def _show_internet_error(self, show):
        if not self.Warning.no_internet.is_shown() and show:
            self.Warning.no_internet()
        elif self.Warning.no_internet.is_shown() and not show:
            self.Warning.no_internet.clear()

    def _add_controls(self):
        self.lat_lon_model = DomainModel(DomainModel.MIXED,
                                         valid_types=ContinuousVariable)

        lat_lon_box = gui.vBox(self.controlArea, True)
        options = dict(labelWidth=75,
                       orientation=Qt.Horizontal,
                       sendSelectedValue=True,
                       valueType=str,
                       contentsLength=14)

        gui.comboBox(lat_lon_box,
                     self,
                     'graph.tile_provider_key',
                     label='Map:',
                     items=list(TILE_PROVIDERS.keys()),
                     callback=self.graph.update_tile_provider,
                     **options)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lat',
                     label='Latitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lon',
                     label='Longitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        super()._add_controls()

        gui.checkBox(
            self._plot_box,
            self,
            value="graph.freeze",
            label="Freeze map",
            tooltip="If checked, the map won't change position to fit new data."
        )

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        lat_data = self.get_column(self.attr_lat, filter_valid=False)
        lon_data = self.get_column(self.attr_lon, filter_valid=False)
        if lat_data is None or lon_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(lat_data) & np.isfinite(lon_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_lat.name, self.attr_lon.name)

        in_range = (-MAX_LONGITUDE <= lon_data) & (lon_data <= MAX_LONGITUDE) &\
                   (-MAX_LATITUDE <= lat_data) & (lat_data <= MAX_LATITUDE)
        in_range = ~np.bitwise_xor(in_range, self.valid_data)
        self.Warning.out_of_range.clear()
        if in_range.sum() != len(lon_data):
            self.Warning.out_of_range()
        if in_range.sum() == 0:
            return None
        self.valid_data &= in_range

        x, y = deg2norm(lon_data, lat_data)
        # invert y to increase from bottom to top
        y = 1 - y
        return np.vstack((x, y)).T

    def check_data(self):
        super().check_data()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def init_attr_values(self):
        lat, lon = None, None
        if self.data is not None:
            lat, lon = find_lat_lon(self.data, filter_hidden=True)
            if lat is None or lon is None:
                # we either find both or we don't have valid data
                self.Error.no_lat_lon_vars()
                self.data = None
                lat, lon = None, None

        super().init_attr_values()

        self.lat_lon_model.set_domain(self.data.domain if self.data else None)
        self.attr_lat, self.attr_lon = lat, lon

    @property
    def effective_variables(self):
        return [self.attr_lat, self.attr_lon] \
            if self.attr_lat and self.attr_lon else []

    def showEvent(self, ev):
        super().showEvent(ev)
        # reset the map on show event since before that we didn't know the
        # right resolution
        self.graph.update_view_range()

    def resizeEvent(self, ev):
        super().resizeEvent(ev)
        # when resizing we need to constantly reset the map so that new
        # portions are drawn
        self.graph.update_view_range(match_data=False)

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            _settings["graph"] = {}
            if "tile_provider" in _settings:
                if _settings["tile_provider"] == "Watercolor":
                    _settings["tile_provider"] = DEFAULT_TILE_PROVIDER
                _settings["graph"]["tile_provider_key"] = \
                    _settings["tile_provider"]
            if "opacity" in _settings:
                _settings["graph"]["alpha_value"] = \
                    round(_settings["opacity"] * 2.55)
            if "zoom" in _settings:
                _settings["graph"]["point_width"] = \
                    round(_settings["zoom"] * 0.02)
            if "jittering" in _settings:
                _settings["graph"]["jitter_size"] = _settings["jittering"]
            if "show_legend" in _settings:
                _settings["graph"]["show_legend"] = _settings["show_legend"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            settings.migrate_str_to_variable(context,
                                             names="lat_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="lon_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="class_attr",
                                             none_placeholder="(None)")

            # those settings can have two none placeholder
            attr_placeholders = [("color_attr", "(Same color)"),
                                 ("label_attr", "(No labels)"),
                                 ("shape_attr", "(Same shape)"),
                                 ("size_attr", "(Same size)")]
            for attr, place in attr_placeholders:
                if context.values[attr][0] == place:
                    context.values[attr] = ("", context.values[attr][1])

                settings.migrate_str_to_variable(context,
                                                 names=attr,
                                                 none_placeholder="")
        if version < 3:
            settings.rename_setting(context, "lat_attr", "attr_lat")
            settings.rename_setting(context, "lon_attr", "attr_lon")
            settings.rename_setting(context, "color_attr", "attr_color")
            settings.rename_setting(context, "label_attr", "attr_label")
            settings.rename_setting(context, "shape_attr", "attr_shape")
            settings.rename_setting(context, "size_attr", "attr_size")
Exemple #12
0
class OWImageViewer(widget.OWWidget):
    name = "Image Viewer"
    description = "View images referred to in the data."
    icon = "icons/ImageViewer.svg"
    priority = 4050

    inputs = [("Data", Orange.data.Table, "setData")]
    outputs = [("Data", Orange.data.Table, )]

    settingsHandler = settings.DomainContextHandler()

    imageAttr = settings.ContextSetting(0)
    titleAttr = settings.ContextSetting(0)

    zoom = settings.Setting(25)
    autoCommit = settings.Setting(False)

    graph_name = "scene"

    def __init__(self):
        super().__init__()

        self.info = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info"),
            "Waiting for input\n"
        )

        self.imageAttrCB = gui.comboBox(
            self.controlArea, self, "imageAttr",
            box="Image Filename Attribute",
            tooltip="Attribute with image filenames",
            callback=[self.clearScene, self.setupScene],
            contentsLength=12,
            addSpace=True,
        )

        self.titleAttrCB = gui.comboBox(
            self.controlArea, self, "titleAttr",
            box="Title Attribute",
            tooltip="Attribute with image title",
            callback=self.updateTitles,
            contentsLength=12,
            addSpace=True
        )

        gui.hSlider(
            self.controlArea, self, "zoom",
            box="Zoom", minValue=1, maxValue=100, step=1,
            callback=self.updateZoom,
            createLabel=False
        )

        gui.separator(self.controlArea)
        gui.auto_commit(self.controlArea, self, "autoCommit",
                        "Commit", "Auto commit")

        gui.rubber(self.controlArea)

        self.scene = GraphicsScene()
        self.sceneView = QGraphicsView(self.scene, self)
        self.sceneView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.sceneView.setRenderHint(QPainter.Antialiasing, True)
        self.sceneView.setRenderHint(QPainter.TextAntialiasing, True)
        self.sceneView.setFocusPolicy(Qt.WheelFocus)
        self.sceneView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.sceneView.installEventFilter(self)
        self.mainArea.layout().addWidget(self.sceneView)

        self.scene.selectionChanged.connect(self.onSelectionChanged)
        self.scene.selectionRectPointChanged.connect(
            self.onSelectionRectPointChanged, Qt.QueuedConnection
        )
        self.resize(800, 600)

        self.thumbnailWidget = None
        self.sceneLayout = None
        self.selectedIndices = []

        #: List of _ImageItems
        self.items = []

        self._errcount = 0
        self._successcount = 0

        self.loader = ImageLoader(self)

    def setData(self, data):
        self.closeContext()
        self.clear()

        self.data = data

        if data is not None:
            domain = data.domain
            self.allAttrs = domain.variables + domain.metas
            self.stringAttrs = [a for a in self.allAttrs if a.is_string]

            self.stringAttrs = sorted(
                self.stringAttrs,
                key=lambda attr: 0 if "type" in attr.attributes else 1
            )

            indices = [i for i, var in enumerate(self.stringAttrs)
                       if var.attributes.get("type") == "image"]
            if indices:
                self.imageAttr = indices[0]

            self.imageAttrCB.setModel(VariableListModel(self.stringAttrs))
            self.titleAttrCB.setModel(VariableListModel(self.allAttrs))

            self.openContext(data)

            self.imageAttr = max(min(self.imageAttr, len(self.stringAttrs) - 1), 0)
            self.titleAttr = max(min(self.titleAttr, len(self.allAttrs) - 1), 0)

            if self.stringAttrs:
                self.setupScene()
        else:
            self.info.setText("Waiting for input\n")

    def clear(self):
        self.data = None
        self.information(0)
        self.error(0)
        self.imageAttrCB.clear()
        self.titleAttrCB.clear()
        self.clearScene()

    def setupScene(self):
        self.information(0)
        self.error(0)
        if self.data:
            attr = self.stringAttrs[self.imageAttr]
            titleAttr = self.allAttrs[self.titleAttr]
            instances = [inst for inst in self.data
                         if numpy.isfinite(inst[attr])]
            widget = ThumbnailWidget()
            layout = widget.layout()

            self.scene.addItem(widget)

            for i, inst in enumerate(instances):
                url = self.urlFromValue(inst[attr])
                title = str(inst[titleAttr])

                thumbnail = GraphicsThumbnailWidget(
                    QPixmap(), title=title, parent=widget
                )

                thumbnail.setToolTip(url.toString())
                thumbnail.instance = inst
                layout.addItem(thumbnail, i / 5, i % 5)

                if url.isValid():
                    future = self.loader.get(url)
                    watcher = _FutureWatcher(parent=thumbnail)
                    # watcher = FutureWatcher(future, parent=thumbnail)

                    def set_pixmap(thumb=thumbnail, future=future):
                        if future.cancelled():
                            return
                        if future.exception():
                            # Should be some generic error image.
                            pixmap = QPixmap()
                            thumb.setToolTip(thumb.toolTip() + "\n" +
                                             str(future.exception()))
                        else:
                            pixmap = QPixmap.fromImage(future.result())

                        thumb.setPixmap(pixmap)
                        if not pixmap.isNull():
                            thumb.setThumbnailSize(self.pixmapSize(pixmap))

                        self._updateStatus(future)

                    watcher.finished.connect(set_pixmap, Qt.QueuedConnection)
                    watcher.setFuture(future)
                else:
                    future = None
                self.items.append(_ImageItem(i, thumbnail, url, future))

            widget.show()
            widget.geometryChanged.connect(self._updateSceneRect)

            self.info.setText("Retrieving...\n")
            self.thumbnailWidget = widget
            self.sceneLayout = layout

        if self.sceneLayout:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())
            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)
            self.sceneLayout.activate()

    def urlFromValue(self, value):
        variable = value.variable
        origin = variable.attributes.get("origin", "")
        if origin and QDir(origin).exists():
            origin = QUrl.fromLocalFile(origin)
        elif origin:
            origin = QUrl(origin)
            if not origin.scheme():
                origin.setScheme("file")
        else:
            origin = QUrl("")
        base = origin.path()
        if base.strip() and not base.endswith("/"):
            origin.setPath(base + "/")

        name = QUrl(str(value))
        url = origin.resolved(name)
        if not url.scheme():
            url.setScheme("file")
        return url

    def pixmapSize(self, pixmap):
        """
        Return the preferred pixmap size based on the current `zoom` value.
        """
        scale = 2 * self.zoom / 100.0
        size = QSizeF(pixmap.size()) * scale
        return size.expandedTo(QSizeF(16, 16))

    def clearScene(self):
        for item in self.items:
            if item.future:
                item.future._reply.close()
                item.future.cancel()

        self.items = []
        self._errcount = 0
        self._successcount = 0

        self.scene.clear()
        self.thumbnailWidget = None
        self.sceneLayout = None

    def thumbnailItems(self):
        return [item.widget for item in self.items]

    def updateZoom(self):
        for item in self.thumbnailItems():
            item.setThumbnailSize(self.pixmapSize(item.pixmap()))

        if self.thumbnailWidget:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())

            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)

        if self.sceneLayout:
            self.sceneLayout.activate()

    def updateTitles(self):
        titleAttr = self.allAttrs[self.titleAttr]
        for item in self.items:
            item.widget.setTitle(str(item.widget.instance[titleAttr]))

    def onSelectionChanged(self):
        selected = [item for item in self.items if item.widget.isSelected()]
        self.selectedIndices = [item.index for item in selected]
        self.commit()

    def onSelectionRectPointChanged(self, point):
        self.sceneView.ensureVisible(QRectF(point, QSizeF(1, 1)), 5, 5)

    def commit(self):
        if self.data:
            if self.selectedIndices:
                selected = self.data[self.selectedIndices]
            else:
                selected = None
            self.send("Data", selected)
        else:
            self.send("Data", None)

    def _updateStatus(self, future):
        if future.cancelled():
            return

        if future.exception():
            self._errcount += 1
            _log.debug("Error: %r", future.exception())
        else:
            self._successcount += 1

        count = len([item for item in self.items if item.future is not None])
        self.info.setText(
            "Retrieving:\n" +
            "{} of {} images".format(self._successcount, count))

        if self._errcount + self._successcount == count:
            if self._errcount:
                self.info.setText(
                    "Done:\n" +
                    "{} images, {} errors".format(count, self._errcount)
                )
            else:
                self.info.setText(
                    "Done:\n" +
                    "{} images".format(count)
                )
            attr = self.stringAttrs[self.imageAttr]
            if self._errcount == count and not "type" in attr.attributes:
                self.error(0,
                           "No images found! Make sure the '%s' attribute "
                           "is tagged with 'type=image'" % attr.name)

    def _updateSceneRect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def onDeleteWidget(self):
        for item in self.items:
            item.future._reply.abort()
            item.future.cancel()

    def eventFilter(self, receiver, event):
        if receiver is self.sceneView and event.type() == QEvent.Resize \
                and self.thumbnailWidget:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())

            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)

        return super(OWImageViewer, self).eventFilter(receiver, event)
Exemple #13
0
class OWDistanceMap(widget.OWWidget):
    name = "Distance Map"
    description = "Visualize a distance matrix."
    icon = "icons/DistanceMap.svg"
    priority = 1200
    keywords = []

    class Inputs:
        distances = Input("Distances", Orange.misc.DistMatrix)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)
        features = Output("Features", widget.AttributeList, dynamic=False)

    settingsHandler = settings.PerfectDomainContextHandler()

    #: type of ordering to apply to matrix rows/columns
    NoOrdering, Clustering, OrderedClustering = 0, 1, 2

    sorting = settings.Setting(NoOrdering)

    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
    color_gamma = settings.Setting(0.0)
    color_low = settings.Setting(0.0)
    color_high = settings.Setting(1.0)

    annotation_idx = settings.ContextSetting(0)
    pending_selection = settings.Setting(None, schema_only=True)

    autocommit = settings.Setting(True)

    graph_name = "grid_widget"

    # Disable clustering for inputs bigger than this
    _MaxClustering = 25000
    # Disable cluster leaf ordering for inputs bigger than this
    _MaxOrderedClustering = 2000

    def __init__(self):
        super().__init__()

        self.matrix = None
        self._matrix_range = 0.
        self._tree = None
        self._ordered_tree = None
        self._sorted_matrix = None
        self._sort_indices = None
        self._selection = None

        self.sorting_cb = gui.comboBox(
            self.controlArea,
            self,
            "sorting",
            box="Element Sorting",
            items=["None", "Clustering", "Clustering with ordered leaves"],
            callback=self._invalidate_ordering)

        box = gui.vBox(self.controlArea, "Colors")
        self.color_map_widget = cmw = ColorGradientSelection(
            thresholds=(self.color_low, self.color_high), )
        model = itemmodels.ContinuousPalettesModel(parent=self)
        cmw.setModel(model)
        idx = cmw.findData(self.palette_name, model.KeyRole)
        if idx != -1:
            cmw.setCurrentIndex(idx)

        cmw.activated.connect(self._update_color)

        def _set_thresholds(low, high):
            self.color_low, self.color_high = low, high
            self._update_color()

        cmw.thresholdsChanged.connect(_set_thresholds)
        box.layout().addWidget(self.color_map_widget)

        self.annot_combo = gui.comboBox(self.controlArea,
                                        self,
                                        "annotation_idx",
                                        box="Annotations",
                                        contentsLength=12,
                                        searchable=True,
                                        callback=self._invalidate_annotations)
        self.annot_combo.setModel(itemmodels.VariableListModel())
        self.annot_combo.model()[:] = ["None", "Enumeration"]
        gui.rubber(self.controlArea)

        gui.auto_send(self.buttonsArea, self, "autocommit")

        self.view = pg.GraphicsView(background="w")
        self.mainArea.layout().addWidget(self.view)

        self.grid_widget = pg.GraphicsWidget()
        self.grid = QGraphicsGridLayout()
        self.grid_widget.setLayout(self.grid)

        self.gradient_legend = GradientLegendWidget(0, 1, self._color_map())
        self.gradient_legend.setSizePolicy(QSizePolicy.Preferred,
                                           QSizePolicy.Fixed)
        self.gradient_legend.setMaximumWidth(250)
        self.grid.addItem(self.gradient_legend, 0, 1)
        self.viewbox = pg.ViewBox(enableMouse=False, enableMenu=False)
        self.viewbox.setAcceptedMouseButtons(Qt.NoButton)
        self.viewbox.setAcceptHoverEvents(False)
        self.grid.addItem(self.viewbox, 2, 1)

        self.left_dendrogram = DendrogramWidget(
            self.grid_widget,
            orientation=DendrogramWidget.Left,
            selectionMode=DendrogramWidget.NoSelection,
            hoverHighlightEnabled=False)
        self.left_dendrogram.setAcceptedMouseButtons(Qt.NoButton)
        self.left_dendrogram.setAcceptHoverEvents(False)

        self.top_dendrogram = DendrogramWidget(
            self.grid_widget,
            orientation=DendrogramWidget.Top,
            selectionMode=DendrogramWidget.NoSelection,
            hoverHighlightEnabled=False)
        self.top_dendrogram.setAcceptedMouseButtons(Qt.NoButton)
        self.top_dendrogram.setAcceptHoverEvents(False)

        self.grid.addItem(self.left_dendrogram, 2, 0)
        self.grid.addItem(self.top_dendrogram, 1, 1)

        self.right_labels = TextList(alignment=Qt.AlignLeft | Qt.AlignVCenter,
                                     sizePolicy=QSizePolicy(
                                         QSizePolicy.Fixed,
                                         QSizePolicy.Expanding))
        self.bottom_labels = TextList(
            orientation=Qt.Horizontal,
            alignment=Qt.AlignRight | Qt.AlignVCenter,
            sizePolicy=QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed))

        self.grid.addItem(self.right_labels, 2, 2)
        self.grid.addItem(self.bottom_labels, 3, 1)

        self.view.setCentralItem(self.grid_widget)

        self.gradient_legend.hide()
        self.left_dendrogram.hide()
        self.top_dendrogram.hide()
        self.right_labels.hide()
        self.bottom_labels.hide()

        self.matrix_item = None
        self.dendrogram = None

        self.settingsAboutToBePacked.connect(self.pack_settings)

    def pack_settings(self):
        if self.matrix_item is not None:
            self.pending_selection = self.matrix_item.selections()
        else:
            self.pending_selection = None

    @Inputs.distances
    def set_distances(self, matrix):
        self.closeContext()
        self.clear()
        self.error()
        if matrix is not None:
            N, _ = matrix.shape
            if N < 2:
                self.error("Empty distance matrix.")
                matrix = None

        self.matrix = matrix
        if matrix is not None:
            self._matrix_range = numpy.nanmax(matrix)
            self.set_items(matrix.row_items, matrix.axis)
        else:
            self._matrix_range = 0.
            self.set_items(None)

        if matrix is not None:
            N, _ = matrix.shape
        else:
            N = 0

        model = self.sorting_cb.model()
        item = model.item(2)

        msg = None
        if N > OWDistanceMap._MaxOrderedClustering:
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sorting == OWDistanceMap.OrderedClustering:
                self.sorting = OWDistanceMap.Clustering
                msg = "Cluster ordering was disabled due to the input " \
                      "matrix being to big"
        else:
            item.setFlags(item.flags() | Qt.ItemIsEnabled)

        item = model.item(1)
        if N > OWDistanceMap._MaxClustering:
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sorting == OWDistanceMap.Clustering:
                self.sorting = OWDistanceMap.NoOrdering
            msg = "Clustering was disabled due to the input " \
                  "matrix being to big"
        else:
            item.setFlags(item.flags() | Qt.ItemIsEnabled)

        self.information(msg)

    def set_items(self, items, axis=1):
        self.items = items
        model = self.annot_combo.model()
        if items is None:
            model[:] = ["None", "Enumeration"]
        elif not axis:
            model[:] = ["None", "Enumeration", "Attribute names"]
        elif isinstance(items, Orange.data.Table):
            annot_vars = list(filter_visible(items.domain.variables)) + list(
                items.domain.metas)
            model[:] = ["None", "Enumeration"] + annot_vars
            self.annotation_idx = 0
            self.openContext(items.domain)
        elif isinstance(items, list) and \
                all(isinstance(item, Orange.data.Variable) for item in items):
            model[:] = ["None", "Enumeration", "Name"]
        else:
            model[:] = ["None", "Enumeration"]
        self.annotation_idx = min(self.annotation_idx, len(model) - 1)

    def clear(self):
        self.matrix = None
        self._tree = None
        self._ordered_tree = None
        self._sorted_matrix = None
        self._selection = []
        self._clear_plot()

    def handleNewSignals(self):
        if self.matrix is not None:
            self._update_ordering()
            self._setup_scene()
            self._update_labels()
            if self.pending_selection is not None:
                self.matrix_item.set_selections(self.pending_selection)
                self.pending_selection = None
        self.commit.now()

    def _clear_plot(self):
        def remove(item):
            item.setParentItem(None)
            item.scene().removeItem(item)

        if self.matrix_item is not None:
            self.matrix_item.selectionChanged.disconnect(
                self._invalidate_selection)
            remove(self.matrix_item)
            self.matrix_item = None

        self._set_displayed_dendrogram(None)
        self._set_labels(None)
        self.gradient_legend.hide()

    def _cluster_tree(self):
        if self._tree is None:
            self._tree = hierarchical.dist_matrix_clustering(self.matrix)
        return self._tree

    def _ordered_cluster_tree(self):
        if self._ordered_tree is None:
            tree = self._cluster_tree()
            self._ordered_tree = \
                hierarchical.optimal_leaf_ordering(tree, self.matrix)
        return self._ordered_tree

    def _setup_scene(self):
        self._clear_plot()
        self.matrix_item = DistanceMapItem(self._sorted_matrix)
        # Scale the y axis to compensate for pg.ViewBox's y axis invert
        self.matrix_item.setTransform(QTransform.fromScale(1, -1), )
        self.viewbox.addItem(self.matrix_item)
        # Set fixed view box range.
        h, w = self._sorted_matrix.shape
        self.viewbox.setRange(QRectF(0, -h, w, h), padding=0)

        self.matrix_item.selectionChanged.connect(self._invalidate_selection)

        if self.sorting == OWDistanceMap.NoOrdering:
            tree = None
        elif self.sorting == OWDistanceMap.Clustering:
            tree = self._cluster_tree()
        elif self.sorting == OWDistanceMap.OrderedClustering:
            tree = self._ordered_cluster_tree()

        self._set_displayed_dendrogram(tree)

        self._update_color()

    def _set_displayed_dendrogram(self, root):
        self.left_dendrogram.set_root(root)
        self.top_dendrogram.set_root(root)
        self.left_dendrogram.setVisible(root is not None)
        self.top_dendrogram.setVisible(root is not None)

        constraint = 0 if root is None else -1  # 150
        self.left_dendrogram.setMaximumWidth(constraint)
        self.top_dendrogram.setMaximumHeight(constraint)

    def _invalidate_ordering(self):
        self._sorted_matrix = None
        if self.matrix is not None:
            self._update_ordering()
            self._setup_scene()
            self._update_labels()
            self._invalidate_selection()

    def _update_ordering(self):
        if self.sorting == OWDistanceMap.NoOrdering:
            self._sorted_matrix = self.matrix
            self._sort_indices = None
        else:
            if self.sorting == OWDistanceMap.Clustering:
                tree = self._cluster_tree()
            elif self.sorting == OWDistanceMap.OrderedClustering:
                tree = self._ordered_cluster_tree()

            leaves = hierarchical.leaves(tree)
            indices = numpy.array([leaf.value.index for leaf in leaves])
            X = self.matrix
            self._sorted_matrix = X[indices[:, numpy.newaxis],
                                    indices[numpy.newaxis, :]]
            self._sort_indices = indices

    def _invalidate_annotations(self):
        if self.matrix is not None:
            self._update_labels()

    def _update_labels(self, ):
        if self.annotation_idx == 0:  # None
            labels = None
        elif self.annotation_idx == 1:  # Enumeration
            labels = [str(i + 1) for i in range(self.matrix.shape[0])]
        elif self.annot_combo.model()[
                self.annotation_idx] == "Attribute names":
            attr = self.matrix.row_items.domain.attributes
            labels = [str(attr[i]) for i in range(self.matrix.shape[0])]
        elif self.annotation_idx == 2 and \
                isinstance(self.items, widget.AttributeList):
            labels = [v.name for v in self.items]
        elif isinstance(self.items, Orange.data.Table):
            var = self.annot_combo.model()[self.annotation_idx]
            column, _ = self.items.get_column_view(var)
            labels = [var.str_val(value) for value in column]

        self._set_labels(labels)

    def _set_labels(self, labels):
        self._labels = labels

        if labels and self.sorting != OWDistanceMap.NoOrdering:
            sortind = self._sort_indices
            labels = [labels[i] for i in sortind]

        for textlist in [self.right_labels, self.bottom_labels]:
            textlist.setItems(labels or [])
            textlist.setVisible(bool(labels))

        constraint = -1 if labels else 0
        self.right_labels.setMaximumWidth(constraint)
        self.bottom_labels.setMaximumHeight(constraint)

    def _color_map(self) -> GradientColorMap:
        palette = self.color_map_widget.currentData()
        return GradientColorMap(palette.lookup_table(),
                                thresholds=(self.color_low,
                                            max(self.color_high,
                                                self.color_low)),
                                span=(0., self._matrix_range))

    def _update_color(self):
        palette = self.color_map_widget.currentData()
        self.palette_name = palette.name
        if self.matrix_item:
            cmap = self._color_map().replace(span=(0., 1.))
            colors = cmap.apply(numpy.arange(256) / 255.)
            self.matrix_item.setLookupTable(colors)
            self.gradient_legend.show()
            self.gradient_legend.setRange(0, self._matrix_range)
            self.gradient_legend.setColorMap(self._color_map())

    def _invalidate_selection(self):
        ranges = self.matrix_item.selections()
        ranges = reduce(iadd, ranges, [])
        indices = reduce(iadd, ranges, [])
        if self.sorting != OWDistanceMap.NoOrdering:
            sortind = self._sort_indices
            indices = [sortind[i] for i in indices]
        self._selection = list(sorted(set(indices)))
        self.commit.deferred()

    @gui.deferred
    def commit(self):
        datasubset = None
        featuresubset = None

        if not self._selection:
            pass
        elif isinstance(self.items, Orange.data.Table):
            indices = self._selection
            if self.matrix.axis == 1:
                datasubset = self.items.from_table_rows(self.items, indices)
            elif self.matrix.axis == 0:
                domain = Orange.data.Domain(
                    [self.items.domain[i] for i in indices],
                    self.items.domain.class_vars, self.items.domain.metas)
                datasubset = self.items.transform(domain)
        elif isinstance(self.items, widget.AttributeList):
            subset = [self.items[i] for i in self._selection]
            featuresubset = widget.AttributeList(subset)

        self.Outputs.selected_data.send(datasubset)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.items, self._selection))
        self.Outputs.features.send(featuresubset)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.clear()

    def send_report(self):
        annot = self.annot_combo.currentText()
        if self.annotation_idx <= 1:
            annot = annot.lower()
        self.report_items((("Sorting", self.sorting_cb.currentText().lower()),
                           ("Annotations", annot)))
        if self.matrix is not None:
            self.report_plot()
Exemple #14
0
class OWTestAndScore(OWWidget):
    name = "Test and Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100
    keywords = ['Cross Validation', 'CV']
    replaces = ["Orange.widgets.evaluate.owtestlearners.OWTestLearners"]

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = MultiInput(
            "Learner", Learner, filter_none=True
        )
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    buttons_area_orientation = None
    UserAdviceMessages = [
        widget.Message(
            "Click on the table header to select shown columns",
            "click_header")]

    settingsHandler = settings.PerfectDomainContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(2)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    use_rope = settings.Setting(False)
    rope = settings.Setting(0.1)
    comparison_criterion = settings.Setting(0, schema_only=True)

    TARGET_AVERAGE = "(None, show average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    class Error(OWWidget.Error):
        test_data_empty = Msg("Test dataset is empty.")
        class_required_test = Msg("Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        test_data_incompatible = Msg(
            "Test data may be incompatible with train data.")
        train_data_error = Msg("{}")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")
        cant_stratify = \
            Msg("Can't run stratified {}-fold cross validation; "
                "the least common class has only {} instances.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")
        test_data_transformed = Msg(
            "Test data has been transformed to match the train data.")
        cant_stratify_numeric = Msg("Stratification is ignored for regression")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion
        self.__id_gen = count()
        self._learner_inputs = []  # type: List[Tuple[Any, Learner]]
        #: An Ordered dictionary with current inputs and their testing results
        #: (keyed by ids generated by __id_gen).
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, box=True)
        rbox = gui.radioButtons(
            sbox, self, "resampling", callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_folds", label="Number of folds: ",
            items=[str(x) for x in self.NFolds],
            orientation=Qt.Horizontal, callback=self.kfold_changed)
        gui.checkBox(
            ibox, self, "cv_stratified", "Stratified",
            callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(
            ibox, self, "fold_feature", model=self.feature_model,
            orientation=Qt.Horizontal, searchable=True,
            callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_repeats", label="Repeat train/test: ",
            items=[str(x) for x in self.NRepeats], orientation=Qt.Horizontal,
            callback=self.shuffle_split_changed
        )
        gui.comboBox(
            ibox, self, "sample_size", label="Training set size: ",
            items=["{} %".format(x) for x in self.SampleSizes],
            orientation=Qt.Horizontal, callback=self.shuffle_split_changed
        )
        gui.checkBox(
            ibox, self, "shuffle_stratified", "Stratified",
            callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        gui.rubber(self.controlArea)

        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        self.results_box = gui.vBox(self.mainArea, box=True)
        self.cbox = gui.hBox(self.results_box)
        self.class_selection_combo = gui.comboBox(
            self.cbox, self, "class_selection", items=[],
            label="Evaluation results for target", orientation=Qt.Horizontal,
            sendSelectedValue=True, searchable=True, contentsLength=25,
            callback=self._on_target_class_changed
        )
        self.cbox.layout().addStretch(100)
        self.class_selection_combo.setMaximumContentsLength(30)
        self.results_box.layout().addWidget(self.score_table.view)

        gui.separator(self.mainArea, 16)
        self.compbox = box = gui.vBox(self.mainArea, box=True)
        cbox = gui.comboBox(
            box, self, "comparison_criterion", label="Compare models by:",
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed),
            orientation=Qt.Horizontal, callback=self.update_comparison_table).box

        gui.separator(cbox, 8)
        gui.checkBox(cbox, self, "use_rope", "Negligible diff.: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(cbox, self, "rope", validator=QDoubleValidator(),
                     controlWidth=50, callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        table = self.comparison_table = QTableWidget(
            wordWrap=False, editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(QLabel(
            "<small>Table shows probabilities that the score for the model in "
            "the row is higher than that of the model in the column. "
            "Small numbers show the probability that the difference is "
            "negligible.</small>", wordWrap=True))

    def sizeHint(self):
        sh = super().sizeHint()
        return QSize(780, sh.height())

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestAndScore.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestAndScore.FeatureFold and not enabled:
            self.resampling = OWTestAndScore.KFold

    @Inputs.learner
    def set_learner(self, index: int, learner: Learner):
        """
        Set the input `learner` at `index`.

        Parameters
        ----------
        index: int
        learner: Orange.base.Learner
        """
        key, _ = self._learner_inputs[index]
        slot = self.learners[key]
        self.learners[key] = slot._replace(learner=learner, results=None)
        self._invalidate([key])

    @Inputs.learner.insert
    def insert_learner(self, index: int, learner: Learner):
        key = next(self.__id_gen)
        self._learner_inputs.insert(index, (key, learner))
        self.learners[key] = InputLearner(learner, None, None, key)
        self.learners = {key: self.learners[key] for key, _ in self._learner_inputs}
        self._invalidate([key])

    @Inputs.learner.remove
    def remove_learner(self, index: int):
        key, _ = self._learner_inputs[index]
        self._invalidate([key])
        self._learner_inputs.pop(index)
        self.learners.pop(key)

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.cancel()
        self.Information.data_sampled.clear()
        self.Error.train_data_error.clear()

        if data is not None:
            data_errors = [
                ("Train dataset is empty.", len(data) == 0),
                (
                    "Train data input requires a target variable.",
                    not data.domain.class_vars
                ),
                ("Too many target variables.", len(data.domain.class_vars) > 1),
                ("Target variable has no values.", np.isnan(data.Y).all()),
                (
                    "Target variable has only one value.",
                    data.domain.has_discrete_class and len(unique(data.Y)) < 2
                ),
                ("Data has no features to learn from.", data.X.shape[1] == 0),
            ]

            for error_msg, cond in data_errors:
                if cond:
                    self.Error.train_data_error(error_msg)
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestAndScore.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not data:
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestAndScore.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {(True, True): " ",  # both, don't specify
                (True, False): " train ",
                (False, True): " test "}[(self.train_data_missing_vals,
                                          self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data and self.data.domain.class_var:
            new_scorers = usable_scorers(self.data.domain.class_var)
        else:
            new_scorers = []
        # Don't unnecessarily reset the combo because this would always reset
        # comparison_criterion; we also set it explicitly, though, for clarity
        if new_scorers != self.scorers:
            self.scorers = new_scorers
            combo = self.controls.comparison_criterion
            combo.clear()
            combo.addItems([scorer.long_name or scorer.name
                            for scorer in self.scorers])
            if self.scorers:
                self.comparison_criterion = 0
        if self.__pending_comparison_criterion is not None:
            # Check for the unlikely case that some scorers have been removed
            # from modules
            if self.__pending_comparison_criterion < len(self.scorers):
                self.comparison_criterion = self.__pending_comparison_criterion
            self.__pending_comparison_criterion = None

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self.score_table.update_header(self.scorers)
        self._update_view_enabled()
        self.update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestAndScore.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestAndScore.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestAndScore.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._update_view_enabled()
        self._invalidate()
        self.__update()

    def _update_view_enabled(self):
        self.compbox.setEnabled(
            self.resampling == OWTestAndScore.KFold
            and len(self.learners) > 1
            and self.data is not None)
        self.score_table.view.setEnabled(
            self.data is not None)

    def update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.score_table.model
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        names = []
        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            names.append(name)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            results = slot.results
            if results is not None and results.success:
                train = QStandardItem("{:.3f}".format(results.value.train_time))
                train.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                train.setData(key, Qt.UserRole)
                test = QStandardItem("{:.3f}".format(results.value.test_time))
                test.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                test.setData(key, Qt.UserRole)
                row = [head, train, test]
            else:
                row = [head]
            if isinstance(results, Try.Fail):
                head.setToolTip(str(results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                if isinstance(results.exception, DomainTransformationError) \
                        and self.resampling == self.TestOnTest:
                    self.Error.test_data_incompatible()
                    self.Information.test_data_transformed.clear()
                else:
                    errors.append("{name} failed with error:\n"
                                  "{exc.__class__.__name__}: {exc!s}"
                                  .format(name=name, exc=slot.results.exception)
                                  )

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(
                        slot.results.value, target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [Try(scorer_caller(scorer, ovr_results, target=1))
                             for scorer in self.scorers]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat, scorer in zip(stats, self.scorers):
                    item = QStandardItem()
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    if stat.success:
                        item.setData(float(stat.value[0]), Qt.DisplayRole)
                    else:
                        item.setToolTip(str(stat.exception))
                        if scorer.name in self.score_table.shown_scores:
                            has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        # Resort rows based on current sorting
        header = self.score_table.view.horizontalHeader()
        model.sort(
            header.sortIndicatorSection(),
            header.sortIndicatorOrder()
        )
        self._set_comparison_headers(names)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _on_use_rope_changed(self):
        self.controls.rope.setEnabled(self.use_rope)
        self.update_comparison_table()

    def update_comparison_table(self):
        self.comparison_table.clearContents()
        slots = self._successful_slots()
        if not (slots and self.scorers):
            return
        names = [learner_name(slot.learner) for slot in slots]
        self._set_comparison_headers(names)
        if self.resampling == OWTestAndScore.KFold:
            scores = self._scores_by_folds(slots)
            self._fill_table(names, scores)

    def _successful_slots(self):
        model = self.score_table.model
        proxy = self.score_table.sorted_model

        keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
                for row in range(proxy.rowCount()))
        slots = [slot for slot in (self.learners[key] for key in keys)
                 if slot.results is not None and slot.results.success]
        return slots

    def _set_comparison_headers(self, names):
        table = self.comparison_table
        try:
            # Prevent glitching during update
            table.setUpdatesEnabled(False)
            header = table.horizontalHeader()
            if len(names) > 2:
                header.setSectionResizeMode(QHeaderView.Stretch)
            else:
                header.setSectionResizeMode(QHeaderView.Fixed)
            table.setRowCount(len(names))
            table.setColumnCount(len(names))
            table.setVerticalHeaderLabels(names)
            table.setHorizontalHeaderLabels(names)
        finally:
            table.setUpdatesEnabled(True)

    def _scores_by_folds(self, slots):
        scorer = self.scorers[self.comparison_criterion]()
        if scorer.is_binary:
            if self.class_selection != self.TARGET_AVERAGE:
                class_var = self.data.domain.class_var
                target_index = class_var.values.index(self.class_selection)
                kw = dict(target=target_index)
            else:
                kw = dict(average='weighted')
        else:
            kw = {}

        def call_scorer(results):
            def thunked():
                return scorer.scores_by_folds(results.value, **kw).flatten()

            return thunked

        scores = [Try(call_scorer(slot.results)) for slot in slots]
        scores = [score.value if score.success else None for score in scores]
        # `None in scores doesn't work -- these are np.arrays)
        if any(score is None for score in scores):
            self.Warning.scores_not_computed()
        return scores

    def _fill_table(self, names, scores):
        table = self.comparison_table
        for row, row_name, row_scores in zip(count(), names, scores):
            for col, col_name, col_scores in zip(range(row), names, scores):
                if row_scores is None or col_scores is None:
                    continue
                if self.use_rope and self.rope:
                    p0, rope, p1 = baycomp.two_on_single(
                        row_scores, col_scores, self.rope)
                    if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(table, row, col,
                                   f"{p0:.3f}<br/><small>{rope:.3f}</small>",
                                   f"p({row_name} > {col_name}) = {p0:.3f}\n"
                                   f"p({row_name} = {col_name}) = {rope:.3f}")
                    self._set_cell(table, col, row,
                                   f"{p1:.3f}<br/><small>{rope:.3f}</small>",
                                   f"p({col_name} > {row_name}) = {p1:.3f}\n"
                                   f"p({col_name} = {row_name}) = {rope:.3f}")
                else:
                    p0, p1 = baycomp.two_on_single(row_scores, col_scores)
                    if np.isnan(p0) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(table, row, col,
                                   f"{p0:.3f}",
                                   f"p({row_name} > {col_name}) = {p0:.3f}")
                    self._set_cell(table, col, row,
                                   f"{p1:.3f}",
                                   f"p({col_name} > {row_name}) = {p1:.3f}")

    @classmethod
    def _set_cells_na(cls, table, row, col):
        cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
        cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

    @staticmethod
    def _set_cell(table, row, col, label, tooltip):
        item = QLabel(label)
        item.setToolTip(tooltip)
        item.setAlignment(Qt.AlignCenter)
        table.setCellWidget(row, col, item)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = (self.TARGET_AVERAGE, ) + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self.update_stats_model()
        self.update_comparison_table()

    def _invalidate(self, which=None):
        self.cancel()
        self.fold_feature_selected = \
            self.resampling == OWTestAndScore.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.score_table.model
        statmodelkeys = [model.item(row, 0).data(Qt.UserRole)
                         for row in range(model.rowCount())]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.comparison_table.clearContents()

        self.__needupdate = True

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [slot for slot in self.learners.values()
                 if slot.results is not None and slot.results.success]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [learner_name(slot.learner)
                                      for slot in valid]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".
                      format(stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [("Sampling type",
                      "{}Shuffle split, {} random samples with {}% data "
                      .format(stratified, self.NRepeats[self.n_repeats],
                              self.SampleSizes[self.sample_size]))]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.score_table.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Error.test_data_incompatible.clear()
        self.Warning.test_data_missing.clear()
        self.Warning.cant_stratify.clear()
        self.Information.cant_stratify_numeric.clear()
        self.Information.test_data_transformed(
            shown=self.resampling == self.TestOnTest
            and self.data is not None
            and self.test_data is not None
            and self.data.domain.attributes != self.test_data.domain.attributes)
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early or show warnings
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestAndScore.KFold:
            k = self.NFolds[self.n_folds]
            if len(self.data) < k:
                self.Error.too_many_folds()
                self.__state = State.Waiting
                self.commit()
                return
            do_stratify = self.cv_stratified
            if do_stratify:
                if self.data.domain.class_var.is_discrete:
                    least = min(filter(None,
                                       np.bincount(self.data.Y.astype(int))))
                    if least < k:
                        self.Warning.cant_stratify(k, least)
                        do_stratify = False
                else:
                    self.Information.cant_stratify_numeric()
                    do_stratify = False

        elif self.resampling == OWTestAndScore.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestAndScore.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData(
                    store_data=True, store_models=True),
                self.data, self.test_data, learners_c, self.preprocessor
            )
        else:
            if self.resampling == OWTestAndScore.KFold:
                sampler = Orange.evaluation.CrossValidation(
                    k=self.NFolds[self.n_folds],
                    random_state=rstate,
                    stratified=do_stratify)
            elif self.resampling == OWTestAndScore.FeatureFold:
                sampler = Orange.evaluation.CrossValidationFeature(
                    feature=self.fold_feature)
            elif self.resampling == OWTestAndScore.LeaveOneOut:
                sampler = Orange.evaluation.LeaveOneOut()
            elif self.resampling == OWTestAndScore.ShuffleSplit:
                sampler = Orange.evaluation.ShuffleSplit(
                    n_resamples=self.NRepeats[self.n_repeats],
                    train_size=self.SampleSizes[self.sample_size] / 100,
                    test_size=None,
                    stratified=self.shuffle_stratified,
                    random_state=rstate)
            elif self.resampling == OWTestAndScore.TestOnTrain:
                sampler = Orange.evaluation.TestOnTrainingData(
                    store_models=True)
            else:
                assert False, "self.resampling %s" % self.resampling

            sampler.store_data = True
            test_f = partial(
                sampler, self.data, learners_c, self.preprocessor)

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[[float], None]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = TaskState()

        def progress_callback(finished):
            if task.is_interruption_requested():
                raise UserInterrupt()
            task.set_progress_value(100 * finished)

        testfunc = partial(testfunc, callback=progress_callback)
        task.start(self.__executor, testfunc)

        task.progress_changed.connect(self.setProgressValue)
        task.watcher.finished.connect(self.__task_complete)

        self.Outputs.evaluations_results.invalidate()
        self.Outputs.predictions.invalidate()
        self.progressBarInit()
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, f: 'Future[Results]'):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        assert self.__task is not None and self.__task.future is f
        self.progressBarFinished()
        self.setStatusMessage("")
        assert f.done()
        self.__task = None
        self.__state = State.Done
        try:
            results = f.result()    # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:  # pylint: disable=broad-except
            log.exception("testing error (in __task_complete):",
                          exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er), er)))
            return

        learner_key = {slot.learner: key for key, slot in
                       self.learners.items()}
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [Try(scorer_caller(scorer, result))
                             for scorer in self.scorers]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self.score_table.update_header(self.scorers)
        self.update_stats_model()
        self.update_comparison_table()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            task.progress_changed.disconnect(self.setProgressValue)
            task.watcher.finished.disconnect(self.__task_complete)

            self.progressBarFinished()
            self.setStatusMessage("")

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=False)
        super().onDeleteWidget()

    def copy_to_clipboard(self):
        self.score_table.copy_selection_to_clipboard()
Exemple #15
0
class OWDiscretize(widget.OWWidget):
    name = "Discretize"
    description = "Discretize the numeric data features."
    icon = "icons/Discretize.svg"
    inputs = [
        InputSignal("Data",
                    Orange.data.Table,
                    "set_data",
                    doc="Input data table")
    ]
    outputs = [
        OutputSignal("Data",
                     Orange.data.Table,
                     doc="Table with discretized features")
    ]

    settingsHandler = settings.DomainContextHandler()
    saved_var_states = settings.ContextSetting({})

    default_method = settings.Setting(2)
    default_k = settings.Setting(3)
    autosend = settings.Setting(True)

    #: Discretization methods
    Default, Leave, MDL, EqualFreq, EqualWidth, Remove, Custom = range(7)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()

        #: input data
        self.data = None
        #: Current variable discretization state
        self.var_state = {}
        #: Saved variable discretization settings (context setting)
        self.saved_var_states = {}

        self.method = 0
        self.k = 5

        box = gui.vBox(self.controlArea, self.tr("Default Discretization"))
        self.default_bbox = rbox = gui.radioButtons(
            box, self, "default_method", callback=self._default_disc_changed)

        options = self.options = [
            self.tr("Default"),
            self.tr("Leave numeric"),
            self.tr("Entropy-MDL discretization"),
            self.tr("Equal-frequency discretization"),
            self.tr("Equal-width discretization"),
            self.tr("Remove numeric variables")
        ]

        for opt in options[1:5]:
            gui.appendRadioButton(rbox, opt)

        s = gui.hSlider(gui.indentedBox(rbox),
                        self,
                        "default_k",
                        minValue=2,
                        maxValue=10,
                        label="Num. of intervals:",
                        callback=self._default_disc_changed)
        s.setTracking(False)

        gui.appendRadioButton(rbox, options[-1])

        vlayout = QHBoxLayout()
        box = gui.widgetBox(self.controlArea,
                            "Individual Attribute Settings",
                            orientation=vlayout,
                            spacing=8)

        # List view with all attributes
        self.varview = QListView(selectionMode=QListView.ExtendedSelection)
        self.varview.setItemDelegate(DiscDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._var_selection_changed)

        vlayout.addWidget(self.varview)
        # Controls for individual attr settings
        self.bbox = controlbox = gui.radioButtons(
            box, self, "method", callback=self._disc_method_changed)
        vlayout.addWidget(controlbox)

        for opt in options[:5]:
            gui.appendRadioButton(controlbox, opt)

        s = gui.hSlider(gui.indentedBox(controlbox),
                        self,
                        "k",
                        minValue=2,
                        maxValue=10,
                        label="Num. of intervals:",
                        callback=self._disc_method_changed)
        s.setTracking(False)

        gui.appendRadioButton(controlbox, "Remove attribute")

        gui.rubber(controlbox)
        controlbox.setEnabled(False)

        self.controlbox = controlbox

        box = gui.auto_commit(self.controlArea,
                              self,
                              "autosend",
                              "Apply",
                              orientation=Qt.Horizontal,
                              checkbox_label="Send data after every change")
        box.layout().insertSpacing(0, 20)
        box.layout().insertWidget(0, self.report_button)

    def set_data(self, data):
        self.closeContext()
        self.data = data
        if self.data is not None:
            self._initialize(data)
            self.openContext(data)
            # Restore the per variable discretization settings
            self._restore(self.saved_var_states)
            # Complete the induction of cut points
            self._update_points()
        else:
            self._clear()
        self.unconditional_commit()

    def _initialize(self, data):
        # Initialize the default variable states for new data.
        self.class_var = data.domain.class_var
        cvars = [var for var in data.domain if var.is_continuous]
        self.varmodel[:] = cvars

        class_var = data.domain.class_var
        has_disc_class = data.domain.has_discrete_class

        self.default_bbox.buttons[self.MDL - 1].setEnabled(has_disc_class)
        self.bbox.buttons[self.MDL].setEnabled(has_disc_class)

        # If the newly disabled MDL button is checked then change it
        if not has_disc_class and self.default_method == self.MDL - 1:
            self.default_method = 0
        if not has_disc_class and self.method == self.MDL:
            self.method = 0

        # Reset (initialize) the variable discretization states.
        self._reset()

    def _restore(self, saved_state):
        # Restore variable states from a saved_state dictionary.
        def_method = self._current_default_method()
        for i, var in enumerate(self.varmodel):
            key = variable_key(var)
            if key in saved_state:
                state = saved_state[key]
                if isinstance(state.method, Default):
                    state = DState(Default(def_method), None, None)
                self._set_var_state(i, state)

    def _reset(self):
        # restore the individual variable settings back to defaults.
        def_method = self._current_default_method()
        self.var_state = {}
        for i in range(len(self.varmodel)):
            state = DState(Default(def_method), None, None)
            self._set_var_state(i, state)

    def _set_var_state(self, index, state):
        # set the state of variable at `index` to `state`.
        self.var_state[index] = state
        self.varmodel.setData(self.varmodel.index(index), state, Qt.UserRole)

    def _clear(self):
        self.data = None
        self.varmodel[:] = []
        self.var_state = {}
        self.saved_var_states = {}
        self.default_bbox.buttons[self.MDL - 1].setEnabled(True)
        self.bbox.buttons[self.MDL].setEnabled(True)

    def _update_points(self):
        """
        Update the induced cut points.
        """
        def induce_cuts(method, data, var):
            dvar = _dispatch[type(method)](method, data, var)
            if dvar is None:
                # removed
                return [], None
            elif dvar is var:
                # no transformation took place
                return None, var
            elif is_discretized(dvar):
                return dvar.compute_value.points, dvar
            else:
                assert False

        for i, var in enumerate(self.varmodel):
            state = self.var_state[i]
            if state.points is None and state.disc_var is None:
                points, dvar = induce_cuts(state.method, self.data, var)
                new_state = state._replace(points=points, disc_var=dvar)
                self._set_var_state(i, new_state)
        self.commit()

    def _method_index(self, method):
        return METHODS.index((type(method), ))

    def _current_default_method(self):
        method = self.default_method + 1
        k = self.default_k
        if method == OWDiscretize.Leave:
            def_method = Leave()
        elif method == OWDiscretize.MDL:
            def_method = MDL()
        elif method == OWDiscretize.EqualFreq:
            def_method = EqualFreq(k)
        elif method == OWDiscretize.EqualWidth:
            def_method = EqualWidth(k)
        elif method == OWDiscretize.Remove:
            def_method = Remove()
        else:
            assert False
        return def_method

    def _current_method(self):
        if self.method == OWDiscretize.Default:
            method = Default(self._current_default_method())
        elif self.method == OWDiscretize.Leave:
            method = Leave()
        elif self.method == OWDiscretize.MDL:
            method = MDL()
        elif self.method == OWDiscretize.EqualFreq:
            method = EqualFreq(self.k)
        elif self.method == OWDiscretize.EqualWidth:
            method = EqualWidth(self.k)
        elif self.method == OWDiscretize.Remove:
            method = Remove()
        elif self.method == OWDiscretize.Custom:
            method = Custom(self.cutpoints)
        else:
            assert False
        return method

    def _default_disc_changed(self):
        method = self._current_default_method()
        state = DState(Default(method), None, None)
        for i, _ in enumerate(self.varmodel):
            if isinstance(self.var_state[i].method, Default):
                self._set_var_state(i, state)
        self._update_points()

    def _disc_method_changed(self):
        indices = self.selected_indices()
        method = self._current_method()
        state = DState(method, None, None)
        for idx in indices:
            self._set_var_state(idx, state)
        self._update_points()

    def _var_selection_changed(self, *args):
        indices = self.selected_indices()
        # set of all methods for the current selection
        methods = [self.var_state[i].method for i in indices]
        mset = set(methods)
        self.controlbox.setEnabled(len(mset) > 0)
        if len(mset) == 1:
            method = mset.pop()
            self.method = self._method_index(method)
            if isinstance(method, (EqualFreq, EqualWidth)):
                self.k = method.k
            elif isinstance(method, Custom):
                self.cutpoints = method.points
        else:
            # deselect the current button
            self.method = -1
            bg = self.controlbox.group
            button_group_reset(bg)

    def selected_indices(self):
        rows = self.varview.selectionModel().selectedRows()
        return [index.row() for index in rows]

    def discretized_var(self, source):
        index = list(self.varmodel).index(source)
        state = self.var_state[index]
        if state.disc_var is None:
            return None
        elif state.disc_var is source:
            return source
        elif state.points == []:
            return None
        else:
            return state.disc_var

    def discretized_domain(self):
        """
        Return the current effective discretized domain.
        """
        if self.data is None:
            return None

        def disc_var(source):
            if source and source.is_continuous:
                return self.discretized_var(source)
            else:
                return source

        attributes = [disc_var(v) for v in self.data.domain.attributes]
        attributes = [v for v in attributes if v is not None]

        class_var = disc_var(self.data.domain.class_var)

        domain = Orange.data.Domain(attributes,
                                    class_var,
                                    metas=self.data.domain.metas)
        return domain

    def commit(self):
        output = None
        if self.data is not None:
            domain = self.discretized_domain()
            output = self.data.from_table(domain, self.data)
        self.send("Data", output)

    def storeSpecificSettings(self):
        super().storeSpecificSettings()
        self.saved_var_states = {
            variable_key(var): self.var_state[i]._replace(points=None,
                                                          disc_var=None)
            for i, var in enumerate(self.varmodel)
        }

    def send_report(self):
        self.report_items(
            (("Default method", self.options[self.default_method + 1]), ))
        if self.varmodel:
            self.report_items(
                "Thresholds",
                [(var.name, DiscDelegate.cutsText(self.var_state[i])
                  or "leave numeric") for i, var in enumerate(self.varmodel)])
Exemple #16
0
class OWDistributions(widget.OWWidget):
    name = "Distributions"
    description = "Display value distributions of a data feature in a graph."
    icon = "icons/Distribution.svg"
    priority = 120
    inputs = [InputSignal("Data", Orange.data.Table, "set_data",
                          doc="Set the input data set")]

    settingsHandler = settings.DomainContextHandler(
        match_values=settings.DomainContextHandler.MATCH_VALUES_ALL)
    #: Selected variable index
    variable_idx = settings.ContextSetting(-1)
    #: Selected group variable
    groupvar_idx = settings.ContextSetting(0)

    relative_freq = settings.Setting(False)
    disc_cont = settings.Setting(False)

    smoothing_index = settings.Setting(5)
    show_prob = settings.ContextSetting(0)

    graph_name = "plot"

    ASH_HIST = 50

    bins = [ 2, 3, 4, 5, 8, 10, 12, 15, 20, 30, 50 ]
    smoothing_facs = list(reversed([ 0.1, 0.2, 0.4, 0.6, 0.8, 1, 1.5, 2, 4, 6, 10 ]))

    def __init__(self):
        super().__init__()
        self.data = None

        self.distributions = None
        self.contingencies = None
        self.var = self.cvar = None
        varbox = gui.vBox(self.controlArea, "Variable")

        self.varmodel = itemmodels.VariableListModel()
        self.groupvarmodel = []

        self.varview = QListView(
            selectionMode=QListView.SingleSelection)
        self.varview.setSizePolicy(
            QSizePolicy.Minimum, QSizePolicy.Expanding)
        self.varview.setModel(self.varmodel)
        self.varview.setSelectionModel(
            itemmodels.ListSingleSelectionModel(self.varmodel))
        self.varview.selectionModel().selectionChanged.connect(
            self._on_variable_idx_changed)
        varbox.layout().addWidget(self.varview)

        box = gui.vBox(self.controlArea, "Precision")

        gui.separator(self.controlArea, 4, 4)

        box2 = gui.hBox(box)
        self.l_smoothing_l = gui.widgetLabel(box2, "Smooth")
        gui.hSlider(box2, self, "smoothing_index",
                    minValue=0, maxValue=len(self.smoothing_facs) - 1,
                    callback=self._on_set_smoothing, createLabel=False)
        self.l_smoothing_r = gui.widgetLabel(box2, "Precise")

        self.cb_disc_cont = gui.checkBox(
            gui.indentedBox(box, sep=4),
            self, "disc_cont", "Bin continuous variables",
            callback=self._on_groupvar_idx_changed,
            tooltip="Show continuous variables as discrete.")

        box = gui.vBox(self.controlArea, "Group by")
        self.icons = gui.attributeIconDict
        self.groupvarview = gui.comboBox(box, self, "groupvar_idx",
             callback=self._on_groupvar_idx_changed, valueType=str,
             contentsLength=12)
        box2 = gui.indentedBox(box, sep=4)
        self.cb_rel_freq = gui.checkBox(
            box2, self, "relative_freq", "Show relative frequencies",
            callback=self._on_relative_freq_changed,
            tooltip="Normalize probabilities so that probabilities for each group-by value sum to 1.")
        gui.separator(box2)
        self.cb_prob = gui.comboBox(
            box2, self, "show_prob", label="Show probabilities:",
            orientation=Qt.Horizontal,
            callback=self._on_relative_freq_changed,
            tooltip="Show probabilities for a chosen group-by value (at each point probabilities for all group-by values sum to 1).")

        self.plotview = pg.PlotWidget(background=None)
        self.plotview.setRenderHint(QPainter.Antialiasing)
        self.mainArea.layout().addWidget(self.plotview)
        w = QLabel()
        w.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        self.mainArea.layout().addWidget(w, Qt.AlignCenter)
        self.ploti = pg.PlotItem()
        self.plot = self.ploti.vb
        self.ploti.hideButtons()
        self.plotview.setCentralItem(self.ploti)

        self.plot_prob = pg.ViewBox()
        self.ploti.hideAxis('right')
        self.ploti.scene().addItem(self.plot_prob)
        self.ploti.getAxis("right").linkToView(self.plot_prob)
        self.ploti.getAxis("right").setLabel("Probability")
        self.plot_prob.setZValue(10)
        self.plot_prob.setXLink(self.ploti)
        self.update_views()
        self.ploti.vb.sigResized.connect(self.update_views)
        self.plot_prob.setRange(yRange=[0,1])

        def disable_mouse(plot):
            plot.setMouseEnabled(False, False)
            plot.setMenuEnabled(False)

        disable_mouse(self.plot)
        disable_mouse(self.plot_prob)

        self.tooltip_items = []
        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        pen = QPen(self.palette().color(QPalette.Text))
        for axis in ("left", "bottom"):
            self.ploti.getAxis(axis).setPen(pen)

        self._legend = LegendItem()
        self._legend.setParentItem(self.plot)
        self._legend.hide()
        self._legend.anchor((1, 0), (1, 0))

    def update_views(self):
        self.plot_prob.setGeometry(self.plot.sceneBoundingRect())
        self.plot_prob.linkedViewChanged(self.plot, self.plot_prob.XAxis)

    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.warning()
        self.data = data
        if self.data is not None:
            if not self.data:
                self.warning("Empty input data cannot be visualized")
                return
            domain = self.data.domain
            self.varmodel[:] = list(domain) + \
                               [meta for meta in domain.metas
                                if meta.is_continuous or meta.is_discrete]
            self.groupvarview.clear()
            self.groupvarmodel = \
                ["(None)"] + [var for var in domain if var.is_discrete] + \
                [meta for meta in domain.metas if meta.is_discrete]
            self.groupvarview.addItem("(None)")
            for var in self.groupvarmodel[1:]:
                self.groupvarview.addItem(self.icons[var], var.name)
            if domain.has_discrete_class:
                self.groupvar_idx = \
                    self.groupvarmodel[1:].index(domain.class_var) + 1
            self.openContext(domain)
            self.variable_idx = min(max(self.variable_idx, 0),
                                    len(self.varmodel) - 1)
            self.groupvar_idx = min(max(self.groupvar_idx, 0),
                                    len(self.groupvarmodel) - 1)
            itemmodels.select_row(self.varview, self.variable_idx)
            self._setup()

    def clear(self):
        self.plot.clear()
        self.plot_prob.clear()
        self.varmodel[:] = []
        self.groupvarmodel = []
        self.variable_idx = -1
        self.groupvar_idx = 0
        self._legend.clear()
        self._legend.hide()
        self.groupvarview.clear()
        self.cb_prob.clear()

    def _setup_smoothing(self):
        if not self.disc_cont and self.var and self.var.is_continuous:
            self.cb_disc_cont.setText("Bin continuous variables")
            self.l_smoothing_l.setText("Smooth")
            self.l_smoothing_r.setText("Precise")
        else:
            self.cb_disc_cont.setText("Bin continuous variables into {} bins".
                                      format(self.bins[self.smoothing_index]))
            self.l_smoothing_l.setText(" " + str(self.bins[0]))
            self.l_smoothing_r.setText(" " + str(self.bins[-1]))

    def _setup(self):
        self.plot.clear()
        self.plot_prob.clear()
        self._legend.clear()
        self._legend.hide()

        varidx = self.variable_idx
        self.var = self.cvar = None
        if varidx >= 0:
            self.var = self.varmodel[varidx]
        if self.groupvar_idx > 0:
            self.cvar = self.groupvarmodel[self.groupvar_idx]
            self.cb_prob.clear()
            self.cb_prob.addItem("(None)")
            self.cb_prob.addItems(self.cvar.values)
            self.cb_prob.addItem("(All)")
            self.show_prob = min(max(self.show_prob, 0),
                    len(self.cvar.values) + 1)
        data = self.data
        self._setup_smoothing()
        if self.var is None:
            return
        if self.disc_cont:
            domain = Orange.data.Domain(
                [self.var, self.cvar] if self.cvar else [self.var])
            data = Orange.data.Table(domain, data)
            disc = Orange.preprocess.discretize.EqualWidth(n=self.bins[self.smoothing_index])
            data = Orange.preprocess.Discretize(method=disc, remove_const=False)(data)
            self.var = data.domain[0]
        self.set_left_axis_name()
        self.enable_disable_rel_freq()
        if self.cvar:
            self.contingencies = \
                contingency.get_contingency(data, self.var, self.cvar)
            self.display_contingency()
        else:
            self.distributions = \
                distribution.get_distribution(data, self.var)
            self.display_distribution()
        self.plot.autoRange()

    def help_event(self, ev):
        in_graph_coor = self.plot.mapSceneToView(ev.scenePos())
        ctooltip = []
        for vb, item in self.tooltip_items:
            if isinstance(item, pg.PlotCurveItem) and item.mouseShape().contains(vb.mapSceneToView(ev.scenePos())):
                ctooltip.append(item.tooltip)
            elif isinstance(item, DistributionBarItem) and item.boundingRect().contains(vb.mapSceneToView(ev.scenePos())):
                ctooltip.append(item.tooltip)
        if ctooltip:
            QToolTip.showText(ev.screenPos(), "\n\n".join(ctooltip), widget=self.plotview)
            return True
        return False

    def display_distribution(self):
        dist = self.distributions
        var = self.var
        assert len(dist) > 0
        self.plot.clear()
        self.plot_prob.clear()
        self.ploti.hideAxis('right')
        self.tooltip_items = []

        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(var.name)
        bottomaxis.resizeEvent()

        self.set_left_axis_name()
        if var and var.is_continuous:
            bottomaxis.setTicks(None)
            if not len(dist[0]):
                return
            edges, curve = ash_curve(dist, None, m=OWDistributions.ASH_HIST,
                smoothing_factor=self.smoothing_facs[self.smoothing_index])
            edges = edges + (edges[1] - edges[0])/2
            edges = edges[:-1]
            item = pg.PlotCurveItem()
            pen = QPen(QBrush(Qt.white), 3)
            pen.setCosmetic(True)
            item.setData(edges, curve, antialias=True, stepMode=False,
                         fillLevel=0, brush=QBrush(Qt.gray), pen=pen)
            self.plot.addItem(item)
            item.tooltip = "Density"
            self.tooltip_items.append((self.plot, item))
        else:
            bottomaxis.setTicks([list(enumerate(var.values))])
            for i, w in enumerate(dist):
                geom = QRectF(i - 0.33, 0, 0.66, w)
                item = DistributionBarItem(geom, [1.0],
                                           [QColor(128, 128, 128)])
                self.plot.addItem(item)
                item.tooltip = "Frequency for %s: %r" % (var.values[i], w)
                self.tooltip_items.append((self.plot, item))

    def _on_relative_freq_changed(self):
        self.set_left_axis_name()
        if self.cvar and self.cvar.is_discrete:
            self.display_contingency()
        else:
            self.display_distribution()
        self.plot.autoRange()

    def display_contingency(self):
        """
        Set the contingency to display.
        """
        cont = self.contingencies
        var, cvar = self.var, self.cvar
        assert len(cont) > 0
        self.plot.clear()
        self.plot_prob.clear()
        self._legend.clear()
        self.tooltip_items = []

        if self.show_prob:
            self.ploti.showAxis('right')
        else:
            self.ploti.hideAxis('right')

        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(var.name)
        bottomaxis.resizeEvent()

        cvar_values = cvar.values
        colors = [QColor(*col) for col in cvar.colors]

        if var and var.is_continuous:
            bottomaxis.setTicks(None)

            weights, cols, cvar_values, curves = [], [], [], []
            for i, dist in enumerate(cont):
                v, W = dist
                if len(v):
                    weights.append(numpy.sum(W))
                    cols.append(colors[i])
                    cvar_values.append(cvar.values[i])
                    curves.append(ash_curve(dist, cont,  m=OWDistributions.ASH_HIST,
                        smoothing_factor=self.smoothing_facs[self.smoothing_index]))
            weights = numpy.array(weights)
            sumw = numpy.sum(weights)
            weights /= sumw
            colors = cols
            curves = [(X, Y * w) for (X, Y), w in zip(curves, weights)]
            ncval = len(cvar_values)

            curvesline = [] #from histograms to lines
            for (X,Y) in curves:
                X = X + (X[1] - X[0])/2
                X = X[:-1]
                X = numpy.array(X)
                Y = numpy.array(Y)
                curvesline.append((X,Y))

            for t in [ "fill", "line" ]:
                for (X, Y), color, w, cval in reversed(list(zip(curvesline, colors, weights, cvar_values))):
                    item = pg.PlotCurveItem()
                    pen = QPen(QBrush(color), 3)
                    pen.setCosmetic(True)
                    color = QColor(color)
                    color.setAlphaF(0.2)
                    item.setData(X, Y/(w if self.relative_freq else 1), antialias=True, stepMode=False,
                         fillLevel=0 if t == "fill" else None,
                         brush=QBrush(color), pen=pen)
                    self.plot.addItem(item)
                    if t == "line":
                        item.tooltip = ("Normalized density " if self.relative_freq else "Density ") \
                            + "\n"+ cvar.name + "=" + cval
                        self.tooltip_items.append((self.plot, item))

            if self.show_prob:
                M_EST = 5 #for M estimate
                all_X = numpy.array(numpy.unique(numpy.hstack([X for X,_ in curvesline])))
                inter_X = numpy.array(numpy.linspace(all_X[0], all_X[-1], len(all_X)*2))
                curvesinterp = [ numpy.interp(inter_X, X, Y) for (X,Y) in curvesline ]
                sumprob = numpy.sum(curvesinterp, axis=0)
                # allcorrection = M_EST/sumw*numpy.sum(sumprob)/len(inter_X)
                legal = sumprob > 0.05 * numpy.max(sumprob)

                i = len(curvesinterp) + 1
                show_all = self.show_prob == i
                for Y, color, cval in reversed(list(zip(curvesinterp, colors, cvar_values))):
                    i -= 1
                    if show_all or self.show_prob == i:
                        item = pg.PlotCurveItem()
                        pen = QPen(QBrush(color), 3, style=Qt.DotLine)
                        pen.setCosmetic(True)
                        #prob = (Y+allcorrection/ncval)/(sumprob+allcorrection)
                        prob = Y[legal] / sumprob[legal]
                        item.setData(inter_X[legal], prob, antialias=True, stepMode=False,
                             fillLevel=None, brush=None, pen=pen)
                        self.plot_prob.addItem(item)
                        item.tooltip = "Probability that \n" + cvar.name + "=" + cval
                        self.tooltip_items.append((self.plot_prob, item))

        elif var and var.is_discrete:
            bottomaxis.setTicks([list(enumerate(var.values))])

            cont = numpy.array(cont)
            ncval = len(cvar_values)

            maxh = 0 #maximal column height
            maxrh = 0 #maximal relative column height
            scvar = cont.sum(axis=1)
            #a cvar with sum=0 with allways have distribution counts 0,
            #therefore we can divide it by anything
            scvar[scvar==0] = 1
            for i, (value, dist) in enumerate(zip(var.values, cont.T)):
                maxh = max(maxh, max(dist))
                maxrh = max(maxrh, max(dist/scvar))

            for i, (value, dist) in enumerate(zip(var.values, cont.T)):
                dsum = sum(dist)
                geom = QRectF(i - 0.333, 0, 0.666, maxrh
                                     if self.relative_freq else maxh)
                if self.show_prob:
                    prob = dist / dsum
                    ci = 1.96 * numpy.sqrt(prob * (1 - prob) / dsum)
                else:
                    ci = None
                item = DistributionBarItem(geom, dist/scvar/maxrh
                                           if self.relative_freq
                                           else dist/maxh, colors)
                self.plot.addItem(item)
                tooltip = "\n".join("%s: %.*f" % (n, 3 if self.relative_freq else 1,  v)
                    for n,v in zip(cvar_values, dist/scvar if self.relative_freq else dist ))
                item.tooltip = ("Normalized frequency " if self.relative_freq else "Frequency ") \
                    + "(" + cvar.name + "=" + value + "):" \
                    + "\n" + tooltip
                self.tooltip_items.append((self.plot, item))

                if self.show_prob:
                    item.tooltip += "\n\nProbabilities:"
                    for ic, a in enumerate(dist):
                        if self.show_prob - 1 != ic and \
                                self.show_prob - 1 != len(dist):
                            continue
                        position = -0.333 + ((ic+0.5)*0.666/len(dist))
                        if dsum < 1e-6:
                            continue
                        prob = a / dsum
                        if not 1e-6 < prob < 1 - 1e-6:
                            continue
                        ci = 1.96 * sqrt(prob * (1 - prob) / dsum)
                        item.tooltip += "\n%s: %.3f ± %.3f" % (cvar_values[ic], prob, ci)
                        mark = pg.ScatterPlotItem()
                        bar = pg.ErrorBarItem()
                        pen = QPen(QBrush(QColor(0)), 1)
                        pen.setCosmetic(True)
                        bar.setData(x=[i+position], y=[prob],
                                    bottom=min(numpy.array([ci]), prob),
                                    top=min(numpy.array([ci]), 1 - prob),
                                     beam=numpy.array([0.05]),
                                     brush=QColor(1), pen=pen)
                        mark.setData([i+position], [prob], antialias=True, symbol="o",
                                 fillLevel=None, pxMode=True, size=10,
                                 brush=QColor(colors[ic]), pen=pen)
                        self.plot_prob.addItem(bar)
                        self.plot_prob.addItem(mark)

        for color, name in zip(colors, cvar_values):
            self._legend.addItem(
                ScatterPlotItem(pen=color, brush=color, size=10, shape="s"),
                escape(name)
            )
        self._legend.show()

    def set_left_axis_name(self):
        leftaxis = self.ploti.getAxis("left")
        set_label = leftaxis.setLabel
        if self.var and self.var.is_continuous:
            set_label(["Density", "Relative density"]
                      [self.cvar is not None and self.relative_freq])
        else:
            set_label(["Frequency", "Relative frequency"]
                      [self.cvar is not None and self.relative_freq])
        leftaxis.resizeEvent()

    def enable_disable_rel_freq(self):
        self.cb_prob.setDisabled(self.var is None or self.cvar is None)
        self.cb_rel_freq.setDisabled(
            self.var is None or self.cvar is None)

    def _on_variable_idx_changed(self):
        self.variable_idx = selected_index(self.varview)
        self._setup()

    def _on_groupvar_idx_changed(self):
        self._setup()

    def _on_set_smoothing(self):
        self._setup()

    def onDeleteWidget(self):
        self.plot.clear()
        super().onDeleteWidget()

    def get_widget_name_extension(self):
        if self.variable_idx >= 0:
            return self.varmodel[self.variable_idx]

    def send_report(self):
        if self.variable_idx < 0:
            return
        self.report_plot()
        text = "Distribution of '{}'".format(
            self.varmodel[self.variable_idx])
        if self.groupvar_idx:
            group_var = self.groupvarmodel[self.groupvar_idx]
            prob = self.cb_prob
            indiv_probs = 0 < prob.currentIndex() < prob.count() - 1
            if not indiv_probs or self.relative_freq:
                text += " grouped by '{}'".format(group_var)
                if self.relative_freq:
                    text += " (relative frequencies)"
            if indiv_probs:
                text += "; probabilites for '{}={}'".format(
                    group_var, prob.currentText())
        self.report_caption(text)
Exemple #17
0
class OWLiftCurve(widget.OWWidget):
    name = "Lift Curve"
    description = "Construct and display a lift curve " \
                  "from the evaluation of classifiers."
    icon = "icons/LiftCurve.svg"
    priority = 1020
    keywords = []

    class Inputs:
        evaluation_results = Input("Evaluation Results",
                                   Orange.evaluation.Results)

    settingsHandler = EvaluationResultsContextHandler()
    target_index = settings.ContextSetting(0)
    selected_classifiers = settings.ContextSetting([])

    display_convex_hull = settings.Setting(False)
    display_cost_func = settings.Setting(True)

    fp_cost = settings.Setting(500)
    fn_cost = settings.Setting(500)
    target_prior = settings.Setting(50.0)

    graph_name = "plot"

    def __init__(self):
        super().__init__()

        self.results = None
        self.classifier_names = []
        self.colors = []
        self._curve_data = {}

        box = gui.vBox(self.controlArea, "Plot")
        tbox = gui.vBox(box, "Target Class")
        tbox.setFlat(True)

        self.target_cb = gui.comboBox(tbox,
                                      self,
                                      "target_index",
                                      callback=self._on_target_changed,
                                      contentsLength=8,
                                      searchable=True)

        cbox = gui.vBox(box, "Classifiers")
        cbox.setFlat(True)
        self.classifiers_list_box = gui.listBox(
            cbox,
            self,
            "selected_classifiers",
            "classifier_names",
            selectionMode=QListView.MultiSelection,
            callback=self._on_classifiers_changed)

        gui.checkBox(box,
                     self,
                     "display_convex_hull",
                     "Show lift convex hull",
                     callback=self._replot)

        self.plotview = pg.GraphicsView(background="w")
        self.plotview.setFrameStyle(QFrame.StyledPanel)

        self.plot = pg.PlotItem(enableMenu=False)
        self.plot.setMouseEnabled(False, False)
        self.plot.hideButtons()

        pen = QPen(self.palette().color(QPalette.Text))

        tickfont = QFont(self.font())
        tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))

        axis = self.plot.getAxis("bottom")
        axis.setTickFont(tickfont)
        axis.setPen(pen)
        axis.setLabel("P Rate")

        axis = self.plot.getAxis("left")
        axis.setTickFont(tickfont)
        axis.setPen(pen)
        axis.setLabel("TP Rate")

        self.plot.showGrid(True, True, alpha=0.1)
        self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)

        self.plotview.setCentralItem(self.plot)
        self.mainArea.layout().addWidget(self.plotview)

    @Inputs.evaluation_results
    def set_results(self, results):
        """Set the input evaluation results."""
        self.closeContext()
        self.clear()
        self.results = check_results_adequacy(results, self.Error)
        if self.results is not None:
            self._initialize(results)
            self.openContext(self.results.domain.class_var,
                             self.classifier_names)
            self._setup_plot()

    def clear(self):
        """Clear the widget state."""
        self.plot.clear()
        self.results = None
        self.target_cb.clear()
        self.target_index = 0
        self.classifier_names = []
        self.colors = []
        self._curve_data = {}

    def _initialize(self, results):
        N = len(results.predicted)

        names = getattr(results, "learner_names", None)
        if names is None:
            names = ["#{}".format(i + 1) for i in range(N)]

        self.colors = colorpalettes.get_default_curve_colors(N)

        self.classifier_names = names
        self.selected_classifiers = list(range(N))
        for i in range(N):
            item = self.classifiers_list_box.item(i)
            item.setIcon(colorpalettes.ColorIcon(self.colors[i]))

        self.target_cb.addItems(results.data.domain.class_var.values)

    def plot_curves(self, target, clf_idx):
        if (target, clf_idx) not in self._curve_data:
            curve = liftCurve_from_results(self.results, clf_idx, target)
            color = self.colors[clf_idx]
            pen = QPen(color, 1)
            pen.setCosmetic(True)
            shadow_pen = QPen(pen.color().lighter(160), 2.5)
            shadow_pen.setCosmetic(True)
            item = pg.PlotDataItem(curve.points[0],
                                   curve.points[1],
                                   pen=pen,
                                   shadowPen=shadow_pen,
                                   symbol="+",
                                   symbolSize=3,
                                   symbolPen=shadow_pen,
                                   antialias=True)
            hull_item = pg.PlotDataItem(curve.hull[0],
                                        curve.hull[1],
                                        pen=pen,
                                        antialias=True)
            self._curve_data[target, clf_idx] = \
                PlotCurve(curve, item, hull_item)

        return self._curve_data[target, clf_idx]

    def _setup_plot(self):
        target = self.target_index
        selected = self.selected_classifiers
        curves = [self.plot_curves(target, clf_idx) for clf_idx in selected]

        for curve in curves:
            self.plot.addItem(curve.curve_item)

        if self.display_convex_hull:
            hull = convex_hull([c.curve.hull for c in curves])
            self.plot.plot(hull[0], hull[1], pen="y", antialias=True)

        pen = QPen(QColor(100, 100, 100, 100), 1, Qt.DashLine)
        pen.setCosmetic(True)
        self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True)

        warning = ""
        if not all(c.curve.is_valid for c in curves):
            if any(c.curve.is_valid for c in curves):
                warning = "Some lift curves are undefined"
            else:
                warning = "All lift curves are undefined"

        self.warning(warning)

    def _replot(self):
        self.plot.clear()
        if self.results is not None:
            self._setup_plot()

    def _on_target_changed(self):
        self._replot()

    def _on_classifiers_changed(self):
        self._replot()

    def send_report(self):
        if self.results is None:
            return
        caption = report.list_legend(self.classifiers_list_box,
                                     self.selected_classifiers)
        self.report_items((("Target class", self.target_cb.currentText()), ))
        self.report_plot()
        self.report_caption(caption)
Exemple #18
0
class OWMDS(widget.OWWidget):
    name = "MDS"
    description = "Two-dimensional data projection by multidimensional " \
                  "scaling constructed from a distance matrix."
    icon = "icons/MDS.svg"
    inputs = [("Data", Orange.data.Table, "set_data"),
              ("Distances", Orange.misc.DistMatrix, "set_disimilarity")]
    outputs = [("Selected Data", Orange.data.Table, widget.Default),
               ("Data", Orange.data.Table)]

    #: Initialization type
    PCA, Random = 0, 1

    #: Refresh rate
    RefreshRate = [("Every iteration", 1), ("Every 5 steps", 5),
                   ("Every 10 steps", 10), ("Every 25 steps", 25),
                   ("Every 50 steps", 50), ("None", -1)]

    JitterAmount = [("None", 0), ("0.1%", 0.1), ("0.5%", 0.5), ("1%", 1.0),
                    ("2%", 2.0)]
    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    settingsHandler = settings.DomainContextHandler()

    max_iter = settings.Setting(300)
    initialization = settings.Setting(PCA)
    refresh_rate = settings.Setting(3)

    # output embedding role.
    NoRole, AttrRole, AddAttrRole, MetaRole = 0, 1, 2, 3

    output_embedding_role = settings.Setting(2)
    autocommit = settings.Setting(True)

    color_value = settings.ContextSetting("")
    shape_value = settings.ContextSetting("")
    size_value = settings.ContextSetting("")
    label_value = settings.ContextSetting("")

    symbol_size = settings.Setting(8)
    symbol_opacity = settings.Setting(230)
    connected_pairs = settings.Setting(5)
    jitter = settings.Setting(0)

    legend_anchor = settings.Setting(((1, 0), (1, 0)))

    want_graph = True

    def __init__(self):
        super().__init__()
        self.matrix = None
        self.data = None
        self.matrix_data = None
        self.signal_data = None

        self._pen_data = None
        self._shape_data = None
        self._size_data = None
        self._label_data = None
        self._similar_pairs = None
        self._scatter_item = None
        self._legend_item = None
        self._selection_mask = None
        self._invalidated = False
        self._effective_matrix = None

        self.__update_loop = None
        self.__state = OWMDS.Waiting
        self.__in_next_step = False
        self.__draw_similar_pairs = False

        box = gui.widgetBox(self.controlArea, "MDS Optimization")
        form = QtGui.QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QtGui.QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10)

        form.addRow("Max iterations:",
                    gui.spin(box, self, "max_iter", 10, 10**4, step=1))

        form.addRow(
            "Initialization",
            gui.comboBox(box,
                         self,
                         "initialization",
                         items=["PCA (Torgerson)", "Random"],
                         callback=self.__invalidate_embedding))

        box.layout().addLayout(form)
        form.addRow(
            "Refresh",
            gui.comboBox(box,
                         self,
                         "refresh_rate",
                         items=[t for t, _ in OWMDS.RefreshRate],
                         callback=self.__invalidate_refresh))
        gui.separator(box, 10)
        self.runbutton = gui.button(box,
                                    self,
                                    "Run",
                                    callback=self._toggle_run)

        box = gui.widgetBox(self.controlArea, "Graph")
        self.colorvar_model = itemmodels.VariableListModel()

        common_options = {
            "sendSelectedValue": True,
            "valueType": str,
            "orientation": "horizontal",
            "labelWidth": 50,
            "contentsLength": 12
        }

        self.cb_color_value = gui.comboBox(
            box,
            self,
            "color_value",
            label="Color",
            callback=self._on_color_index_changed,
            **common_options)
        self.cb_color_value.setModel(self.colorvar_model)

        self.shapevar_model = itemmodels.VariableListModel()
        self.cb_shape_value = gui.comboBox(
            box,
            self,
            "shape_value",
            label="Shape",
            callback=self._on_shape_index_changed,
            **common_options)
        self.cb_shape_value.setModel(self.shapevar_model)

        self.sizevar_model = itemmodels.VariableListModel()
        self.cb_size_value = gui.comboBox(box,
                                          self,
                                          "size_value",
                                          label="Size",
                                          callback=self._on_size_index_changed,
                                          **common_options)
        self.cb_size_value.setModel(self.sizevar_model)

        self.labelvar_model = itemmodels.VariableListModel()
        self.cb_label_value = gui.comboBox(
            box,
            self,
            "label_value",
            label="Label",
            callback=self._on_label_index_changed,
            **common_options)
        self.cb_label_value.setModel(self.labelvar_model)

        form = QtGui.QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QtGui.QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10)
        form.addRow(
            "Symbol size",
            gui.hSlider(box,
                        self,
                        "symbol_size",
                        minValue=1,
                        maxValue=20,
                        callback=self._on_size_index_changed,
                        createLabel=False))
        form.addRow(
            "Symbol opacity",
            gui.hSlider(box,
                        self,
                        "symbol_opacity",
                        minValue=100,
                        maxValue=255,
                        step=100,
                        callback=self._on_color_index_changed,
                        createLabel=False))
        form.addRow(
            "Show similar pairs",
            gui.hSlider(gui.widgetBox(self.controlArea,
                                      orientation="horizontal"),
                        self,
                        "connected_pairs",
                        minValue=0,
                        maxValue=20,
                        createLabel=False,
                        callback=self._on_connected_changed))
        form.addRow(
            "Jitter",
            gui.comboBox(box,
                         self,
                         "jitter",
                         items=[text for text, _ in self.JitterAmount],
                         callback=self._update_plot))

        box.layout().addLayout(form)

        gui.rubber(self.controlArea)

        box = QtGui.QGroupBox("Zoom/Select", )
        box.setLayout(QtGui.QHBoxLayout())
        box.layout().setMargin(2)

        group = QtGui.QActionGroup(self, exclusive=True)

        def icon(name):
            path = "icons/Dlg_{}.png".format(name)
            path = pkg_resources.resource_filename(widget.__name__, path)
            return QtGui.QIcon(path)

        action_select = QtGui.QAction(
            "Select",
            self,
            checkable=True,
            checked=True,
            icon=icon("arrow"),
            shortcut=QtGui.QKeySequence(Qt.ControlModifier + Qt.Key_1))
        action_zoom = QtGui.QAction(
            "Zoom",
            self,
            checkable=True,
            checked=False,
            icon=icon("zoom"),
            shortcut=QtGui.QKeySequence(Qt.ControlModifier + Qt.Key_2))
        action_pan = QtGui.QAction(
            "Pan",
            self,
            checkable=True,
            checked=False,
            icon=icon("pan_hand"),
            shortcut=QtGui.QKeySequence(Qt.ControlModifier + Qt.Key_3))

        action_reset_zoom = QtGui.QAction(
            "Zoom to fit",
            self,
            icon=icon("zoom_reset"),
            shortcut=QtGui.QKeySequence(Qt.ControlModifier + Qt.Key_0))
        action_reset_zoom.triggered.connect(lambda: self.plot.autoRange(
            padding=0.1, items=[self._scatter_item]))
        group.addAction(action_select)
        group.addAction(action_zoom)
        group.addAction(action_pan)
        self.addActions(group.actions() + [action_reset_zoom])
        action_select.setChecked(True)

        def button(action):
            b = QtGui.QToolButton()
            b.setToolButtonStyle(Qt.ToolButtonIconOnly)
            b.setDefaultAction(action)
            return b

        box.layout().addWidget(button(action_select))
        box.layout().addWidget(button(action_zoom))
        box.layout().addWidget(button(action_pan))
        box.layout().addSpacing(4)
        box.layout().addWidget(button(action_reset_zoom))
        box.layout().addStretch()

        self.controlArea.layout().addWidget(box)

        box = gui.widgetBox(self.controlArea, "Output")
        self.output_combo = gui.comboBox(box,
                                         self,
                                         "output_embedding_role",
                                         items=[
                                             "Original features only",
                                             "Coordinates only",
                                             "Coordinates as features",
                                             "Coordinates as meta attributes"
                                         ],
                                         callback=self._invalidate_output,
                                         addSpace=4)
        gui.auto_commit(box,
                        self,
                        "autocommit",
                        "Send data",
                        checkbox_label="Send after any change",
                        box=None)
        self.inline_graph_report()

        self.plot = pg.PlotWidget(background="w", enableMenu=False)
        self.plot.getPlotItem().hideAxis("bottom")
        self.plot.getPlotItem().hideAxis("left")
        self.plot.getPlotItem().hideButtons()
        self.plot.setRenderHint(QtGui.QPainter.Antialiasing)
        self.mainArea.layout().addWidget(self.plot)

        self.selection_tool = PlotSelectionTool(parent=self)
        self.zoom_tool = PlotZoomTool(parent=self)
        self.pan_tool = PlotPanTool(parent=self)
        self.pinch_tool = PlotPinchZoomTool(parent=self)
        self.pinch_tool.setViewBox(self.plot.getViewBox())
        self.selection_tool.setViewBox(self.plot.getViewBox())
        self.selection_tool.selectionFinished.connect(self.__selection_end)
        self.current_tool = self.selection_tool

        def activate_tool(action):
            self.current_tool.setViewBox(None)

            if action is action_select:
                active, cur = self.selection_tool, Qt.ArrowCursor
            elif action is action_zoom:
                active, cur = self.zoom_tool, Qt.ArrowCursor
            elif action is action_pan:
                active, cur = self.pan_tool, Qt.OpenHandCursor
            self.current_tool = active
            self.current_tool.setViewBox(self.plot.getViewBox())
            self.plot.getViewBox().setCursor(QtGui.QCursor(cur))

        group.triggered[QtGui.QAction].connect(activate_tool)
        self.graphButton.clicked.connect(self.save_graph)

        self._initialize()

    @check_sql_input
    def set_data(self, data):
        self.signal_data = data

        if self.matrix is not None and data is not None and len(
                self.matrix) == len(data):
            self.closeContext()
            self.data = data
            self.update_controls()
            self.openContext(data)
        else:
            self._invalidated = True
        self._selection_mask = None

    def set_disimilarity(self, matrix):
        self.matrix = matrix
        if matrix is not None and matrix.row_items:
            self.matrix_data = matrix.row_items
        if matrix is None:
            self.matrix_data = None
        self._invalidated = True
        self._selection_mask = None

    def _clear(self):
        self._pen_data = None
        self._shape_data = None
        self._size_data = None
        self._label_data = None
        self._similar_pairs = None

        self.colorvar_model[:] = ["Same color"]
        self.shapevar_model[:] = ["Same shape"]
        self.sizevar_model[:] = ["Same size"]
        self.labelvar_model[:] = ["No labels"]

        self.color_value = self.colorvar_model[0]
        self.shape_value = self.shapevar_model[0]
        self.size_value = self.sizevar_model[0]
        self.label_value = self.labelvar_model[0]

        self.__set_update_loop(None)
        self.__state = OWMDS.Waiting

    def _clear_plot(self):
        self.plot.clear()
        self._scatter_item = None
        if self._legend_item is not None:
            anchor = legend_anchor_pos(self._legend_item)
            if anchor is not None:
                self.legend_anchor = anchor
            if self._legend_item.scene() is not None:
                self._legend_item.scene().removeItem(self._legend_item)
            self._legend_item = None

    def update_controls(self):
        if getattr(self.matrix, 'axis', 1) == 0:
            # Column-wise distances
            attr = "Attribute names"
            self.labelvar_model[:] = ["No labels", attr]
            self.shapevar_model[:] = ["Same shape", attr]
            self.colorvar_model[:] = ["Same color", attr]

            self.color_value = attr
            self.shape_value = attr
        else:
            # initialize the graph state from data
            domain = self.data.domain
            all_vars = list(domain.variables + domain.metas)
            cd_vars = [var for var in all_vars if var.is_primitive()]
            disc_vars = [var for var in all_vars if var.is_discrete]
            cont_vars = [var for var in all_vars if var.is_continuous]
            shape_vars = [
                var for var in disc_vars
                if len(var.values) <= len(ScatterPlotItem.Symbols) - 1
            ]
            self.colorvar_model[:] = chain(["Same color"],
                                           [self.colorvar_model.Separator],
                                           cd_vars)
            self.shapevar_model[:] = chain(["Same shape"],
                                           [self.shapevar_model.Separator],
                                           shape_vars)
            self.sizevar_model[:] = chain(["Same size", "Stress"],
                                          [self.sizevar_model.Separator],
                                          cont_vars)
            self.labelvar_model[:] = chain(["No labels"],
                                           [self.labelvar_model.Separator],
                                           all_vars)

            if domain.class_var is not None:
                self.color_value = domain.class_var.name

    def _initialize(self):
        # clear everything
        self.closeContext()
        self._clear()
        self.data = None
        self._effective_matrix = None
        self.embedding = None

        # if no data nor matrix is present reset plot
        if self.signal_data is None and self.matrix is None:
            return

        if self.signal_data and self.matrix_data and len(
                self.signal_data) != len(self.matrix_data):
            self.error(1, "Data and distances dimensions do not match.")
            self._update_plot()
            return

        self.error(1)

        if self.signal_data:
            self.data = self.signal_data
        elif self.matrix_data:
            self.data = self.matrix_data

        if self.matrix is not None:
            self._effective_matrix = self.matrix
            if self.matrix.axis == 0:
                self.data = None
        else:
            preprocessed_data = Orange.projection.MDS().preprocess(self.data)
            self._effective_matrix = Orange.distance.Euclidean(
                preprocessed_data)

        self.update_controls()
        self.openContext(self.data)

    def _toggle_run(self):
        if self.__state == OWMDS.Running:
            self.stop()
            self._invalidate_output()
        else:
            self.start()

    def start(self):
        if self.__state == OWMDS.Running:
            return
        elif self.__state == OWMDS.Finished:
            # Resume/continue from a previous run
            self.__start()
        elif self.__state == OWMDS.Waiting and \
                self._effective_matrix is not None:
            self.__start()

    def stop(self):
        if self.__state == OWMDS.Running:
            self.__set_update_loop(None)

    def __start(self):
        self.__draw_similar_pairs = False
        X = self._effective_matrix

        if self.embedding is not None:
            init = self.embedding
        elif self.initialization == OWMDS.PCA:
            init = torgerson(X, n_components=2)
        else:
            init = None

        # number of iterations per single GUI update step
        _, step_size = OWMDS.RefreshRate[self.refresh_rate]
        if step_size == -1:
            step_size = self.max_iter

        def update_loop(X, max_iter, step, init):
            """
            return an iterator over successive improved MDS point embeddings.
            """
            # NOTE: this code MUST NOT call into QApplication.processEvents
            done = False
            iterations_done = 0
            oldstress = numpy.finfo(numpy.float).max

            while not done:
                step_iter = min(max_iter - iterations_done, step)
                mds = Orange.projection.MDS(dissimilarity="precomputed",
                                            n_components=2,
                                            n_init=1,
                                            max_iter=step_iter)

                mdsfit = mds.fit(X, init=init)
                iterations_done += step_iter

                embedding, stress = mdsfit.embedding_, mdsfit.stress_
                stress /= numpy.sqrt(numpy.sum(embedding**2, axis=1)).sum()

                if iterations_done >= max_iter:
                    done = True
                elif (oldstress - stress) < mds.params["eps"]:
                    done = True
                init = embedding
                oldstress = stress

                yield embedding, mdsfit.stress_, iterations_done / max_iter

        self.__set_update_loop(update_loop(X, self.max_iter, step_size, init))
        self.progressBarInit(processEvents=None)

    def __set_update_loop(self, loop):
        """
        Set the update `loop` coroutine.

        The `loop` is a generator yielding `(embedding, stress, progress)`
        tuples where `embedding` is a `(N, 2) ndarray` of current updated
        MDS points, `stress` is the current stress and `progress` a float
        ratio (0 <= progress <= 1)

        If an existing update loop is already in palace it is interrupted
        (closed).

        .. note::
            The `loop` must not explicitly yield control flow to the event
            loop (i.e. call `QApplication.processEvents`)

        """
        if self.__update_loop is not None:
            self.__update_loop.close()
            self.__update_loop = None
            self.progressBarFinished(processEvents=None)

        self.__update_loop = loop

        if loop is not None:
            self.progressBarInit(processEvents=None)
            self.setStatusMessage("Running")
            self.runbutton.setText("Stop")
            self.__state = OWMDS.Running
            QtGui.QApplication.postEvent(self, QEvent(QEvent.User))
        else:
            self.setStatusMessage("")
            self.runbutton.setText("Start")
            self.__state = OWMDS.Finished

    def __next_step(self):
        if self.__update_loop is None:
            return

        loop = self.__update_loop
        try:
            embedding, stress, progress = next(self.__update_loop)
            assert self.__update_loop is loop
        except StopIteration:
            self.__set_update_loop(None)
            self.unconditional_commit()
            self.__draw_similar_pairs = True
            self._update_plot()
            self.plot.autoRange(padding=0.1, items=[self._scatter_item])
        else:
            self.progressBarSet(100.0 * progress, processEvents=None)
            self.embedding = embedding
            self._update_plot()
            self.plot.autoRange(padding=0.1, items=[self._scatter_item])
            # schedule next update
            QtGui.QApplication.postEvent(self, QEvent(QEvent.User),
                                         Qt.LowEventPriority)

    def customEvent(self, event):
        if event.type() == QEvent.User and self.__update_loop is not None:
            if not self.__in_next_step:
                self.__in_next_step = True
                try:
                    self.__next_step()
                finally:
                    self.__in_next_step = False
            else:
                warnings.warn(
                    "Re-entry in update loop detected. "
                    "A rogue `proccessEvents` is on the loose.",
                    RuntimeWarning)
                # re-schedule the update iteration.
                QtGui.QApplication.postEvent(self, QEvent(QEvent.User))
        return super().customEvent(event)

    def __invalidate_embedding(self):
        # reset/invalidate the MDS embedding, to the default initialization
        # (Random or PCA), restarting the optimization if necessary.
        if self.embedding is None:
            return
        state = self.__state
        if self.__update_loop is not None:
            self.__set_update_loop(None)

        X = self._effective_matrix

        if self.initialization == OWMDS.PCA:
            self.embedding = torgerson(X)
        else:
            self.embedding = numpy.random.rand(len(X), 2)

        self._update_plot()
        self.plot.autoRange(padding=0.1, items=[self._scatter_item])

        # restart the optimization if it was interrupted.
        if state == OWMDS.Running:
            self.__start()

    def __invalidate_refresh(self):
        state = self.__state

        if self.__update_loop is not None:
            self.__set_update_loop(None)

        # restart the optimization if it was interrupted.
        # TODO: decrease the max iteration count by the already
        # completed iterations count.
        if state == OWMDS.Running:
            self.__start()

    def handleNewSignals(self):
        if self._invalidated:
            self._invalidated = False
            self._initialize()
            self.start()
        self.__draw_similar_pairs = False
        self._update_plot()
        self.plot.autoRange(padding=0.1)
        self.unconditional_commit()

    def _invalidate_output(self):
        self.commit()

    def _on_color_index_changed(self):
        self._pen_data = None
        self._update_plot()

    def _on_shape_index_changed(self):
        self._shape_data = None
        self._update_plot()

    def _on_size_index_changed(self):
        self._size_data = None
        self._update_plot()

    def _on_label_index_changed(self):
        self._label_data = None
        self._update_plot()

    def _on_connected_changed(self):
        self._similar_pairs = None
        self._update_plot()

    def _update_plot(self):
        self._clear_plot()

        if self.embedding is not None:
            self._setup_plot()

    def _setup_plot(self):
        have_data = self.data is not None
        have_matrix_transposed = self.matrix is not None and not self.matrix.axis
        plotstyle = mdsplotutils.plotstyle

        def column(data, variable):
            a, _ = data.get_column_view(variable)
            return a.ravel()

        def attributes(matrix):
            return matrix.row_items.domain.attributes

        def scale(a):
            dmin, dmax = numpy.nanmin(a), numpy.nanmax(a)
            if dmax - dmin > 0:
                return (a - dmin) / (dmax - dmin)
            else:
                return numpy.zeros_like(a)

        def jitter(x, factor=1, rstate=None):
            if rstate is None:
                rstate = numpy.random.RandomState()
            elif not isinstance(rstate, numpy.random.RandomState):
                rstate = numpy.random.RandomState(rstate)
            span = numpy.nanmax(x) - numpy.nanmin(x)
            if span < numpy.finfo(x.dtype).eps * 100:
                span = 1
            a = factor * span / 100.
            return x + (rstate.random_sample(x.shape) - 0.5) * a

        if self._pen_data is None:
            if self._selection_mask is not None:
                pointflags = numpy.where(self._selection_mask,
                                         mdsplotutils.Selected,
                                         mdsplotutils.NoFlags)
            else:
                pointflags = None

            color_index = self.cb_color_value.currentIndex()
            if have_data and color_index > 0:
                color_var = self.colorvar_model[color_index]
                if color_var.is_discrete:
                    palette = colorpalette.ColorPaletteGenerator(
                        len(color_var.values))
                    plotstyle = plotstyle.updated(discrete_palette=palette)
                else:
                    palette = None

                color_data = mdsplotutils.color_data(self.data,
                                                     color_var,
                                                     plotstyle=plotstyle)
                color_data = numpy.hstack((color_data,
                                           numpy.full((len(color_data), 1),
                                                      self.symbol_opacity)))
                pen_data = mdsplotutils.pen_data(color_data * 0.8, pointflags)
                brush_data = mdsplotutils.brush_data(color_data)
            elif have_matrix_transposed and \
                    self.colorvar_model[color_index] == 'Attribute names':
                attr = attributes(self.matrix)
                palette = colorpalette.ColorPaletteGenerator(len(attr))
                color_data = [palette.getRGB(i) for i in range(len(attr))]
                color_data = numpy.hstack((color_data,
                                           numpy.full((len(color_data), 1),
                                                      self.symbol_opacity)))
                pen_data = mdsplotutils.pen_data(color_data * 0.8, pointflags)
                brush_data = mdsplotutils.brush_data(color_data)
            else:
                pen_data = make_pen(QtGui.QColor(Qt.darkGray), cosmetic=True)
                if self._selection_mask is not None:
                    pen_data = numpy.array([pen_data, plotstyle.selected_pen])
                    pen_data = pen_data[self._selection_mask.astype(int)]
                else:
                    pen_data = numpy.full(len(self.data),
                                          pen_data,
                                          dtype=object)
                brush_data = numpy.full(len(self.data),
                                        pg.mkColor((192, 192, 192,
                                                    self.symbol_opacity)),
                                        dtype=object)

            self._pen_data = pen_data
            self._brush_data = brush_data

        if self._shape_data is None:
            shape_index = self.cb_shape_value.currentIndex()
            if have_data and shape_index > 0:
                Symbols = ScatterPlotItem.Symbols
                symbols = numpy.array(list(Symbols.keys()))

                shape_var = self.shapevar_model[shape_index]
                data = column(self.data, shape_var)
                data = data % (len(Symbols) - 1)
                data[numpy.isnan(data)] = len(Symbols) - 1
                shape_data = symbols[data.astype(int)]
            elif have_matrix_transposed and \
                    self.shapevar_model[shape_index] == 'Attribute names':
                Symbols = ScatterPlotItem.Symbols
                symbols = numpy.array(list(Symbols.keys()))
                attr = [
                    i % (len(Symbols) - 1)
                    for i, _ in enumerate(attributes(self.matrix))
                ]
                shape_data = symbols[attr]
            else:
                shape_data = "o"
            self._shape_data = shape_data

        if self._size_data is None:
            MinPointSize = 3
            point_size = self.symbol_size + MinPointSize
            size_index = self.cb_size_value.currentIndex()
            if have_data and size_index == 1:
                # size by stress
                size_data = stress(self.embedding, self._effective_matrix)
                size_data = scale(size_data)
                size_data = MinPointSize + size_data * point_size
            elif have_data and size_index > 0:
                size_var = self.sizevar_model[size_index]
                size_data = column(self.data, size_var)
                size_data = scale(size_data)
                size_data = MinPointSize + size_data * point_size
            else:
                size_data = point_size
            self._size_data = size_data

        if self._label_data is None:
            label_index = self.cb_label_value.currentIndex()
            if have_data and label_index > 0:
                label_var = self.labelvar_model[label_index]
                label_data = column(self.data, label_var)
                label_data = [label_var.str_val(val) for val in label_data]
                label_items = [
                    pg.TextItem(text, anchor=(0.5, 0), color=0.0)
                    for text in label_data
                ]
            elif have_matrix_transposed and \
                    self.labelvar_model[label_index] == 'Attribute names':
                attr = attributes(self.matrix)
                label_items = [
                    pg.TextItem(str(text), anchor=(0.5, 0)) for text in attr
                ]
            else:
                label_items = None
            self._label_data = label_items

        emb_x, emb_y = self.embedding[:, 0], self.embedding[:, 1]
        if self.jitter > 0:
            _, jitter_factor = self.JitterAmount[self.jitter]
            emb_x = jitter(emb_x, jitter_factor, rstate=42)
            emb_y = jitter(emb_y, jitter_factor, rstate=667)

        if self.connected_pairs and self.__draw_similar_pairs:
            if self._similar_pairs is None:
                # This code requires storing lower triangle of X (n x n / 2
                # doubles), n x n / 2 * 2 indices to X, n x n / 2 indices for
                # argsort result. If this becomes an issue, it can be reduced to
                # n x n argsort indices by argsorting the entire X. Then we
                # take the first n + 2 * p indices. We compute their coordinates
                # i, j in the original matrix. We keep those for which i < j.
                # n + 2 * p will suffice to exclude the diagonal (i = j). If the
                # number of those for which i < j is smaller than p, we instead
                # take i > j. Among those that remain, we take the first p.
                # Assuming that MDS can't show so many points that memory could
                # become an issue, I preferred using simpler code.
                m = self._effective_matrix
                n = len(m)
                p = (n * (n - 1) // 2 * self.connected_pairs) // 100
                indcs = numpy.triu_indices(n, 1)
                sorted = numpy.argsort(m[indcs])[:p]
                self._similar_pairs = fpairs = numpy.empty(2 * p, dtype=int)
                fpairs[::2] = indcs[0][sorted]
                fpairs[1::2] = indcs[1][sorted]
            for i in range(int(len(emb_x[self._similar_pairs]) / 2)):
                item = QtGui.QGraphicsLineItem(
                    emb_x[self._similar_pairs][i * 2],
                    emb_y[self._similar_pairs][i * 2],
                    emb_x[self._similar_pairs][i * 2 + 1],
                    emb_y[self._similar_pairs][i * 2 + 1])
                pen = QtGui.QPen(QtGui.QBrush(QtGui.QColor(204, 204, 204)), 2)
                pen.setCosmetic(True)
                item.setPen(pen)
                self.plot.addItem(item)

        data = numpy.arange(len(self.data if have_data else self.matrix))
        self._scatter_item = item = ScatterPlotItem(x=emb_x,
                                                    y=emb_y,
                                                    pen=self._pen_data,
                                                    brush=self._brush_data,
                                                    symbol=self._shape_data,
                                                    size=self._size_data,
                                                    data=data,
                                                    antialias=True)
        self.plot.addItem(item)

        if self._label_data is not None:
            for (x, y), text_item in zip(self.embedding, self._label_data):
                self.plot.addItem(text_item)
                text_item.setPos(x, y)

        self._legend_item = LegendItem()
        viewbox = self.plot.getViewBox()
        self._legend_item.setParentItem(self.plot.getViewBox())
        self._legend_item.setZValue(viewbox.zValue() + 10)
        self._legend_item.restoreAnchor(self.legend_anchor)

        color_var = shape_var = None
        color_index = self.cb_color_value.currentIndex()
        if have_data and 1 <= color_index < len(self.colorvar_model):
            color_var = self.colorvar_model[color_index]
            assert isinstance(color_var, Orange.data.Variable)
        shape_index = self.cb_shape_value.currentIndex()
        if have_data and 1 <= shape_index < len(self.shapevar_model):
            shape_var = self.shapevar_model[shape_index]
            assert isinstance(shape_var, Orange.data.Variable)

        if shape_var is not None or \
                (color_var is not None and color_var.is_discrete):

            legend_data = mdsplotutils.legend_data(color_var,
                                                   shape_var,
                                                   plotstyle=plotstyle)

            for color, symbol, text in legend_data:
                self._legend_item.addItem(
                    ScatterPlotItem(pen=color,
                                    brush=color,
                                    symbol=symbol,
                                    size=10), escape(text))
        else:
            self._legend_item.hide()

    def commit(self):
        if self.embedding is not None:
            output = embedding = Orange.data.Table.from_numpy(
                Orange.data.Domain([
                    Orange.data.ContinuousVariable("X"),
                    Orange.data.ContinuousVariable("Y")
                ]), self.embedding)
        else:
            output = embedding = None

        if self.embedding is not None and self.data is not None:
            domain = self.data.domain
            attrs = domain.attributes
            class_vars = domain.class_vars
            metas = domain.metas

            if self.output_embedding_role == OWMDS.AttrRole:
                attrs = embedding.domain.attributes
            elif self.output_embedding_role == OWMDS.AddAttrRole:
                attrs = domain.attributes + embedding.domain.attributes
            elif self.output_embedding_role == OWMDS.MetaRole:
                metas += embedding.domain.attributes

            domain = Orange.data.Domain(attrs, class_vars, metas)
            output = Orange.data.Table.from_table(domain, self.data)

            if self.output_embedding_role == OWMDS.AttrRole:
                output.X[:] = embedding.X
            if self.output_embedding_role == OWMDS.AddAttrRole:
                output.X[:, -2:] = embedding.X
            elif self.output_embedding_role == OWMDS.MetaRole:
                output.metas[:, -2:] = embedding.X

        self.send("Data", output)
        if output is not None and self._selection_mask is not None and \
                numpy.any(self._selection_mask):
            subset = output[self._selection_mask]
        else:
            subset = None
        self.send("Selected Data", subset)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._clear_plot()
        self._clear()

    def __selection_end(self, path):
        self.select(path)
        self._pen_data = None
        self._update_plot()
        self._invalidate_output()

    def select(self, region):
        item = self._scatter_item
        if item is None:
            return

        indices = numpy.array([
            spot.data()
            for spot in item.points() if region.contains(spot.pos())
        ],
                              dtype=int)

        if not QtGui.QApplication.keyboardModifiers():
            self._selection_mask = None

        self.select_indices(indices, QtGui.QApplication.keyboardModifiers())

    def select_indices(self, indices, modifiers=Qt.NoModifier):
        if self.data is None:
            return

        if self._selection_mask is None or \
                not modifiers & (Qt.ControlModifier | Qt.ShiftModifier |
                                 Qt.AltModifier):
            self._selection_mask = numpy.zeros(len(self.data), dtype=bool)

        if modifiers & Qt.AltModifier:
            self._selection_mask[indices] = False
        elif modifiers & Qt.ControlModifier:
            self._selection_mask[indices] = ~self._selection_mask[indices]
        else:
            self._selection_mask[indices] = True

    def save_graph(self):
        from Orange.widgets.data.owsave import OWSave

        save_img = OWSave(data=self.plot.plotItem,
                          file_formats=FileFormat.img_writers)
        save_img.exec_()

    def send_report(self):
        if self.data is None:
            return
        self.report_plot(self.plot)
        caption = report.render_items_vert(
            (("Color", self.color_value != "Same color"
              and self.color_value), ("Shape", self.shape_value != "Same shape"
                                      and self.shape_value),
             ("Size", self.size_value != "Same size"
              and self.size_value), ("Labels", self.label_value != "No labels"
                                     and self.label_value)))
        if caption:
            self.report_caption(caption)
        self.report_items((("Output", self.output_combo.currentText()), ))
Exemple #19
0
class OWHeatMap(widget.OWWidget):
    name = "Heat Map"
    description = "Plot a data matrix heatmap."
    icon = "icons/Heatmap.svg"
    priority = 260
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settings_version = 3

    settingsHandler = settings.DomainContextHandler()

    # Disable clustering for inputs bigger than this
    MaxClustering = 25000
    # Disable cluster leaf ordering for inputs bigger than this
    MaxOrderedClustering = 1000

    threshold_low = settings.Setting(0.0)
    threshold_high = settings.Setting(1.0)

    merge_kmeans = settings.Setting(False)
    merge_kmeans_k = settings.Setting(50)

    # Display column with averages
    averages: bool = settings.Setting(True)
    # Display legend
    legend: bool = settings.Setting(True)
    # Annotations
    #: text row annotation (row names)
    annotation_var = settings.ContextSetting(None)
    #: color row annotation
    annotation_color_var = settings.ContextSetting(None)
    # Discrete variable used to split that data/heatmaps (vertically)
    split_by_var = settings.ContextSetting(None)
    # Selected row/column clustering method (name)
    col_clustering_method: str = settings.Setting(Clustering.None_.name)
    row_clustering_method: str = settings.Setting(Clustering.None_.name)

    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
    column_label_pos: int = settings.Setting(1)
    selected_rows: List[int] = settings.Setting(None, schema_only=True)

    auto_commit = settings.Setting(True)

    graph_name = "scene"

    left_side_scrolling = True

    class Information(widget.OWWidget.Information):
        sampled = Msg("Data has been sampled")
        discrete_ignored = Msg("{} categorical feature{} ignored")
        row_clust = Msg("{}")
        col_clust = Msg("{}")
        sparse_densified = Msg("Showing this data may require a lot of memory")

    class Error(widget.OWWidget.Error):
        no_continuous = Msg("No numeric features")
        not_enough_features = Msg("Not enough features for column clustering")
        not_enough_instances = Msg("Not enough instances for clustering")
        not_enough_instances_k_means = Msg(
            "Not enough instances for k-means merging")
        not_enough_memory = Msg("Not enough memory to show this data")

    class Warning(widget.OWWidget.Warning):
        empty_clusters = Msg("Empty clusters were removed")

    def __init__(self):
        super().__init__()
        self.__pending_selection = self.selected_rows

        # A kingdom for a save_state/restore_state
        self.col_clustering = enum_get(Clustering, self.col_clustering_method,
                                       Clustering.None_)
        self.row_clustering = enum_get(Clustering, self.row_clustering_method,
                                       Clustering.None_)

        @self.settingsAboutToBePacked.connect
        def _():
            self.col_clustering_method = self.col_clustering.name
            self.row_clustering_method = self.row_clustering.name

        self.keep_aspect = False

        #: The original data with all features (retained to
        #: preserve the domain on the output)
        self.input_data = None
        #: The effective data striped of discrete features, and often
        #: merged using k-means
        self.data = None
        self.effective_data = None
        #: kmeans model used to merge rows of input_data
        self.kmeans_model = None
        #: merge indices derived from kmeans
        #: a list (len==k) of int ndarray where the i-th item contains
        #: the indices which merge the input_data into the heatmap row i
        self.merge_indices = None
        self.parts: Optional[Parts] = None
        self.__rows_cache = {}
        self.__columns_cache = {}

        # GUI definition
        colorbox = gui.vBox(self.controlArea, "Color")
        self.color_cb = gui.palette_combo_box(self.palette_name)
        self.color_cb.currentIndexChanged.connect(self.update_color_schema)
        colorbox.layout().addWidget(self.color_cb)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        lowslider = gui.hSlider(colorbox,
                                self,
                                "threshold_low",
                                minValue=0.0,
                                maxValue=1.0,
                                step=0.05,
                                ticks=True,
                                intOnly=False,
                                createLabel=False,
                                callback=self.update_lowslider)
        highslider = gui.hSlider(colorbox,
                                 self,
                                 "threshold_high",
                                 minValue=0.0,
                                 maxValue=1.0,
                                 step=0.05,
                                 ticks=True,
                                 intOnly=False,
                                 createLabel=False,
                                 callback=self.update_highslider)

        form.addRow("Low:", lowslider)
        form.addRow("High:", highslider)

        colorbox.layout().addLayout(form)

        mergebox = gui.vBox(
            self.controlArea,
            "Merge",
        )
        gui.checkBox(mergebox,
                     self,
                     "merge_kmeans",
                     "Merge by k-means",
                     callback=self.__update_row_clustering)
        ibox = gui.indentedBox(mergebox)
        gui.spin(ibox,
                 self,
                 "merge_kmeans_k",
                 minv=5,
                 maxv=500,
                 label="Clusters:",
                 keyboardTracking=False,
                 callbackOnReturn=True,
                 callback=self.update_merge)

        cluster_box = gui.vBox(self.controlArea, "Clustering")
        # Row clustering
        self.row_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.row_clustering, ClusteringRole)
        self.connect_control(
            "row_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_row_clustering(cb.itemData(idx, ClusteringRole))

        # Column clustering
        self.col_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.col_clustering, ClusteringRole)
        self.connect_control(
            "col_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_col_clustering(cb.itemData(idx, ClusteringRole))

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )
        form.addRow("Rows:", self.row_cluster_cb)
        form.addRow("Columns:", self.col_cluster_cb)
        cluster_box.layout().addLayout(form)
        box = gui.vBox(self.controlArea, "Split By")

        self.row_split_model = DomainModel(
            placeholder="(None)",
            valid_types=(Orange.data.DiscreteVariable, ),
            parent=self,
        )
        self.row_split_cb = cb = ComboBoxSearch(
            enabled=not self.merge_kmeans,
            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=14,
            toolTip="Split the heatmap vertically by a categorical column")
        self.row_split_cb.setModel(self.row_split_model)
        self.connect_control("split_by_var",
                             lambda value, cb=cb: cbselect(cb, value))
        self.connect_control("merge_kmeans", self.row_split_cb.setDisabled)
        self.split_by_var = None

        self.row_split_cb.activated.connect(self.__on_split_rows_activated)
        box.layout().addWidget(self.row_split_cb)

        box = gui.vBox(self.controlArea, 'Annotation && Legends')

        gui.checkBox(box,
                     self,
                     'legend',
                     'Show legend',
                     callback=self.update_legend)

        gui.checkBox(box,
                     self,
                     'averages',
                     'Stripes with averages',
                     callback=self.update_averages_stripe)
        annotbox = QGroupBox("Row Annotations", flat=True)
        form = QFormLayout(annotbox,
                           formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        self.annotation_model = DomainModel(placeholder="(None)")
        self.annotation_text_cb = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength)
        self.annotation_text_cb.setModel(self.annotation_model)
        self.annotation_text_cb.activated.connect(self.set_annotation_var)
        self.connect_control("annotation_var", self.annotation_var_changed)

        self.row_side_color_model = DomainModel(
            order=(DomainModel.CLASSES, DomainModel.Separator,
                   DomainModel.METAS),
            placeholder="(None)",
            valid_types=DomainModel.PRIMITIVE,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
            parent=self,
        )
        self.row_side_color_cb = ComboBoxSearch(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            minimumContentsLength=12)
        self.row_side_color_cb.setModel(self.row_side_color_model)
        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
        self.connect_control("annotation_color_var",
                             self.annotation_color_var_changed)
        form.addRow("Text", self.annotation_text_cb)
        form.addRow("Color", self.row_side_color_cb)
        box.layout().addWidget(annotbox)
        posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
        posbox.setFlat(True)
        cb = gui.comboBox(posbox,
                          self,
                          "column_label_pos",
                          callback=self.update_column_annotations)
        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
        cb.setCurrentIndex(self.column_label_pos)
        gui.checkBox(self.controlArea,
                     self,
                     "keep_aspect",
                     "Keep aspect ratio",
                     box="Resize",
                     callback=self.__aspect_mode_changed)

        gui.rubber(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

        # Scene with heatmap
        class HeatmapScene(QGraphicsScene):
            widget: Optional[HeatmapGridWidget] = None

        self.scene = self.scene = HeatmapScene(parent=self)
        self.view = GraphicsView(
            self.scene,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
            widgetResizable=True,
        )
        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
        self.view.customContextMenuRequested.connect(
            self._on_view_context_menu)
        self.mainArea.layout().addWidget(self.view)
        self.selected_rows = []
        self.__font_inc = QAction("Increase Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+>"))
        self.__font_dec = QAction("Decrease Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+<"))
        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
            apply_all([self.__font_inc, self.__font_dec],
                      lambda a: a.setShortcutVisibleInContextMenu(True))
        self.addActions([self.__font_inc, self.__font_dec])

    @property
    def center_palette(self):
        palette = self.color_cb.currentData()
        return bool(palette.flags & palette.Diverging)

    @property
    def _column_label_pos(self) -> HeatmapGridWidget.Position:
        return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole]

    def annotation_color_var_changed(self, value):
        cbselect(self.row_side_color_cb, value, Qt.EditRole)

    def annotation_var_changed(self, value):
        cbselect(self.annotation_text_cb, value, Qt.EditRole)

    def set_row_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.row_clustering != method:
            self.row_clustering = method
            cbselect(self.row_cluster_cb, method, ClusteringRole)
            self.__update_row_clustering()

    def set_col_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.col_clustering != method:
            self.col_clustering = method
            cbselect(self.col_cluster_cb, method, ClusteringRole)
            self.__update_column_clustering()

    def sizeHint(self) -> QSize:
        return super().sizeHint().expandedTo(QSize(900, 700))

    def color_palette(self):
        return self.color_cb.currentData().lookup_table()

    def color_map(self) -> GradientColorMap:
        return GradientColorMap(self.color_palette(),
                                (self.threshold_low, self.threshold_high),
                                0 if self.center_palette else None)

    def clear(self):
        self.data = None
        self.input_data = None
        self.effective_data = None
        self.kmeans_model = None
        self.merge_indices = None
        self.annotation_model.set_domain(None)
        self.annotation_var = None
        self.row_side_color_model.set_domain(None)
        self.annotation_color_var = None
        self.row_split_model.set_domain(None)
        self.split_by_var = None
        self.parts = None
        self.clear_scene()
        self.selected_rows = []
        self.__columns_cache.clear()
        self.__rows_cache.clear()
        self.__update_clustering_enable_state(None)

    def clear_scene(self):
        if self.scene.widget is not None:
            self.scene.widget.layoutDidActivate.disconnect(
                self.__on_layout_activate)
            self.scene.widget.selectionFinished.disconnect(
                self.on_selection_finished)
        self.scene.widget = None
        self.scene.clear()

        self.view.setSceneRect(QRectF())
        self.view.setHeaderSceneRect(QRectF())
        self.view.setFooterSceneRect(QRectF())

    @Inputs.data
    def set_dataset(self, data=None):
        """Set the input dataset to display."""
        self.closeContext()
        self.clear()
        self.clear_messages()

        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)

        if data is not None and not len(data):
            data = None

        if data is not None and sp.issparse(data.X):
            try:
                data = data.to_dense()
            except MemoryError:
                data = None
                self.Error.not_enough_memory()
            else:
                self.Information.sparse_densified()

        input_data = data

        # Data contains no attributes or meta attributes only
        if data is not None and len(data.domain.attributes) == 0:
            self.Error.no_continuous()
            input_data = data = None

        # Data contains some discrete attributes which must be filtered
        if data is not None and \
                any(var.is_discrete for var in data.domain.attributes):
            ndisc = sum(var.is_discrete for var in data.domain.attributes)
            data = data.transform(
                Domain([
                    var for var in data.domain.attributes if var.is_continuous
                ], data.domain.class_vars, data.domain.metas))
            if not data.domain.attributes:
                self.Error.no_continuous()
                input_data = data = None
            else:
                self.Information.discrete_ignored(ndisc,
                                                  "s" if ndisc > 1 else "")

        self.data = data
        self.input_data = input_data

        if data is not None:
            self.annotation_model.set_domain(self.input_data.domain)
            self.row_side_color_model.set_domain(self.input_data.domain)
            self.annotation_var = None
            self.annotation_color_var = None
            self.row_split_model.set_domain(data.domain)
            if data.domain.has_discrete_class:
                self.split_by_var = data.domain.class_var
            else:
                self.split_by_var = None
            self.openContext(self.input_data)
            if self.split_by_var not in self.row_split_model:
                self.split_by_var = None

        self.update_heatmaps()
        if data is not None and self.__pending_selection is not None:
            assert self.scene.widget is not None
            self.scene.widget.selectRows(self.__pending_selection)
            self.selected_rows = self.__pending_selection
            self.__pending_selection = None

        self.unconditional_commit()

    def __on_split_rows_activated(self):
        self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))

    def set_split_variable(self, var):
        if var != self.split_by_var:
            self.split_by_var = var
            self.update_heatmaps()

    def update_heatmaps(self):
        if self.data is not None:
            self.clear_scene()
            self.clear_messages()
            if self.col_clustering != Clustering.None_ and \
                    len(self.data.domain.attributes) < 2:
                self.Error.not_enough_features()
            elif (self.col_clustering != Clustering.None_ or
                  self.row_clustering != Clustering.None_) and \
                    len(self.data) < 2:
                self.Error.not_enough_instances()
            elif self.merge_kmeans and len(self.data) < 3:
                self.Error.not_enough_instances_k_means()
            else:
                parts = self.construct_heatmaps(self.data, self.split_by_var)
                self.construct_heatmaps_scene(parts, self.effective_data)
                self.selected_rows = []
        else:
            self.clear()

    def update_merge(self):
        self.kmeans_model = None
        self.merge_indices = None
        if self.data is not None and self.merge_kmeans:
            self.update_heatmaps()
            self.commit()

    def _make_parts(self, data, group_var=None):
        """
        Make initial `Parts` for data, split by group_var, group_key
        """
        if group_var is not None:
            assert group_var.is_discrete
            _col_data = table_column_data(data, group_var)
            row_indices = [
                np.flatnonzero(_col_data == i)
                for i in range(len(group_var.values))
            ]
            row_groups = [
                RowPart(title=name,
                        indices=ind,
                        cluster=None,
                        cluster_ordered=None)
                for name, ind in zip(group_var.values, row_indices)
            ]
        else:
            row_groups = [
                RowPart(title=None,
                        indices=range(0, len(data)),
                        cluster=None,
                        cluster_ordered=None)
            ]

        col_groups = [
            ColumnPart(title=None,
                       indices=range(0, len(data.domain.attributes)),
                       domain=data.domain,
                       cluster=None,
                       cluster_ordered=None)
        ]

        minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
        return Parts(row_groups, col_groups, span=(minv, maxv))

    def cluster_rows(self,
                     data: Table,
                     parts: 'Parts',
                     ordered=False) -> 'Parts':
        row_groups = []
        for row in parts.rows:
            if row.cluster is not None:
                cluster = row.cluster
            else:
                cluster = None
            if row.cluster_ordered is not None:
                cluster_ord = row.cluster_ordered
            else:
                cluster_ord = None

            if row.can_cluster:
                matrix = None
                need_dist = cluster is None or (ordered
                                                and cluster_ord is None)
                if need_dist:
                    subset = data[row.indices]
                    matrix = Orange.distance.Euclidean(subset)

                if cluster is None:
                    assert len(matrix) < self.MaxClustering
                    cluster = hierarchical.dist_matrix_clustering(
                        matrix, linkage=hierarchical.WARD)
                if ordered and cluster_ord is None:
                    assert len(matrix) < self.MaxOrderedClustering
                    cluster_ord = hierarchical.optimal_leaf_ordering(
                        cluster,
                        matrix,
                    )
            row_groups.append(
                row._replace(cluster=cluster, cluster_ordered=cluster_ord))

        return parts._replace(rows=row_groups)

    def cluster_columns(self, data, parts, ordered=False):
        assert len(parts.columns) == 1, "columns split is no longer supported"
        assert all(var.is_continuous for var in data.domain.attributes)

        col0 = parts.columns[0]
        if col0.cluster is not None:
            cluster = col0.cluster
        else:
            cluster = None
        if col0.cluster_ordered is not None:
            cluster_ord = col0.cluster_ordered
        else:
            cluster_ord = None
        need_dist = cluster is None or (ordered and cluster_ord is None)
        matrix = None
        if need_dist:
            data = Orange.distance._preprocess(data)
            matrix = np.asarray(Orange.distance.PearsonR(data, axis=0))
            # nan values break clustering below
            matrix = np.nan_to_num(matrix)

        if cluster is None:
            assert matrix is not None
            assert len(matrix) < self.MaxClustering
            cluster = hierarchical.dist_matrix_clustering(
                matrix, linkage=hierarchical.WARD)
        if ordered and cluster_ord is None:
            assert len(matrix) < self.MaxOrderedClustering
            cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix)

        col_groups = [
            col._replace(cluster=cluster, cluster_ordered=cluster_ord)
            for col in parts.columns
        ]
        return parts._replace(columns=col_groups)

    def construct_heatmaps(self, data, group_var=None) -> 'Parts':
        if self.merge_kmeans:
            if self.kmeans_model is None:
                effective_data = self.input_data.transform(
                    Orange.data.Domain([
                        var for var in self.input_data.domain.attributes
                        if var.is_continuous
                    ], self.input_data.domain.class_vars,
                                       self.input_data.domain.metas))
                nclust = min(self.merge_kmeans_k, len(effective_data) - 1)
                self.kmeans_model = kmeans_compress(effective_data, k=nclust)
                effective_data.domain = self.kmeans_model.domain
                merge_indices = [
                    np.flatnonzero(self.kmeans_model.labels == ind)
                    for ind in range(nclust)
                ]
                not_empty_indices = [
                    i for i, x in enumerate(merge_indices) if len(x) > 0
                ]
                self.merge_indices = \
                    [merge_indices[i] for i in not_empty_indices]
                if len(merge_indices) != len(self.merge_indices):
                    self.Warning.empty_clusters()
                effective_data = Orange.data.Table(
                    Orange.data.Domain(effective_data.domain.attributes),
                    self.kmeans_model.centroids[not_empty_indices])
            else:
                effective_data = self.effective_data

            group_var = None
        else:
            self.kmeans_model = None
            self.merge_indices = None
            effective_data = data

        self.effective_data = effective_data

        self.__update_clustering_enable_state(effective_data)

        parts = self._make_parts(effective_data, group_var)
        # Restore/update the row/columns items descriptions from cache if
        # available
        rows_cache_key = (group_var,
                          self.merge_kmeans_k if self.merge_kmeans else None)
        if rows_cache_key in self.__rows_cache:
            parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)

        if self.row_clustering != Clustering.None_:
            parts = self.cluster_rows(
                effective_data,
                parts,
                ordered=self.row_clustering == Clustering.OrderedClustering)
        if self.col_clustering != Clustering.None_:
            parts = self.cluster_columns(
                effective_data,
                parts,
                ordered=self.col_clustering == Clustering.OrderedClustering)

        # Cache the updated parts
        self.__rows_cache[rows_cache_key] = parts
        return parts

    def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None:
        _T = TypeVar("_T", bound=Union[RowPart, ColumnPart])

        def select_cluster(clustering: Clustering, item: _T) -> _T:
            if clustering == Clustering.None_:
                return item._replace(cluster=None, cluster_ordered=None)
            elif clustering == Clustering.Clustering:
                return item._replace(cluster=item.cluster,
                                     cluster_ordered=None)
            elif clustering == Clustering.OrderedClustering:
                return item._replace(cluster=item.cluster_ordered,
                                     cluster_ordered=None)
            else:  # pragma: no cover
                raise TypeError()

        rows = [
            select_cluster(self.row_clustering, rowitem)
            for rowitem in parts.rows
        ]
        cols = [
            select_cluster(self.col_clustering, colitem)
            for colitem in parts.columns
        ]
        parts = Parts(columns=cols, rows=rows, span=parts.span)

        self.setup_scene(parts, data)

    def setup_scene(self, parts, data):
        # type: (Parts, Table) -> None
        widget = HeatmapGridWidget()
        widget.setColorMap(self.color_map())
        self.scene.addItem(widget)
        self.scene.widget = widget
        columns = [v.name for v in data.domain.attributes]
        parts = HeatmapGridWidget.Parts(
            rows=[
                HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster)
                for r in parts.rows
            ],
            columns=[
                HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster)
                for c in parts.columns
            ],
            data=data.X,
            span=parts.span,
            row_names=None,
            col_names=columns,
        )
        widget.setHeatmaps(parts)
        side = self.row_side_colors()
        if side is not None:
            widget.setRowSideColorAnnotations(side[0],
                                              side[1],
                                              name=side[2].name)
        widget.setColumnLabelsPosition(self._column_label_pos)
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        widget.setShowAverages(self.averages)
        widget.setLegendVisible(self.legend)

        widget.layoutDidActivate.connect(self.__on_layout_activate)
        widget.selectionFinished.connect(self.on_selection_finished)

        self.update_annotations()
        self.view.setCentralWidget(widget)
        self.parts = parts

    def __update_scene_rects(self):
        widget = self.scene.widget
        if widget is None:
            return
        rect = widget.geometry()
        self.scene.setSceneRect(rect)
        self.view.setSceneRect(rect)
        self.view.setHeaderSceneRect(widget.headerGeometry())
        self.view.setFooterSceneRect(widget.footerGeometry())

    def __on_layout_activate(self):
        self.__update_scene_rects()

    def __aspect_mode_changed(self):
        widget = self.scene.widget
        if widget is None:
            return
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        # when aspect fixed the vertical sh is fixex, when not, it can
        # shrink vertically
        sp = widget.sizePolicy()
        if self.keep_aspect:
            sp.setVerticalPolicy(QSizePolicy.Fixed)
        else:
            sp.setVerticalPolicy(QSizePolicy.Preferred)
        widget.setSizePolicy(sp)

    def __update_clustering_enable_state(self, data):
        if data is not None:
            N = len(data)
            M = len(data.domain.attributes)
        else:
            N = M = 0

        rc_enabled = N <= self.MaxClustering
        rco_enabled = N <= self.MaxOrderedClustering
        cc_enabled = M <= self.MaxClustering
        cco_enabled = M <= self.MaxOrderedClustering
        row_clust, col_clust = self.row_clustering, self.col_clustering

        row_clust_msg = ""
        col_clust_msg = ""

        if not rco_enabled and row_clust == Clustering.OrderedClustering:
            row_clust = Clustering.Clustering
            row_clust_msg = "Row cluster ordering was disabled due to the " \
                            "input matrix being to big"
        if not rc_enabled and row_clust == Clustering.Clustering:
            row_clust = Clustering.None_
            row_clust_msg = "Row clustering was was disabled due to the " \
                            "input matrix being to big"

        if not cco_enabled and col_clust == Clustering.OrderedClustering:
            col_clust = Clustering.Clustering
            col_clust_msg = "Column cluster ordering was disabled due to " \
                            "the input matrix being to big"
        if not cc_enabled and col_clust == Clustering.Clustering:
            col_clust = Clustering.None_
            col_clust_msg = "Column clustering was disabled due to the " \
                            "input matrix being to big"

        self.col_clustering = col_clust
        self.row_clustering = row_clust

        self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg))
        self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg))

        # Disable/enable the combobox items for the clustering methods
        def setenabled(cb: QComboBox, clu: bool, clu_op: bool):
            model = cb.model()
            assert isinstance(model, QStandardItemModel)
            idx = cb.findData(Clustering.OrderedClustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu_op)
            idx = cb.findData(Clustering.Clustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu)

        setenabled(self.row_cluster_cb, rc_enabled, rco_enabled)
        setenabled(self.col_cluster_cb, cc_enabled, cco_enabled)

    def update_averages_stripe(self):
        """Update the visibility of the averages stripe.
        """
        widget = self.scene.widget
        if widget is not None:
            widget.setShowAverages(self.averages)

    def update_lowslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            low.setSliderPosition(high.value() - 1)
        self.update_color_schema()

    def update_highslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            high.setSliderPosition(low.value() + 1)
        self.update_color_schema()

    def update_color_schema(self):
        self.palette_name = self.color_cb.currentData().name
        w = self.scene.widget
        if w is not None:
            w.setColorMap(self.color_map())

    def __update_column_clustering(self):
        self.update_heatmaps()
        self.commit()

    def __update_row_clustering(self):
        self.update_heatmaps()
        self.commit()

    def update_legend(self):
        widget = self.scene.widget
        if widget is not None:
            widget.setLegendVisible(self.legend)

    def row_annotation_var(self):
        return self.annotation_var

    def row_annotation_data(self):
        var = self.row_annotation_var()
        if var is None:
            return None
        return column_str_from_table(self.input_data, var)

    def _merge_row_indices(self):
        if self.merge_kmeans and self.kmeans_model is not None:
            return self.merge_indices
        else:
            return None

    def set_annotation_var(self, var: Union[None, Variable, int]):
        if isinstance(var, int):
            var = self.annotation_model[var]
        if self.annotation_var != var:
            self.annotation_var = var
            self.update_annotations()

    def update_annotations(self):
        widget = self.scene.widget
        if widget is not None:
            annot_col = self.row_annotation_data()
            merge_indices = self._merge_row_indices()
            if merge_indices is not None and annot_col is not None:
                join = lambda _1: join_elided(", ", 42, _1, " ({} more)")
                annot_col = aggregate_apply(join, annot_col, merge_indices)
            if annot_col is not None:
                widget.setRowLabels(annot_col)
                widget.setRowLabelsVisible(True)
            else:
                widget.setRowLabelsVisible(False)
                widget.setRowLabels(None)

    def row_side_colors(self):
        var = self.annotation_color_var
        if var is None:
            return None
        column_data = column_data_from_table(self.input_data, var)
        span = (np.nanmin(column_data), np.nanmax(column_data))
        merges = self._merge_row_indices()
        if merges is not None:
            column_data = aggregate(var, column_data, merges)
        data, colormap = self._colorize(var, column_data)
        if var.is_continuous:
            colormap.span = span
        return data, colormap, var

    def set_annotation_color_var(self, var: Union[None, Variable, int]):
        """Set the current side color annotation variable."""
        if isinstance(var, int):
            var = self.row_side_color_model[var]
        if self.annotation_color_var != var:
            self.annotation_color_var = var
            self.update_row_side_colors()

    def update_row_side_colors(self):
        widget = self.scene.widget
        if widget is None:
            return
        colors = self.row_side_colors()
        if colors is None:
            widget.setRowSideColorAnnotations(None)
        else:
            widget.setRowSideColorAnnotations(colors[0], colors[1],
                                              colors[2].name)

    def _colorize(self, var: Variable,
                  data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
        palette = var.palette  # type: Palette
        colors = np.array(
            [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
            dtype=np.uint8,
        )
        if var.is_discrete:
            mask = np.isnan(data)
            data[mask] = -1
            data = data.astype(int)
            if mask.any():
                values = (*var.values, "N/A")
            else:
                values = var.values
                colors = colors[:-1]
            return data, CategoricalColorMap(colors, values)
        elif var.is_continuous:
            cmap = GradientColorMap(colors[:-1])
            return data, cmap
        else:
            raise TypeError

    def update_column_annotations(self):
        widget = self.scene.widget
        if self.data is not None and widget is not None:
            widget.setColumnLabelsPosition(self._column_label_pos)

    def __adjust_font_size(self, diff):
        widget = self.scene.widget
        if widget is None:
            return
        curr = widget.font().pointSizeF()
        new = curr + diff

        self.__font_dec.setEnabled(new > 1.0)
        self.__font_inc.setEnabled(new <= 32)
        if new > 1.0:
            font = QFont()
            font.setPointSizeF(new)
            widget.setFont(font)

    def _on_view_context_menu(self, pos):
        widget = self.scene.widget
        if widget is None:
            return
        assert isinstance(widget, HeatmapGridWidget)
        menu = QMenu(self.view.viewport())
        menu.setAttribute(Qt.WA_DeleteOnClose)
        menu.addActions(self.view.actions())
        menu.addSeparator()
        menu.addActions([self.__font_inc, self.__font_dec])
        menu.addSeparator()
        a = QAction("Keep aspect ratio", menu, checkable=True)
        a.setChecked(self.keep_aspect)

        def ontoggled(state):
            self.keep_aspect = state
            self.__aspect_mode_changed()

        a.toggled.connect(ontoggled)
        menu.addAction(a)
        menu.popup(self.view.viewport().mapToGlobal(pos))

    def on_selection_finished(self):
        if self.scene.widget is not None:
            self.selected_rows = list(self.scene.widget.selectedRows())
        else:
            self.selected_rows = []
        self.commit()

    def commit(self):
        data = None
        indices = None
        if self.merge_kmeans:
            merge_indices = self.merge_indices
        else:
            merge_indices = None

        if self.input_data is not None and self.selected_rows:
            indices = self.selected_rows
            if merge_indices is not None:
                # expand merged indices
                indices = np.hstack([merge_indices[i] for i in indices])

            data = self.input_data[indices]

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.input_data, indices))

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()

    def send_report(self):
        self.report_items((
            ("Columns:",
             "Clustering" if self.col_clustering else "No sorting"),
            ("Rows:", "Clustering" if self.row_clustering else "No sorting"),
            ("Split:", self.split_by_var is not None
             and self.split_by_var.name),
            ("Row annotation", self.annotation_var is not None
             and self.annotation_var.name),
        ))
        self.report_plot()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version is not None and version < 3:

            def st2cl(state: bool) -> Clustering:
                return Clustering.OrderedClustering if state else \
                    Clustering.None_

            rc = settings.pop("row_clustering", False)
            cc = settings.pop("col_clustering", False)
            settings["row_clustering_method"] = st2cl(rc).name
            settings["col_clustering_method"] = st2cl(cc).name
Exemple #20
0
class OWColor(widget.OWWidget):
    """Widget for assigning color palettes to variable"""

    name = "Color"
    description = "Set color legend for variables."
    icon = "icons/Colors.svg"

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Data", Orange.data.Table)]

    settingsHandler = settings.PerfectDomainContextHandler()
    disc_data = settings.ContextSetting([])
    cont_data = settings.ContextSetting([])
    color_settings = settings.Setting(None)
    selected_schema_index = settings.Setting(0)
    auto_apply = settings.Setting(True)

    want_main_area = False

    def __init__(self):
        super().__init__()
        self.data = None
        self.orig_domain = self.domain = None
        self.disc_colors = []
        self.cont_colors = []

        box = gui.hBox(self.controlArea, "Discrete Variables")
        self.disc_model = DiscColorTableModel()
        disc_view = self.disc_view = DiscreteTable(self.disc_model)
        disc_view.horizontalHeader().setResizeMode(QHeaderView.ResizeToContents)
        self.disc_model.dataChanged.connect(self._on_data_changed)
        box.layout().addWidget(disc_view)

        box = gui.hBox(self.controlArea, "Numeric Variables")
        self.cont_model = ContColorTableModel()
        cont_view = self.cont_view = ContinuousTable(self, self.cont_model)
        cont_view.setColumnWidth(1, 256)
        self.cont_model.dataChanged.connect(self._on_data_changed)
        box.layout().addWidget(cont_view)

        box = gui.auto_commit(self.controlArea, self, "auto_apply", "Apply",
                              orientation=Qt.Horizontal,
                              checkbox_label="Apply automatically")
        box.layout().insertSpacing(0, 20)
        box.layout().insertWidget(0, self.report_button)

    def _create_proxies(self, variables):
        part_vars = []
        for var in variables:
            if var.is_discrete or var.is_continuous:
                var = var.make_proxy()
                if var.is_discrete:
                    var.values = var.values[:]
                    self.disc_colors.append(var)
                else:
                    self.cont_colors.append(var)
            part_vars.append(var)
        return part_vars

    def set_data(self, data):
        """Handle data input signal"""
        self.closeContext()
        self.disc_colors = []
        self.cont_colors = []
        if data is None:
            self.data = self.domain = None
        else:
            domain = self.orig_domain = data.domain
            domain = Orange.data.Domain(self._create_proxies(domain.attributes),
                                        self._create_proxies(domain.class_vars),
                                        self._create_proxies(domain.metas))
            self.openContext(data)
            self.data = Orange.data.Table(domain, data)
            self.data.domain = domain

            self.disc_model.set_data(self.disc_colors)
            self.cont_model.set_data(self.cont_colors)
            self.disc_view.resizeColumnsToContents()
            self.cont_view.resizeColumnsToContents()
        self.commit()

    def storeSpecificSettings(self):
        # Store the colors that were changed -- but not others
        self.current_context.disc_data = \
            [(var.name, var.values, "colors" in var.attributes and var.colors)
             for var in self.disc_colors]
        self.current_context.cont_data = \
            [(var.name, "colors" in var.attributes and var.colors)
             for var in self.cont_colors]

    def retrieveSpecificSettings(self):
        disc_data = getattr(self.current_context, "disc_data", ())
        for var, (name, values, colors) in zip(self.disc_colors, disc_data):
            var.name = name
            var.values = values[:]
            if colors is not False:
                var.colors = colors
        cont_data = getattr(self.current_context, "cont_data", ())
        for var, (name, colors) in zip(self.cont_colors, cont_data):
            var.name = name
            if colors is not False:
                var.colors = colors

    def _on_data_changed(self, *args):
        self.commit()

    def commit(self):
        self.send("Data", self.data)

    def send_report(self):
        """Send report"""
        def _report_variables(variables, orig_variables):
            from Orange.canvas.report import colored_square as square

            def was(n, o):
                return n if n == o else "{} (was: {})".format(n, o)

            # definition of td element for continuous gradient
            # with support for pre-standard css (needed at least for Qt 4.8)
            max_values = max(
                (len(var.values) for var in variables if var.is_discrete),
                default=1)
            defs = ("-webkit-", "-o-", "-moz-", "")
            cont_tpl = '<td colspan="{}">' \
                       '<span class="legend-square" style="width: 100px; '.\
                format(max_values) + \
                " ".join(map(
                    "background: {}linear-gradient("
                    "left, rgb({{}}, {{}}, {{}}), {{}}rgb({{}}, {{}}, {{}}));"
                    .format, defs)) + \
                '"></span></td>'

            rows = ""
            for var, ovar in zip(variables, orig_variables):
                if var.is_discrete:
                    values = "    \n".join(
                        "<td>{} {}</td>".
                        format(square(*var.colors[i]), was(value, ovalue))
                        for i, (value, ovalue) in
                        enumerate(zip(var.values, ovar.values)))
                elif var.is_continuous:
                    col = var.colors
                    colors = col[0][:3] + ("black, " * col[2], ) + col[1][:3]
                    values = cont_tpl.format(*colors * len(defs))
                else:
                    continue
                name = was(var.name, ovar.name)
                rows += '<tr style="height: 2em">\n' \
                        '  <th style="text-align: right">{}</th>{}\n</tr>\n'. \
                    format(name, values)
            return rows

        if not self.data:
            return
        domain = self.data.domain
        orig_domain = self.orig_domain
        sections = (
            (name, _report_variables(vars, ovars))
            for name, vars, ovars in (
                ("Features", domain.attributes, orig_domain.attributes),
                ("Outcome" + "s" * (len(domain.class_vars) > 1),
                 domain.class_vars, orig_domain.class_vars),
                ("Meta attributes", domain.metas, orig_domain.metas)))
        table = "".join("<tr><th>{}</th></tr>{}".format(name, rows)
                        for name, rows in sections if rows)
        if table:
            self.report_raw("<table>{}</table>".format(table))
Exemple #21
0
class OWSNR(OWWidget):
    # Widget's name as displayed in the canvas
    name = "SNR"

    # Short widget description
    description = (
        "Calculates Signal-to-Noise Ratio (SNR), Averages or Standard Deviation by coordinates."
    )

    icon = "icons/snr.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        final_data = Output("SNR", Orange.data.Table, default=True)

    OUT_OPTIONS = {
        'Signal-to-noise ratio': 0,  #snr
        'Average': 1,  # average
        'Standard Deviation': 2
    }  # std

    settingsHandler = settings.DomainContextHandler()
    group_x = settings.ContextSetting(None)
    group_y = settings.ContextSetting(None)
    out_choiced = settings.Setting(0)

    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    class Warning(OWWidget.Warning):
        nodata = Msg("No useful data on input!")

    def __init__(self):
        super().__init__()

        self.data = None
        self.set_data(self.data)  # show warning

        self.group_x = None
        self.group_y = None

        # methods in this widget assume group axes in metas
        self.xy_model = DomainModel(DomainModel.METAS,
                                    placeholder="None",
                                    separators=False,
                                    valid_types=Orange.data.ContinuousVariable)
        self.group_view_x = gui.comboBox(self.controlArea,
                                         self,
                                         "group_x",
                                         box="Select axis: x",
                                         model=self.xy_model,
                                         callback=self.grouping_changed)

        self.group_view_y = gui.comboBox(self.controlArea,
                                         self,
                                         "group_y",
                                         box="Select axis: y",
                                         model=self.xy_model,
                                         callback=self.grouping_changed)

        self.selected_out = gui.comboBox(self.controlArea,
                                         self,
                                         "out_choiced",
                                         box="Select Output:",
                                         items=self.OUT_OPTIONS,
                                         callback=self.out_choice_changed)

        gui.auto_commit(self.controlArea, self, "autocommit", "Apply")

        # prepare interface according to the new context
        self.contextAboutToBeOpened.connect(
            lambda x: self.init_attr_values(x[0]))

    def init_attr_values(self, domain):
        self.xy_model.set_domain(domain)
        self.group_x = None
        self.group_y = None

    @Inputs.data
    def set_data(self, dataset):
        self.Warning.nodata.clear()
        self.closeContext()
        self.data = dataset
        self.group = None
        if dataset is None:
            self.Warning.nodata()
        else:
            self.openContext(dataset.domain)

        self.commit()

    def calc_table_np(self, array):
        if len(array) == 0:
            return array
        if self.out_choiced == 0:  #snr
            return self.make_table(
                (bottleneck.nanmean(array, axis=0) /
                 bottleneck.nanstd(array, axis=0)).reshape(1, -1), self.data)
        elif self.out_choiced == 1:  #avg
            return self.make_table(
                bottleneck.nanmean(array, axis=0).reshape(1, -1), self.data)
        else:  # std
            return self.make_table(
                bottleneck.nanstd(array, axis=0).reshape(1, -1), self.data)

    @staticmethod
    def make_table(array, data_table):
        new_table = Orange.data.Table.from_numpy(
            data_table.domain,
            X=array.copy(),
            Y=np.atleast_2d(data_table.Y[0]).copy(),
            metas=np.atleast_2d(data_table.metas[0]).copy())
        cont_vars = data_table.domain.class_vars + data_table.domain.metas
        with new_table.unlocked():
            for var in cont_vars:
                index = data_table.domain.index(var)
                col, _ = data_table.get_column_view(index)
                val = var.to_val(new_table[0, var])
                if not np.all(col == val):
                    new_table[0, var] = Orange.data.Unknown

        return new_table

    def grouping_changed(self):
        """Calls commit() indirectly to respect auto_commit setting."""
        self.commit()

    def out_choice_changed(self):
        self.commit()

    def select_2coordinates(self, attr_x, attr_y):
        xat = self.data.domain[attr_x]
        yat = self.data.domain[attr_y]

        def extract_col(data, var):
            nd = Orange.data.Domain([var])
            d = self.data.transform(nd)
            return d.X[:, 0]

        coorx = extract_col(self.data, xat)
        coory = extract_col(self.data, yat)

        lsx = values_to_linspace(coorx)
        lsy = values_to_linspace(coory)

        xindex, xnan = index_values_nan(coorx, lsx)
        yindex, ynan = index_values_nan(coory, lsy)

        # trick:
        # https://stackoverflow.com/questions/31878240/numpy-average-of-values-corresponding-to-unique-coordinate-positions

        coo = np.hstack([xindex.reshape(-1, 1), yindex.reshape(-1, 1)])
        sortidx = np.lexsort(coo.T)
        sorted_coo = coo[sortidx]
        unqID_mask = np.append(True, np.any(np.diff(sorted_coo, axis=0),
                                            axis=1))
        ID = unqID_mask.cumsum() - 1
        unq_coo = sorted_coo[unqID_mask]
        unique, counts = np.unique(ID, return_counts=True)

        pos = 0
        bins = []
        for size in counts:
            bins.append(sortidx[pos:pos + size])
            pos += size

        matrix = []
        for indices in bins:
            selection = self.data.X[indices]
            array = self.calc_table_np(selection)
            matrix.append(array)
        table_2_coord = Orange.data.Table.concatenate(matrix, axis=0)

        with table_2_coord.unlocked():
            table_2_coord[:, attr_x] = np.linspace(*lsx)[unq_coo[:,
                                                                 0]].reshape(
                                                                     -1, 1)
            table_2_coord[:, attr_y] = np.linspace(*lsy)[unq_coo[:,
                                                                 1]].reshape(
                                                                     -1, 1)
        return table_2_coord

    def select_1coordinate(self, attr):
        at = self.data.domain[attr]

        def extract_col(data, var):
            nd = Orange.data.Domain([var])
            d = self.data.transform(nd)
            return d.X[:, 0]

        coor = extract_col(self.data, at)
        ls = values_to_linspace(coor)
        index, _ = index_values_nan(coor, ls)
        coo = np.hstack([index.reshape(-1, 1)])
        sortidx = np.lexsort(coo.T)
        sorted_coo = coo[sortidx]
        unqID_mask = np.append(True, np.any(np.diff(sorted_coo, axis=0),
                                            axis=1))
        ID = unqID_mask.cumsum() - 1
        unq_coo = sorted_coo[unqID_mask]
        unique, counts = np.unique(ID, return_counts=True)

        pos = 0
        bins = []
        for size in counts:
            bins.append(sortidx[pos:pos + size])
            pos += size

        matrix = []
        for indices in bins:
            selection = self.data.X[indices]
            array = self.calc_table_np(selection)
            matrix.append(array)
        table_1_coord = Orange.data.Table.concatenate(matrix, axis=0)

        with table_1_coord.unlocked():
            table_1_coord[:,
                          attr] = np.linspace(*ls)[unq_coo[:,
                                                           0]].reshape(-1, 1)

        return table_1_coord

    def select_coordinate(self):
        if self.group_y is None and self.group_x is None:
            final_data = self.calc_table_np(self.data.X)
        elif None in [self.group_x, self.group_y]:
            if self.group_x is None:
                group = self.group_y
            else:
                group = self.group_x
            final_data = self.select_1coordinate(group)
        else:
            final_data = self.select_2coordinates(self.group_x, self.group_y)

        return final_data

    def commit(self):
        final_data = None
        if self.data is not None:
            final_data = self.select_coordinate()

        self.Outputs.final_data.send(final_data)
Exemple #22
0
class OWDistanceMap(widget.OWWidget):
    name = "距离地图"
    description = "可视化距离矩阵。"
    icon = "icons/DistanceMap.svg"
    priority = 1200
    keywords = []

    class Inputs:
        distances = Input("距离", Orange.misc.DistMatrix)

    class Outputs:
        selected_data = Output("被选数据", Orange.data.Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)
        features = Output("特征", widget.AttributeList, dynamic=False)

    settingsHandler = settings.PerfectDomainContextHandler()

    #: type of ordering to apply to matrix rows/columns
    NoOrdering, Clustering, OrderedClustering = 0, 1, 2

    sorting = settings.Setting(NoOrdering)

    colormap = settings.Setting(_default_colormap_index)
    color_gamma = settings.Setting(0.0)
    color_low = settings.Setting(0.0)
    color_high = settings.Setting(1.0)

    annotation_idx = settings.ContextSetting(0)

    autocommit = settings.Setting(True)

    graph_name = "grid_widget"

    # Disable clustering for inputs bigger than this
    if hierarchical._HAS_NN_CHAIN:
        _MaxClustering = 25000
    else:
        _MaxClustering = 3000

    # Disable cluster leaf ordering for inputs bigger than this
    _MaxOrderedClustering = 1000

    def __init__(self):
        super().__init__()

        self.matrix = None
        self._tree = None
        self._ordered_tree = None
        self._sorted_matrix = None
        self._sort_indices = None
        self._selection = None

        self.sorting_cb = gui.comboBox(self.controlArea,
                                       self,
                                       "sorting",
                                       box="排序",
                                       items=["无", "聚类", "使用有序叶子进行聚类"],
                                       callback=self._invalidate_ordering)

        box = gui.vBox(self.controlArea, "颜色")
        self.colormap_cb = gui.comboBox(box,
                                        self,
                                        "colormap",
                                        callback=self._update_color)
        self.colormap_cb.setIconSize(QSize(64, 16))
        self.palettes = list(_color_palettes)

        init_color_combo(self.colormap_cb, self.palettes, QSize(64, 16))
        self.colormap_cb.setCurrentIndex(self.colormap)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        #         form.addRow(
        #             "Gamma",
        #             gui.hSlider(box, self, "color_gamma", minValue=0.0, maxValue=1.0,
        #                         step=0.05, ticks=True, intOnly=False,
        #                         createLabel=False, callback=self._update_color)
        #         )
        form.addRow(
            "低:",
            gui.hSlider(box,
                        self,
                        "color_low",
                        minValue=0.0,
                        maxValue=1.0,
                        step=0.05,
                        ticks=True,
                        intOnly=False,
                        createLabel=False,
                        callback=self._update_color))
        form.addRow(
            "高:",
            gui.hSlider(box,
                        self,
                        "color_high",
                        minValue=0.0,
                        maxValue=1.0,
                        step=0.05,
                        ticks=True,
                        intOnly=False,
                        createLabel=False,
                        callback=self._update_color))
        box.layout().addLayout(form)

        self.annot_combo = gui.comboBox(self.controlArea,
                                        self,
                                        "annotation_idx",
                                        box="注解",
                                        callback=self._invalidate_annotations,
                                        contentsLength=12)
        self.annot_combo.setModel(itemmodels.VariableListModel())
        self.annot_combo.model()[:] = ["无", "列举"]
        self.controlArea.layout().addStretch()

        gui.auto_commit(self.controlArea, self, "autocommit", "选中发送")

        self.view = pg.GraphicsView(background="w")
        self.mainArea.layout().addWidget(self.view)

        self.grid_widget = pg.GraphicsWidget()
        self.grid = QGraphicsGridLayout()
        self.grid_widget.setLayout(self.grid)

        self.viewbox = pg.ViewBox(enableMouse=False, enableMenu=False)
        self.viewbox.setAcceptedMouseButtons(Qt.NoButton)
        self.viewbox.setAcceptHoverEvents(False)
        self.grid.addItem(self.viewbox, 1, 1)

        self.left_dendrogram = DendrogramWidget(
            self.grid_widget,
            orientation=DendrogramWidget.Left,
            selectionMode=DendrogramWidget.NoSelection,
            hoverHighlightEnabled=False)
        self.left_dendrogram.setAcceptedMouseButtons(Qt.NoButton)
        self.left_dendrogram.setAcceptHoverEvents(False)

        self.top_dendrogram = DendrogramWidget(
            self.grid_widget,
            orientation=DendrogramWidget.Top,
            selectionMode=DendrogramWidget.NoSelection,
            hoverHighlightEnabled=False)
        self.top_dendrogram.setAcceptedMouseButtons(Qt.NoButton)
        self.top_dendrogram.setAcceptHoverEvents(False)

        self.grid.addItem(self.left_dendrogram, 1, 0)
        self.grid.addItem(self.top_dendrogram, 0, 1)

        self.right_labels = TextList(alignment=Qt.AlignLeft)

        self.bottom_labels = TextList(orientation=Qt.Horizontal,
                                      alignment=Qt.AlignRight)

        self.grid.addItem(self.right_labels, 1, 2)
        self.grid.addItem(self.bottom_labels, 2, 1)

        self.view.setCentralItem(self.grid_widget)

        self.left_dendrogram.hide()
        self.top_dendrogram.hide()
        self.right_labels.hide()
        self.bottom_labels.hide()

        self.matrix_item = None
        self.dendrogram = None

        self.grid_widget.scene().installEventFilter(self)

    @Inputs.distances
    def set_distances(self, matrix):
        self.closeContext()
        self.clear()
        self.error()
        if matrix is not None:
            N, _ = matrix.shape
            if N < 2:
                self.error("Empty distance matrix.")
                matrix = None

        self.matrix = matrix
        if matrix is not None:
            self.set_items(matrix.row_items, matrix.axis)
        else:
            self.set_items(None)

        if matrix is not None:
            N, _ = matrix.shape
        else:
            N = 0

        model = self.sorting_cb.model()
        item = model.item(2)

        msg = None
        if N > OWDistanceMap._MaxOrderedClustering:
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sorting == OWDistanceMap.OrderedClustering:
                self.sorting = OWDistanceMap.Clustering
                msg = "Cluster ordering was disabled due to the input " \
                      "matrix being to big"
        else:
            item.setFlags(item.flags() | Qt.ItemIsEnabled)

        item = model.item(1)
        if N > OWDistanceMap._MaxClustering:
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sorting == OWDistanceMap.Clustering:
                self.sorting = OWDistanceMap.NoOrdering
            msg = "Clustering was disabled due to the input " \
                  "matrix being to big"
        else:
            item.setFlags(item.flags() | Qt.ItemIsEnabled)

        self.information(msg)

    def set_items(self, items, axis=1):
        self.items = items
        model = self.annot_combo.model()
        if items is None:
            model[:] = ["None", "Enumeration"]
        elif not axis:
            model[:] = ["None", "Enumeration", "Attribute names"]
        elif isinstance(items, Orange.data.Table):
            annot_vars = list(filter_visible(items.domain.variables)) + list(
                items.domain.metas)
            model[:] = ["None", "Enumeration"] + annot_vars
            self.annotation_idx = 0
            self.openContext(items.domain)
        elif isinstance(items, list) and \
                all(isinstance(item, Orange.data.Variable) for item in items):
            model[:] = ["None", "Enumeration", "Name"]
        else:
            model[:] = ["None", "Enumeration"]
        self.annotation_idx = min(self.annotation_idx, len(model) - 1)

    def clear(self):
        self.matrix = None
        self.cluster = None
        self._tree = None
        self._ordered_tree = None
        self._sorted_matrix = None
        self._selection = []
        self._clear_plot()

    def handleNewSignals(self):
        if self.matrix is not None:
            self._update_ordering()
            self._setup_scene()
            self._update_labels()
        self.unconditional_commit()

    def _clear_plot(self):
        def remove(item):
            item.setParentItem(None)
            item.scene().removeItem(item)

        if self.matrix_item is not None:
            self.matrix_item.selectionChanged.disconnect(
                self._invalidate_selection)
            remove(self.matrix_item)
            self.matrix_item = None

        self._set_displayed_dendrogram(None)
        self._set_labels(None)

    def _cluster_tree(self):
        if self._tree is None:
            self._tree = hierarchical.dist_matrix_clustering(self.matrix)
        return self._tree

    def _ordered_cluster_tree(self):
        if self._ordered_tree is None:
            tree = self._cluster_tree()
            self._ordered_tree = \
                hierarchical.optimal_leaf_ordering(tree, self.matrix)
        return self._ordered_tree

    def _setup_scene(self):
        self._clear_plot()
        self.matrix_item = DistanceMapItem(self._sorted_matrix)
        # Scale the y axis to compensate for pg.ViewBox's y axis invert
        self.matrix_item.setTransform(QTransform.fromScale(1, -1), )
        self.viewbox.addItem(self.matrix_item)
        # Set fixed view box range.
        h, w = self._sorted_matrix.shape
        self.viewbox.setRange(QRectF(0, -h, w, h), padding=0)

        self.matrix_item.selectionChanged.connect(self._invalidate_selection)

        if self.sorting == OWDistanceMap.NoOrdering:
            tree = None
        elif self.sorting == OWDistanceMap.Clustering:
            tree = self._cluster_tree()
        elif self.sorting == OWDistanceMap.OrderedClustering:
            tree = self._ordered_cluster_tree()

        self._set_displayed_dendrogram(tree)

        self._update_color()

    def _set_displayed_dendrogram(self, root):
        self.left_dendrogram.set_root(root)
        self.top_dendrogram.set_root(root)
        self.left_dendrogram.setVisible(root is not None)
        self.top_dendrogram.setVisible(root is not None)

        constraint = 0 if root is None else -1  # 150
        self.left_dendrogram.setMaximumWidth(constraint)
        self.top_dendrogram.setMaximumHeight(constraint)

    def _invalidate_ordering(self):
        self._sorted_matrix = None
        if self.matrix is not None:
            self._update_ordering()
            self._setup_scene()
            self._update_labels()
            self._invalidate_selection()

    def _update_ordering(self):
        if self.sorting == OWDistanceMap.NoOrdering:
            self._sorted_matrix = self.matrix
            self._sort_indices = None
        else:
            if self.sorting == OWDistanceMap.Clustering:
                tree = self._cluster_tree()
            elif self.sorting == OWDistanceMap.OrderedClustering:
                tree = self._ordered_cluster_tree()

            leaves = hierarchical.leaves(tree)
            indices = numpy.array([leaf.value.index for leaf in leaves])
            X = self.matrix
            self._sorted_matrix = X[indices[:, numpy.newaxis],
                                    indices[numpy.newaxis, :]]
            self._sort_indices = indices

    def _invalidate_annotations(self):
        if self.matrix is not None:
            self._update_labels()

    def _update_labels(self, ):
        if self.annotation_idx == 0:  # None
            labels = None
        elif self.annotation_idx == 1:  # Enumeration
            labels = [str(i + 1) for i in range(self.matrix.shape[0])]
        elif self.annot_combo.model()[
                self.annotation_idx] == "Attribute names":
            attr = self.matrix.row_items.domain.attributes
            labels = [str(attr[i]) for i in range(self.matrix.shape[0])]
        elif self.annotation_idx == 2 and \
                isinstance(self.items, widget.AttributeList):
            labels = [v.name for v in self.items]
        elif isinstance(self.items, Orange.data.Table):
            var = self.annot_combo.model()[self.annotation_idx]
            column, _ = self.items.get_column_view(var)
            labels = [var.str_val(value) for value in column]

        self._set_labels(labels)

    def _set_labels(self, labels):
        self._labels = labels

        if labels and self.sorting != OWDistanceMap.NoOrdering:
            sortind = self._sort_indices
            labels = [labels[i] for i in sortind]

        for textlist in [self.right_labels, self.bottom_labels]:
            textlist.set_labels(labels or [])
            textlist.setVisible(bool(labels))

        constraint = -1 if labels else 0
        self.right_labels.setMaximumWidth(constraint)
        self.bottom_labels.setMaximumHeight(constraint)

    def _update_color(self):
        if self.matrix_item:
            name, colors = self.palettes[self.colormap]
            n, colors = max(colors.items())
            colors = numpy.array(colors, dtype=numpy.ubyte)
            low, high = self.color_low * 255, self.color_high * 255
            points = numpy.linspace(low, high, n)
            space = numpy.linspace(0, 255, 255)

            r = numpy.interp(space, points, colors[:, 0], left=255, right=0)
            g = numpy.interp(space, points, colors[:, 1], left=255, right=0)
            b = numpy.interp(space, points, colors[:, 2], left=255, right=0)
            colortable = numpy.c_[r, g, b]
            self.matrix_item.setLookupTable(colortable)

    def _invalidate_selection(self):
        ranges = self.matrix_item.selections()
        ranges = reduce(iadd, ranges, [])
        indices = reduce(iadd, ranges, [])
        if self.sorting != OWDistanceMap.NoOrdering:
            sortind = self._sort_indices
            indices = [sortind[i] for i in indices]
        self._selection = list(sorted(set(indices)))
        self.commit()

    def commit(self):
        datasubset = None
        featuresubset = None

        if not self._selection:
            pass
        elif isinstance(self.items, Orange.data.Table):
            indices = self._selection
            if self.matrix.axis == 1:
                datasubset = self.items.from_table_rows(self.items, indices)
            elif self.matrix.axis == 0:
                domain = Orange.data.Domain(
                    [self.items.domain[i] for i in indices],
                    self.items.domain.class_vars, self.items.domain.metas)
                datasubset = self.items.transform(domain)
        elif isinstance(self.items, widget.AttributeList):
            subset = [self.items[i] for i in self._selection]
            featuresubset = widget.AttributeList(subset)

        self.Outputs.selected_data.send(datasubset)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.items, self._selection))
        self.Outputs.features.send(featuresubset)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.clear()

    def send_report(self):
        annot = self.annot_combo.currentText()
        if self.annotation_idx <= 1:
            annot = annot.lower()
        self.report_items((("Sorting", self.sorting_cb.currentText().lower()),
                           ("Annotations", annot)))
        if self.matrix is not None:
            self.report_plot()
class OWCorrespondenceAnalysis(widget.OWWidget):
    name = "Correspondence Analysis"
    description = "Correspondence analysis for categorical multivariate data."
    icon = "icons/CorrespondenceAnalysis.svg"
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)

    Invalidate = QEvent.registerEventType()

    settingsHandler = settings.DomainContextHandler()

    selected_var_indices = settings.ContextSetting([])

    graph_name = "plot.plotItem"

    class Error(widget.OWWidget.Error):
        empty_data = widget.Msg("Empty dataset")
        no_disc_vars = widget.Msg("No categorical data")

    def __init__(self):
        super().__init__()

        self.data = None
        self.component_x = 0
        self.component_y = 1

        box = gui.vBox(self.controlArea, "Variables")
        self.varlist = itemmodels.VariableListModel()
        self.varview = view = QListView(selectionMode=QListView.MultiSelection,
                                        uniformItemSizes=True)
        view.setModel(self.varlist)
        view.selectionModel().selectionChanged.connect(self._var_changed)

        box.layout().addWidget(view)

        axes_box = gui.vBox(self.controlArea, "Axes")
        box = gui.vBox(axes_box, "Axis X", margin=0)
        box.setFlat(True)
        self.axis_x_cb = gui.comboBox(box,
                                      self,
                                      "component_x",
                                      callback=self._component_changed)

        box = gui.vBox(axes_box, "Axis Y", margin=0)
        box.setFlat(True)
        self.axis_y_cb = gui.comboBox(box,
                                      self,
                                      "component_y",
                                      callback=self._component_changed)

        self.infotext = gui.widgetLabel(
            gui.vBox(self.controlArea, "Contribution to Inertia"), "\n")

        gui.rubber(self.controlArea)

        self.plot = pg.PlotWidget(background="w")
        self.plot.setMenuEnabled(False)
        self.mainArea.layout().addWidget(self.plot)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.Error.clear()

        if data is not None and not len(data):
            self.Error.empty_data()
            data = None

        self.data = data
        if data is not None:
            self.varlist[:] = [
                var for var in data.domain.variables if var.is_discrete
            ]
            if not len(self.varlist[:]):
                self.Error.no_disc_vars()
                self.data = None
            else:
                self.selected_var_indices = [0, 1][:len(self.varlist)]
                self.component_x = 0
                self.component_y = int(
                    len(self.varlist[self.selected_var_indices[-1]].values) > 1
                )
                self.openContext(data)
                self._restore_selection()
        self._update_CA()

    def clear(self):
        self.data = None
        self.ca = None
        self.plot.clear()
        self.varlist[:] = []

    def selected_vars(self):
        rows = sorted(ind.row()
                      for ind in self.varview.selectionModel().selectedRows())
        return [self.varlist[i] for i in rows]

    def _restore_selection(self):
        def restore(view, indices):
            with itemmodels.signal_blocking(view.selectionModel()):
                select_rows(view, indices)

        restore(self.varview, self.selected_var_indices)

    def _p_axes(self):
        #         return (0, 1)
        return (self.component_x, self.component_y)

    def _var_changed(self):
        self.selected_var_indices = sorted(
            ind.row() for ind in self.varview.selectionModel().selectedRows())
        rfs = self.update_XY()
        if rfs is not None:
            if self.component_x >= rfs:
                self.component_x = rfs - 1
            if self.component_y >= rfs:
                self.component_y = rfs - 1
        self._invalidate()

    def _component_changed(self):
        if self.ca is not None:
            self._setup_plot()
            self._update_info()

    def _invalidate(self):
        self.__invalidated = True
        QApplication.postEvent(self, QEvent(self.Invalidate))

    def customEvent(self, event):
        if event.type() == self.Invalidate:
            self.ca = None
            self.plot.clear()
            self._update_CA()
            return
        return super().customEvent(event)

    def _update_CA(self):
        self.update_XY()
        self.component_x, self.component_y = self.component_x, self.component_y

        self._setup_plot()
        self._update_info()

    def update_XY(self):
        self.axis_x_cb.clear()
        self.axis_y_cb.clear()
        ca_vars = self.selected_vars()
        if len(ca_vars) == 0:
            return

        multi = len(ca_vars) != 2
        if multi:
            _, ctable = burt_table(self.data, ca_vars)
        else:
            ctable = contingency.get_contingency(self.data, *ca_vars[::-1])

        self.ca = correspondence(ctable, )
        rfs = self.ca.row_factors.shape[1]
        axes = ["{}".format(i + 1) for i in range(rfs)]
        self.axis_x_cb.addItems(axes)
        self.axis_y_cb.addItems(axes)
        return rfs

    def _setup_plot(self):
        def get_minmax(points):
            minmax = [float('inf'), float('-inf'), float('inf'), float('-inf')]
            for pp in points:
                for p in pp:
                    minmax[0] = min(p[0], minmax[0])
                    minmax[1] = max(p[0], minmax[1])
                    minmax[2] = min(p[1], minmax[2])
                    minmax[3] = max(p[1], minmax[3])
            return minmax

        self.plot.clear()
        points = self.ca
        variables = self.selected_vars()
        colors = colorpalette.ColorPaletteGenerator(len(variables))

        p_axes = self._p_axes()

        if points is None:
            return

        if len(variables) == 2:
            row_points = self.ca.row_factors[:, p_axes]
            col_points = self.ca.col_factors[:, p_axes]
            points = [row_points, col_points]
        else:
            points = self.ca.row_factors[:, p_axes]
            counts = [len(var.values) for var in variables]
            range_indices = np.cumsum([0] + counts)
            ranges = zip(range_indices, range_indices[1:])
            points = [points[s:e] for s, e in ranges]

        minmax = get_minmax(points)

        margin = abs(minmax[0] - minmax[1])
        margin = margin * 0.05 if margin > 1e-10 else 1
        self.plot.setXRange(minmax[0] - margin, minmax[1] + margin)
        margin = abs(minmax[2] - minmax[3])
        margin = margin * 0.05 if margin > 1e-10 else 1
        self.plot.setYRange(minmax[2] - margin, minmax[3] + margin)

        for i, (v, points) in enumerate(zip(variables, points)):
            color_outline = colors[i]
            color_outline.setAlpha(200)
            color = QColor(color_outline)
            color.setAlpha(120)
            item = ScatterPlotItem(
                x=points[:, 0],
                y=points[:, 1],
                brush=QBrush(color),
                pen=pg.mkPen(color_outline.darker(120), width=1.5),
                size=np.full((points.shape[0], ), 10.1),
            )
            self.plot.addItem(item)

            for name, point in zip(v.values, points):
                item = pg.TextItem(name, anchor=(0.5, 0))
                self.plot.addItem(item)
                item.setPos(point[0], point[1])

        inertia = self.ca.inertia_of_axis()
        if np.sum(inertia) == 0:
            inertia = 100 * inertia
        else:
            inertia = 100 * inertia / np.sum(inertia)

        ax = self.plot.getAxis("bottom")
        ax.setLabel("Component {} ({:.1f}%)".format(p_axes[0] + 1,
                                                    inertia[p_axes[0]]))
        ax = self.plot.getAxis("left")
        ax.setLabel("Component {} ({:.1f}%)".format(p_axes[1] + 1,
                                                    inertia[p_axes[1]]))

    def _update_info(self):
        if self.ca is None:
            self.infotext.setText("\n\n")
        else:
            fmt = ("Axis 1: {:.2f}\n" "Axis 2: {:.2f}")
            inertia = self.ca.inertia_of_axis()
            if np.sum(inertia) == 0:
                inertia = 100 * inertia
            else:
                inertia = 100 * inertia / np.sum(inertia)

            ax1, ax2 = self._p_axes()
            self.infotext.setText(fmt.format(inertia[ax1], inertia[ax2]))

    def send_report(self):
        if self.data is None:
            return

        vars = self.selected_vars()
        if not vars:
            return

        items = OrderedDict()
        items["Data instances"] = len(self.data)
        if len(vars) == 1:
            items["Selected variable"] = vars[0]
        else:
            items["Selected variables"] = "{} and {}".format(
                ", ".join(var.name for var in vars[:-1]), vars[-1].name)
        self.report_items(items)

        self.report_plot()
Exemple #24
0
class OWTextableLength(OWTextableBaseWidget):
    """Orange widget for length computation"""

    name = "Length"
    description = "Compute the (average) length of segments"
    icon = "icons/Length.png"
    priority = 8003

    inputs = [('Segmentation', Segmentation, "inputData", widget.Multiple)]
    outputs = [('Textable table', Table, widget.Default),
               ('Orange table', Orange.data.Table)]

    settingsHandler = SegmentationListContextHandler(
        version=__version__.rsplit(".", 1)[0]
    )
    segmentations = SegmentationsInputList()  # type: list

    # Settings...
    computeAverage = settings.Setting(False)
    computeStdev = settings.Setting(False)
    mergeContexts = settings.Setting(False)

    units = settings.ContextSetting(-1)
    averagingSegmentation = settings.ContextSetting(-1)
    _contexts = settings.ContextSetting(-1)
    mode = settings.ContextSetting(u'No context')
    contextAnnotationKey = settings.ContextSetting(u'(none)')

    want_main_area = False

    def __init__(self, *args, **kwargs):
        """Initialize a Length widget"""
        super().__init__(*args, **kwargs)

        self.windowSize = 1

        self.infoBox = InfoBox(
            widget=self.controlArea,
            stringClickSend=u", please click 'Send' when ready.",
        )
        self.sendButton = SendButton(
            widget=self.controlArea,
            master=self,
            callback=self.sendData,
            infoBoxAttribute='infoBox',
            buttonLabel=u'Send',
            checkboxLabel=u'Send automatically',
            sendIfPreCallback=self.updateGUI,
        )

        # GUI...

        # Units box
        self.unitsBox = gui.widgetBox(
            widget=self.controlArea,
            box=u'Units',
            orientation='vertical',
            addSpace=True,
        )
        self.unitSegmentationCombo = gui.comboBox(
            widget=self.unitsBox,
            master=self,
            value='units',
            orientation='horizontal',
            label=u'Segmentation:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"The segmentation whose segments constitute the\n"
                u"units of length."
            ),
        )
        self.unitSegmentationCombo.setMinimumWidth(120)
        gui.separator(widget=self.unitsBox, height=3)

        # Averaging box...
        self.averagingBox = gui.widgetBox(
            widget=self.controlArea,
            box=u'Averaging',
            orientation='vertical',
            addSpace=True,
        )
        averagingBoxLine1 = gui.widgetBox(
            widget=self.averagingBox,
            box=False,
            orientation='horizontal',
            addSpace=True,
        )
        gui.checkBox(
            widget=averagingBoxLine1,
            master=self,
            value='computeAverage',
            label=u'Average over segmentation:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"Check this box in order to measure the average\n"
                u"length of segments.\n\n"
                u"Leaving this box unchecked implies that no\n"
                u"averaging will take place."
            ),
        )
        self.averagingSegmentationCombo = gui.comboBox(
            widget=averagingBoxLine1,
            master=self,
            value='averagingSegmentation',
            orientation='horizontal',
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"The segmentation whose segment length will be\n"
                u"measured and averaged (if the box to the left\n"
                u"is checked)."
            ),
        )
        self.computeStdevCheckBox = gui.checkBox(
            widget=self.averagingBox,
            master=self,
            value='computeStdev',
            label=u'Compute standard deviation',
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"Check this box to compute not only length average\n"
                u"but also standard deviation (if the above box\n"
                u"is checked).\n\n"
                u"Note that computing standard deviation can be a\n"
                u"lengthy operation for large segmentations."
            ),
        )
        gui.separator(widget=self.averagingBox, height=2)

        # Contexts box...
        self.contextsBox = gui.widgetBox(
            widget=self.controlArea,
            box=u'Contexts',
            orientation='vertical',
            addSpace=True,
        )
        self.modeCombo = gui.comboBox(
            widget=self.contextsBox,
            master=self,
            value='mode',
            sendSelectedValue=True,
            items=[
                u'No context',
                u'Sliding window',
                u'Containing segmentation',
            ],
            orientation='horizontal',
            label=u'Mode:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"Context specification mode.\n\n"
                u"Contexts define the rows of the resulting table.\n\n"
                u"'No context': simply return the length of the\n"
                u"'Units' segmentation, or the average length of\n"
                u"segments in the 'Averaging' segmentation (if any),\n"
                u"so that the output table contains a single row.\n\n"
                u"'Sliding window': contexts are defined as all the\n"
                u"successive, overlapping sequences of n segments\n"
                u"in the 'Averaging' segmentation; this mode is\n"
                u"available only if the 'Averaging' box is checked.\n\n"
                u"'Containing segmentation': contexts are defined\n"
                u"as the distinct segments occurring in a given\n"
                u"segmentation (which may or may not be the same\n"
                u"as the 'Units' and/or 'Averaging' segmentation)."
            ),
        )
        self.slidingWindowBox = gui.widgetBox(
            widget=self.contextsBox,
            orientation='vertical',
        )
        gui.separator(widget=self.slidingWindowBox, height=3)
        self.windowSizeSpin = gui.spin(
            widget=self.slidingWindowBox,
            master=self,
            value='windowSize',
            minv=1,
            maxv=1,
            step=1,
            orientation='horizontal',
            label=u'Window size:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            keyboardTracking=False,
            tooltip=(
                u"The length of segment sequences defining contexts."
            ),
        )
        self.containingSegmentationBox = gui.widgetBox(
            widget=self.contextsBox,
            orientation='vertical',
        )
        gui.separator(widget=self.containingSegmentationBox, height=3)
        self.contextSegmentationCombo = gui.comboBox(
            widget=self.containingSegmentationBox,
            master=self,
            value='_contexts',
            orientation='horizontal',
            label=u'Segmentation:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"The segmentation whose segment types define\n"
                u"the contexts in which length will be measured."
            ),
        )
        gui.separator(widget=self.containingSegmentationBox, height=3)
        self.contextAnnotationCombo = gui.comboBox(
            widget=self.containingSegmentationBox,
            master=self,
            value='contextAnnotationKey',
            sendSelectedValue=True,
            emptyString=u'(none)',
            orientation='horizontal',
            label=u'Annotation key:',
            labelWidth=190,
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"Indicate whether context types are defined by\n"
                u"the content of segments in the above specified\n"
                u"segmentation (value 'none') or by their\n"
                u"annotation values for a specific annotation key."
            ),
        )
        gui.separator(widget=self.containingSegmentationBox, height=3)
        gui.checkBox(
            widget=self.containingSegmentationBox,
            master=self,
            value='mergeContexts',
            label=u'Merge contexts',
            callback=self.sendButton.settingsChanged,
            tooltip=(
                u"Check this box if you want to treat all segments\n"
                u"of the above specified segmentation as forming\n"
                u"a single context (hence the resulting table\n"
                u"contains a single row)."
            ),
        )
        gui.separator(widget=self.contextsBox, height=3)

        gui.rubber(self.controlArea)

        # Send button...
        self.sendButton.draw()

        # Info box...
        self.infoBox.draw()

        self.sendButton.sendIf()
        self.adjustSizeWithTimer()

    def inputData(self, newItem, newId=None):
        """Process incoming data."""
        self.closeContext()
        updateMultipleInputs(
            self.segmentations,
            newItem,
            newId,
            self.onInputRemoval
        )
        self.infoBox.inputChanged()
        self.updateGUI()

    def sendData(self):

        """Check input, compute (average) length table, then send it"""

        # Check that there's something on input...
        if len(self.segmentations) == 0:
            self.infoBox.setText(u'Widget needs input.', 'warning')
            self.send('Textable table', None)
            self.send('Orange table', None)
            return
        assert self.units >= 0

        # Units parameter...
        units = self.segmentations[self.units][1]

        # Averaging parameters...
        if self.computeAverage:
            assert self.averagingSegmentation >= 0
            averaging = {
                'segmentation':self.segmentations[self.averagingSegmentation][1]
            }
            if self.computeStdev:
                averaging['std_deviation'] = True
            else:
                averaging['std_deviation'] = False
        else:
            averaging = None

        self.infoBox.setText(u"Processing, please wait...", "warning")
        self.controlArea.setDisabled(True)

        # Case 1: sliding window...
        if self.mode == 'Sliding window':

            # Compute length...
            progressBar = ProgressBar(
                self,
                iterations=len(units) - (self.windowSize - 1)
            )
            table = Processor.length_in_window(
                units,
                averaging=averaging,
                window_size=self.windowSize,
                progress_callback=progressBar.advance,
            )
            progressBar.finish()

        # Case 2: Containing segmentation or no context...
        else:

            # Parameters for mode 'Containing segmentation'...
            if self.mode == 'Containing segmentation':
                assert self._contexts >= 0
                contexts = {
                    'segmentation': self.segmentations[self._contexts][1],
                    'annotation_key': self.contextAnnotationKey or None,
                    'merge': self.mergeContexts,
                }
                if contexts['annotation_key'] == u'(none)':
                    contexts['annotation_key'] = None
                num_iterations = len(contexts['segmentation'])
            # Parameters for mode 'No context'...
            else:
                contexts = None
                num_iterations = 1

            # Compute length...
            progressBar = ProgressBar(
                self,
                iterations=num_iterations
            )
            table = Processor.length_in_context(
                units,
                averaging,
                contexts,
                progress_callback=progressBar.advance,
            )
            progressBar.finish()

        self.controlArea.setDisabled(False)

        if not len(table.row_ids):
            self.infoBox.setText(u'Resulting table is empty.', 'warning')
            self.send('Textable table', None)
            self.send('Orange table', None)
        else:
            self.infoBox.setText(u'Table sent to output.')
            self.send('Textable table', table)
            self.send('Orange table', table.to_orange_table())

        self.sendButton.resetSettingsChangedFlag()

    def onInputRemoval(self, index):
        """Handle removal of input with given index"""
        if index < self.units:
            self.units -= 1
        elif index == self.units and self.units == len(self.segmentations) - 1:
            self.units -= 1
        if self.mode == u'Containing segmentation':
            if index == self._contexts:
                self.mode = u'No context'
                self._contexts = -1
            elif index < self._contexts:
                self._contexts -= 1
                if self._contexts < 0:
                    self.mode = u'No context'
        if self.computeAverage \
                and self.averagingSegmentation != self.units:
            if index == self.averagingSegmentation:
                self.computeAverage = False
                self.averagingSegmentation = -1
            elif index < self.averagingSegmentation:
                self.averagingSegmentation -= 1
                if self.averagingSegmentation < 0:
                    self.computeAverage = False

    def updateGUI(self):

        """Update GUI state"""

        self.unitSegmentationCombo.clear()
        self.averagingSegmentationCombo.clear()
        self.averagingSegmentationCombo.clear()

        if self.mode == u'No context':
            self.containingSegmentationBox.setVisible(False)
            self.slidingWindowBox.setVisible(False)

        if len(self.segmentations) == 0:
            self.units = -1
            self.unitsBox.setDisabled(True)
            self.averagingBox.setDisabled(True)
            self.mode = 'No context'
            self.contextsBox.setDisabled(True)
            return
        else:
            if len(self.segmentations) == 1:
                self.units = 0
                self.averagingSegmentation = 0
            for segmentation in self.segmentations:
                self.unitSegmentationCombo.addItem(segmentation[1].label)
                self.averagingSegmentationCombo.addItem(segmentation[1].label)
            self.units = max(self.units, 0)
            self.averagingSegmentation = max(self.averagingSegmentation, 0)
            self.unitsBox.setDisabled(False)
            self.averagingBox.setDisabled(False)
            self.contextsBox.setDisabled(False)
            if self.computeAverage:
                if self.modeCombo.itemText(1) != u'Sliding window':
                    self.modeCombo.insertItem(1, u'Sliding window')
                self.averagingSegmentationCombo.setDisabled(False)
                self.computeStdevCheckBox.setDisabled(False)
            else:
                self.averagingSegmentationCombo.setDisabled(True)
                self.computeStdevCheckBox.setDisabled(True)
                self.computeStdev = False
                if self.mode == u'Sliding window':
                    self.mode = u'No context'
                if self.modeCombo.itemText(1) == u'Sliding window':
                    self.modeCombo.removeItem(1)

        if self.mode == 'Sliding window':
            self.containingSegmentationBox.setVisible(False)
            self.slidingWindowBox.setVisible(True)
            self.windowSizeSpin.setRange(
                1,
                len(self.segmentations[self.units][1])
            )
            self.windowSize = self.windowSize or 1

        elif self.mode == 'Containing segmentation':
            self.slidingWindowBox.setVisible(False)
            self.containingSegmentationBox.setVisible(True)
            self.contextSegmentationCombo.clear()
            for index in range(len(self.segmentations)):
                self.contextSegmentationCombo.addItem(
                    self.segmentations[index][1].label
                )
            self._contexts = max(self._contexts, 0)
            segmentation = self.segmentations[self._contexts]
            self.contextAnnotationCombo.clear()
            self.contextAnnotationCombo.addItem(u'(none)')
            contextAnnotationKeys = segmentation[1].get_annotation_keys()
            for key in contextAnnotationKeys:
                self.contextAnnotationCombo.addItem(key)
            if self.contextAnnotationKey not in contextAnnotationKeys:
                self.contextAnnotationKey = u'(none)'
            self.contextAnnotationKey = self.contextAnnotationKey

    def handleNewSignals(self):
        """Overridden: called after multiple signals have been added"""
        self.openContext(self.uuid, self.segmentations)
        self.updateGUI()
        self.sendButton.sendIf()
Exemple #25
0
class OWEditDomain(widget.OWWidget):
    name = "文本编辑"
    description = "重命名变量,编辑种类和变量注释。"
    icon = "icons/EditDomain.svg"
    priority = 3125
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    class Error(widget.OWWidget.Error):
        duplicate_var_name = widget.Msg("A variable name is duplicated.")

    settingsHandler = settings.DomainContextHandler()
    settings_version = 2

    _domain_change_store = settings.ContextSetting({})
    _selected_item = settings.ContextSetting(None)  # type: Optional[str]

    want_control_area = False

    def __init__(self):
        super().__init__()
        self.data = None  # type: Optional[Orange.data.Table]
        #: The current selected variable index
        self.selected_index = -1
        self._invalidated = False

        mainlayout = self.mainArea.layout()
        assert isinstance(mainlayout, QVBoxLayout)
        layout = QHBoxLayout()
        mainlayout.addLayout(layout)
        box = QGroupBox("变量")
        box.setLayout(QVBoxLayout())
        layout.addWidget(box)

        self.variables_model = VariableListModel(parent=self)
        self.variables_view = self.domain_view = QListView(
            selectionMode=QListView.SingleSelection,
            uniformItemSizes=True,
        )
        self.variables_view.setItemDelegate(VariableEditDelegate(self))
        self.variables_view.setModel(self.variables_model)
        self.variables_view.selectionModel().selectionChanged.connect(
            self._on_selection_changed
        )
        box.layout().addWidget(self.variables_view)

        box = QGroupBox("编辑", )
        box.setLayout(QVBoxLayout(margin=4))
        layout.addWidget(box)

        self.editor_stack = QStackedWidget()

        self.editor_stack.addWidget(DiscreteVariableEditor())
        self.editor_stack.addWidget(ContinuousVariableEditor())
        self.editor_stack.addWidget(TimeVariableEditor())
        self.editor_stack.addWidget(VariableEditor())

        box.layout().addWidget(self.editor_stack)

        bbox = QDialogButtonBox()
        bbox.setStyleSheet(
            "button-layout: {:d};".format(QDialogButtonBox.MacLayout))
        bapply = QPushButton(
            "应用",
            objectName="button-apply",
            toolTip="应用更改并在输出时提交数据",
            default=True,
            autoDefault=False
        )
        bapply.clicked.connect(self.commit)
        breset = QPushButton(
            "重置选定",
            objectName="button-reset",
            toolTip="将所选变量静止到其输入状态。",
            autoDefault=False
        )
        breset.clicked.connect(self.reset_selected)
        breset_all = QPushButton(
            "重置全部",
            objectName="button-reset-all",
            toolTip="将所有变量重置为其输入状态。",
            autoDefault=False
        )
        breset_all.clicked.connect(self.reset_all)

        bbox.addButton(bapply, QDialogButtonBox.AcceptRole)
        bbox.addButton(breset, QDialogButtonBox.ResetRole)
        bbox.addButton(breset_all, QDialogButtonBox.ResetRole)

        mainlayout.addWidget(bbox)
        self.variables_view.setFocus(Qt.NoFocusReason)  # initial focus

    @Inputs.data
    def set_data(self, data):
        """Set input dataset."""
        self.closeContext()
        self.clear()
        self.data = data

        if self.data is not None:
            self.set_domain(data.domain)
            self.openContext(self.data)
            self._restore()

        self.commit()

    def clear(self):
        """Clear the widget state."""
        self.data = None
        self.variables_model.clear()
        assert self.selected_index == -1
        self.selected_index = -1

        self._selected_item = None
        self._domain_change_store = {}

    def reset_selected(self):
        """Reset the currently selected variable to its original state."""
        ind = self.selected_var_index()
        if ind >= 0:
            model = self.variables_model
            midx = model.index(ind)
            var = midx.data(Qt.EditRole)
            tr = midx.data(TransformRole)
            if not tr:
                return  # nothing to reset
            editor = self.editor_stack.currentWidget()
            with disconnected(editor.variable_changed,
                              self._on_variable_changed):
                model.setData(midx, [], TransformRole)
                editor.set_data(var, [])
            self._invalidate()

    def reset_all(self):
        """Reset all variables to their original state."""
        self._domain_change_store = {}
        if self.data is not None:
            model = self.variables_model
            for i in range(model.rowCount()):
                midx = model.index(i)
                model.setData(midx, [], TransformRole)
            index = self.selected_var_index()
            if index >= 0:
                self.open_editor(index)
            self._invalidate()

    def selected_var_index(self):
        """Return the current selected variable index."""
        rows = self.variables_view.selectedIndexes()
        assert len(rows) <= 1
        return rows[0].row() if rows else -1

    def set_domain(self, domain):
        # type: (Orange.data.Domain) -> None
        self.variables_model[:] = [abstract(v)
                                   for v in domain.variables + domain.metas]

    def _restore(self, ):
        """
        Restore the edit transform from saved state.
        """
        model = self.variables_model
        for i in range(model.rowCount()):
            midx = model.index(i, 0)
            var = model.data(midx, Qt.EditRole)
            tr = self._restore_transform(var)
            if tr:
                model.setData(midx, tr, TransformRole)

        # Restore the current variable selection
        i = -1
        if self._selected_item is not None:
            for i, var in enumerate(model):
                if var.name == self._selected_item:
                    break
        if i == -1 and model.rowCount():
            i = 0

        if i != -1:
            itemmodels.select_row(self.variables_view, i)

    def _on_selection_changed(self):
        self.selected_index = self.selected_var_index()
        if self.selected_index != -1:
            self._selected_item = self.variables_model[self.selected_index].name
        else:
            self._selected_item = None
        self.open_editor(self.selected_index)

    def open_editor(self, index):
        # type: (int) -> None
        self.clear_editor()
        model = self.variables_model
        if not 0 <= index < model.rowCount():
            return
        idx = model.index(index, 0)
        var = model.data(idx, Qt.EditRole)
        tr = model.data(idx, TransformRole)
        if tr is None:
            tr = []

        editors = {
            Categorical: 0,
            Real: 1,
            Time: 2,
            String: 3
        }

        editor_index = editors.get(type(var), 3)
        editor = self.editor_stack.widget(editor_index)
        self.editor_stack.setCurrentWidget(editor)
        editor.set_data(var, tr)
        editor.variable_changed.connect(
            self._on_variable_changed, Qt.UniqueConnection
        )

    def clear_editor(self):
        current = self.editor_stack.currentWidget()
        try:
            current.variable_changed.disconnect(self._on_variable_changed)
        except TypeError:
            pass
        current.set_data(None)

    @Slot()
    def _on_variable_changed(self):
        """User edited the current variable in editor."""
        assert 0 <= self.selected_index <= len(self.variables_model)
        editor = self.editor_stack.currentWidget()
        var, transform = editor.get_data()
        model = self.variables_model
        midx = model.index(self.selected_index, 0)
        model.setData(midx, transform, TransformRole)
        self._store_transform(var, transform)
        self._invalidate()

    def _store_transform(self, var, transform):
        # type: (Variable, List[Transform]) -> None
        self._domain_change_store[deconstruct(var)] = [deconstruct(t) for t in transform]

    def _restore_transform(self, var):
        # type: (Variable) -> List[Transform]
        tr_ = self._domain_change_store.get(deconstruct(var), [])
        tr = []

        for t in tr_:
            try:
                tr.append(reconstruct(*t))
            except (NameError, TypeError) as err:
                warnings.warn(
                    "Failed to restore transform: {}, {!r}".format(t, err),
                    UserWarning, stacklevel=2
                )
        return tr

    def _invalidate(self):
        self._set_modified(True)

    def _set_modified(self, state):
        self._invalidated = state
        b = self.findChild(QPushButton, "button-apply")
        if isinstance(b, QPushButton):
            f = b.font()
            f.setItalic(state)
            b.setFont(f)

    def commit(self):
        """
        Apply the changes to the input data and send the changed data to output.
        """
        self._set_modified(False)
        self.Error.duplicate_var_name.clear()

        data = self.data
        if data is None:
            self.Outputs.data.send(None)
            return
        model = self.variables_model

        def state(i):
            # type: (int) -> Tuple[Variable, List[Transform]]
            midx = self.variables_model.index(i, 0)
            return (model.data(midx, Qt.EditRole),
                    model.data(midx, TransformRole))

        state = [state(i) for i in range(model.rowCount())]
        if all(tr is None or not tr for _, tr in state):
            self.Outputs.data.send(data)
            return

        output_vars = []
        input_vars = data.domain.variables + data.domain.metas
        assert all(v_.name == v.name
                   for v, (v_, _) in zip(input_vars, state))
        for (_, tr), v in zip(state, input_vars):
            if tr is not None and len(tr) > 0:
                var = apply_transform(v, tr)
            else:
                var = v
            output_vars.append(var)

        if len(output_vars) != len({v.name for v in output_vars}):
            self.Error.duplicate_var_name()
            self.Outputs.data.send(None)
            return

        domain = data.domain
        nx = len(domain.attributes)
        ny = len(domain.class_vars)
        domain = Orange.data.Domain(
            output_vars[:nx], output_vars[nx: nx + ny], output_vars[nx + ny:]
        )
        new_data = data.transform(domain)
        # print(new_data)
        self.Outputs.data.send(new_data)

    def sizeHint(self):
        sh = super().sizeHint()
        return sh.expandedTo(QSize(660, 550))

    def send_report(self):

        if self.data is not None:
            model = self.variables_model
            state = ((model.data(midx, Qt.EditRole),
                      model.data(midx, TransformRole))
                     for i in range(model.rowCount())
                     for midx in [model.index(i)])
            parts = []
            for var, trs in state:
                if trs:
                    parts.append(report_transform(var, trs))
            if parts:
                html = ("<ul>" +
                        "".join(map("<li>{}</li>".format, parts)) +
                        "</ul>")
            else:
                html = "No changes"
            self.report_raw("", html)
        else:
            self.report_data(None)

    @classmethod
    def migrate_context(cls, context, version):
        # pylint: disable=bad-continuation
        if version is None or version <= 1:
            hints_ = context.values.get("domain_change_hints", ({}, -2))[0]
            store = []
            ns = "Orange.data.variable"
            mapping = {
                "DiscreteVariable":
                    lambda name, args, attrs:
                        ("Categorical", (name, tuple(args[0][1]), None, ())),
                "TimeVariable":
                    lambda name, _, attrs:
                        ("Time", (name, ())),
                "ContinuousVariable":
                    lambda name, _, attrs:
                        ("Real", (name, (3, "f"), ())),
                "StringVariable":
                    lambda name, _, attrs:
                        ("String", (name, ())),
            }
            for (module, class_name, *rest), target in hints_.items():
                if module != ns:
                    continue
                f = mapping.get(class_name)
                if f is None:
                    continue
                trs = []
                key_mapped = f(*rest)
                item_mapped = f(*target[2:])
                src = reconstruct(*key_mapped)   # type: Variable
                dst = reconstruct(*item_mapped)  # type: Variable
                if src.name != dst.name:
                    trs.append(Rename(dst.name))
                if src.annotations != dst.annotations:
                    trs.append(Annotate(dst.annotations))
                if isinstance(src, Categorical):
                    if src.categories != dst.categories:
                        assert len(src.categories) == len(dst.categories)
                        trs.append(CategoriesMapping(
                            list(zip(src.categories, dst.categories))))
                store.append((deconstruct(src), [deconstruct(tr) for tr in trs]))
            context.values["_domain_change_store"] = (dict(store), -2)
Exemple #26
0
class OWTestLearners(OWWidget):
    name = "Test and Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100
    keywords = []

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message("Click on the table header to select shown columns",
                       "click_header")
    ]

    settingsHandler = settings.PerfectDomainContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train dataset is empty.")
        test_data_empty = Msg("Test dataset is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg(
            "Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        no_class_values = Msg("Target variable has no values.")
        only_one_class_var_value = Msg("Target variable has only one value.")
        test_data_incompatible = Msg(
            "Test data may be incompatible with train data.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")
        test_data_transformed = Msg(
            "Test data has been transformed to match the train data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)

    @staticmethod
    def sizeHint():
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        elif learner is not None:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.no_class_values.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not data:
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [
                not data.domain.class_vars,
                len(data.domain.class_vars) > 1,
                np.isnan(data.Y).all(), data.domain.has_discrete_class
                and len(data.domain.class_var.values) == 1
            ]
            errors = [
                self.Error.class_required, self.Error.too_many_classes,
                self.Error.no_class_values, self.Error.only_one_class_var_value
            ]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not data:
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {
            (True, True): " ",  # both, don't specify
            (True, False): " train ",
            (False, True): " test "
        }[(self.train_data_missing_vals, self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data is None or self.data.domain.class_var is None:
            self.scorers = []
            return
        self.scorers = usable_scorers(self.data.domain.class_var)

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self.score_table.update_header(self.scorers)
        self.update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._invalidate()
        self.__update()

    def update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.score_table.model
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            results = slot.results
            if results is not None and results.success:
                train = QStandardItem("{:.3f}".format(
                    results.value.train_time))
                train.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                train.setData(key, Qt.UserRole)
                test = QStandardItem("{:.3f}".format(results.value.test_time))
                test.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                test.setData(key, Qt.UserRole)
                row = [head, train, test]
            else:
                row = [head]
            if isinstance(results, Try.Fail):
                head.setToolTip(str(results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                if isinstance(results.exception, DomainTransformationError) \
                        and self.resampling == self.TestOnTest:
                    self.Error.test_data_incompatible()
                    self.Information.test_data_transformed.clear()
                else:
                    errors.append("{name} failed with error:\n"
                                  "{exc.__class__.__name__}: {exc!s}".format(
                                      name=name, exc=slot.results.exception))

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(slot.results.value,
                                                      target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [
                        Try(scorer_caller(scorer, ovr_results, target=1))
                        for scorer in self.scorers
                    ]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat, scorer in zip(stats, self.scorers):
                    item = QStandardItem()
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    if stat.success:
                        item.setData(float(stat.value[0]), Qt.DisplayRole)
                    else:
                        item.setToolTip(str(stat.exception))
                        if scorer.name in self.score_table.shown_scores:
                            has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        # Resort rows based on current sorting
        header = self.score_table.view.horizontalHeader()
        model.sort(header.sortIndicatorSection(), header.sortIndicatorOrder())

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self.update_stats_model()

    def _invalidate(self, which=None):
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.score_table.model
        statmodelkeys = [
            model.item(row, 0).data(Qt.UserRole)
            for row in range(model.rowCount())
        ]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.__needupdate = True

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [
            slot for slot in self.learners.values()
            if slot.results is not None and slot.results.success
        ]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [
                learner_name(slot.learner) for slot in valid
            ]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(
                    combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".format(
                stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [
                ("Sampling type",
                 "{}Shuffle split, {} random samples with {}% data ".format(
                     stratified, self.NRepeats[self.n_repeats],
                     self.SampleSizes[self.sample_size]))
            ]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.score_table.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')
            ]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Error.test_data_incompatible.clear()
        self.Warning.test_data_missing.clear()
        self.Information.test_data_transformed(
            shown=self.resampling == self.TestOnTest and self.data is not None
            and self.test_data is not None and
            self.data.domain.attributes != self.test_data.domain.attributes)
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData(store_data=True,
                                                 store_models=True), self.data,
                self.test_data, learners_c, self.preprocessor)
        else:
            if self.resampling == OWTestLearners.KFold:
                sampler = Orange.evaluation.CrossValidation(
                    k=self.NFolds[self.n_folds], random_state=rstate)
            elif self.resampling == OWTestLearners.FeatureFold:
                sampler = Orange.evaluation.CrossValidationFeature(
                    feature=self.fold_feature)
            elif self.resampling == OWTestLearners.LeaveOneOut:
                sampler = Orange.evaluation.LeaveOneOut()
            elif self.resampling == OWTestLearners.ShuffleSplit:
                sampler = Orange.evaluation.ShuffleSplit(
                    n_resamples=self.NRepeats[self.n_repeats],
                    train_size=self.SampleSizes[self.sample_size] / 100,
                    test_size=None,
                    stratified=self.shuffle_stratified,
                    random_state=rstate)
            elif self.resampling == OWTestLearners.TestOnTrain:
                sampler = Orange.evaluation.TestOnTrainingData(
                    store_models=True)
            else:
                assert False, "self.resampling %s" % self.resampling

            sampler.store_data = True
            test_f = partial(sampler, self.data, learners_c, self.preprocessor)

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[[float], None]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = Task()

        def progress_callback(finished):
            if task.cancelled:
                raise UserInterrupt()
            QMetaObject.invokeMethod(self, "setProgressValue",
                                     Qt.QueuedConnection,
                                     Q_ARG(float, 100 * finished))

        def ondone(_):
            QMetaObject.invokeMethod(self, "__task_complete",
                                     Qt.QueuedConnection, Q_ARG(object, task))

        testfunc = partial(testfunc, callback=progress_callback)
        task.future = self.__executor.submit(testfunc)
        task.future.add_done_callback(ondone)

        self.progressBarInit()
        self.setBlocking(True)
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, task):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        if self.__task is not task:
            assert task.cancelled
            log.debug("Reaping cancelled task: %r", "<>")
            return

        self.setBlocking(False)
        self.progressBarFinished()
        self.setStatusMessage("")
        result = task.future
        assert result.done()
        self.__task = None
        try:
            results = result.result()  # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:
            log.exception("testing error (in __task_complete):", exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er),
                                                                 er)))
            self.__state = State.Done
            return

        self.__state = State.Done

        learner_key = {
            slot.learner: key
            for key, slot in self.learners.items()
        }
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [
                        Try(scorer_caller(scorer, result))
                        for scorer in self.scorers
                    ]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self.score_table.update_header(self.scorers)
        self.update_stats_model()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            assert task.future.done()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Exemple #27
0
class OWUnivariateRegression(OWBaseLearner):
    name = "Polynomial Regression"
    description = "Univariate regression with polynomial expansion."
    icon = "icons/UnivariateRegression.svg"

    inputs = [("Learner", Learner, "set_learner")]

    outputs = [("Coefficients", Table), ("Data", Table)]

    replaces = [
        "Orange.widgets.regression.owunivariateregression."
        "OWUnivariateRegression",
        "orangecontrib.prototypes.widgets.owpolynomialregression."
        "OWPolynomialRegression"
    ]

    LEARNER = PolynomialLearner

    learner_name = settings.Setting("Univariate Regression")

    polynomialexpansion = settings.Setting(1)

    x_var_index = settings.ContextSetting(0)
    y_var_index = settings.ContextSetting(1)
    error_bars_enabled = settings.Setting(False)

    default_learner_name = "Linear Regression"
    error_plot_items = []

    rmse = ""
    mae = ""
    regressor_name = ""

    want_main_area = True
    graph_name = 'Regression graph'

    class Error(OWWidget.Error):
        """
        Class used fro widget warnings.
        """
        all_none = Msg("One of the features has no defined values")

    def add_main_layout(self):

        self.data = None
        self.preprocessors = None
        self.learner = None

        self.scatterplot_item = None
        self.plot_item = None

        self.x_label = 'x'
        self.y_label = 'y'

        self.rmse = ""
        self.mae = ""
        self.regressor_name = self.default_learner_name

        # info box
        info_box = gui.vBox(self.controlArea, "Info")
        self.regressor_label = gui.label(
            widget=info_box,
            master=self,
            label="Regressor: %(regressor_name).30s")
        gui.label(widget=info_box,
                  master=self,
                  label="Mean absolute error: %(mae).6s")
        gui.label(widget=info_box,
                  master=self,
                  label="Root mean square error: %(rmse).6s")

        box = gui.vBox(self.controlArea, "Variables")

        self.x_var_model = itemmodels.VariableListModel()
        self.comboBoxAttributesX = gui.comboBox(box,
                                                self,
                                                value='x_var_index',
                                                label="Input: ",
                                                orientation=Qt.Horizontal,
                                                callback=self.apply,
                                                maximumContentsLength=15)
        self.comboBoxAttributesX.setModel(self.x_var_model)
        self.expansion_spin = gui.doubleSpin(gui.indentedBox(box),
                                             self,
                                             "polynomialexpansion",
                                             0,
                                             10,
                                             label="Polynomial expansion:",
                                             callback=self.apply)

        gui.separator(box, height=8)
        self.y_var_model = itemmodels.VariableListModel()
        self.comboBoxAttributesY = gui.comboBox(box,
                                                self,
                                                value="y_var_index",
                                                label="Target: ",
                                                orientation=Qt.Horizontal,
                                                callback=self.apply,
                                                maximumContentsLength=15)
        self.comboBoxAttributesY.setModel(self.y_var_model)

        properties_box = gui.vBox(self.controlArea, "Properties")
        self.error_bars_checkbox = gui.checkBox(widget=properties_box,
                                                master=self,
                                                value='error_bars_enabled',
                                                label="Show error bars",
                                                callback=self.apply)

        gui.rubber(self.controlArea)

        # main area GUI
        self.plotview = pg.PlotWidget(background="w")
        self.plot = self.plotview.getPlotItem()

        axis_color = self.palette().color(QPalette.Text)
        axis_pen = QPen(axis_color)

        tickfont = QFont(self.font())
        tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))

        axis = self.plot.getAxis("bottom")
        axis.setLabel(self.x_label)
        axis.setPen(axis_pen)
        axis.setTickFont(tickfont)

        axis = self.plot.getAxis("left")
        axis.setLabel(self.y_label)
        axis.setPen(axis_pen)
        axis.setTickFont(tickfont)

        self.plot.setRange(xRange=(0.0, 1.0),
                           yRange=(0.0, 1.0),
                           disableAutoRange=True)

        self.mainArea.layout().addWidget(self.plotview)

    def send_report(self):
        if self.data is None:
            return
        caption = report.render_items_vert(
            (("Polynomial Expansion: ", self.polynomialexpansion), ))
        self.report_plot(self.plot)
        if caption:
            self.report_caption(caption)

    def clear(self):
        self.data = None
        self.rmse = ""
        self.mae = ""
        self.clear_plot()

    def clear_plot(self):
        if self.plot_item is not None:
            self.plot_item.setParentItem(None)
            self.plotview.removeItem(self.plot_item)
            self.plot_item = None

        if self.scatterplot_item is not None:
            self.scatterplot_item.setParentItem(None)
            self.plotview.removeItem(self.scatterplot_item)
            self.scatterplot_item = None

        self.remove_error_items()

        self.plotview.clear()

    @check_sql_input
    def set_data(self, data):
        self.clear()
        self.data = data
        if data is not None:
            cvars = [var for var in data.domain.variables if var.is_continuous]
            class_cvars = [
                var for var in data.domain.class_vars if var.is_continuous
            ]

            self.x_var_model[:] = cvars
            self.y_var_model[:] = cvars

            nvars = len(cvars)
            nclass = len(class_cvars)
            self.x_var_index = min(max(0, self.x_var_index), nvars - 1)
            if nclass > 0:
                self.y_var_index = min(max(0, nvars - nclass), nvars - 1)
            else:
                self.y_var_index = min(max(0, nvars - 1), nvars - 1)

    def set_learner(self, learner):
        self.learner = learner
        self.regressor_name = (learner.name if learner is not None else
                               self.default_learner_name)

    def handleNewSignals(self):
        self.apply()

    def plot_scatter_points(self, x_data, y_data):
        if self.scatterplot_item:
            self.plotview.removeItem(self.scatterplot_item)
        self.n_points = len(x_data)
        self.scatterplot_item = pg.ScatterPlotItem(x=x_data,
                                                   y=y_data,
                                                   data=np.arange(
                                                       self.n_points),
                                                   symbol="o",
                                                   size=10,
                                                   pen=pg.mkPen(0.2),
                                                   brush=pg.mkBrush(0.7),
                                                   antialias=True)
        self.scatterplot_item.opts["useCache"] = False
        self.plotview.addItem(self.scatterplot_item)
        self.plotview.replot()

    def set_range(self, x_data, y_data):
        min_x, max_x = np.nanmin(x_data), np.nanmax(x_data)
        min_y, max_y = np.nanmin(y_data), np.nanmax(y_data)
        self.plotview.setRange(QRectF(min_x, min_y, max_x - min_x,
                                      max_y - min_y),
                               padding=0.025)
        self.plotview.replot()

    def plot_regression_line(self, x_data, y_data):
        if self.plot_item:
            self.plotview.removeItem(self.plot_item)
        self.plot_item = pg.PlotCurveItem(x=x_data,
                                          y=y_data,
                                          pen=pg.mkPen(QColor(255, 0, 0),
                                                       width=3),
                                          antialias=True)
        self.plotview.addItem(self.plot_item)
        self.plotview.replot()

    def remove_error_items(self):
        for it in self.error_plot_items:
            self.plotview.removeItem(it)
        self.error_plot_items = []

    def plot_error_bars(self, x, actual, predicted):
        self.remove_error_items()
        if self.error_bars_enabled:
            for x, a, p in zip(x, actual, predicted):
                line = pg.PlotCurveItem(x=[x, x],
                                        y=[a, p],
                                        pen=pg.mkPen(QColor(150, 150, 150),
                                                     width=1),
                                        antialias=True)
                self.plotview.addItem(line)
                self.error_plot_items.append(line)
        self.plotview.replot()

    def apply(self):
        degree = int(self.polynomialexpansion)
        learner = self.LEARNER(preprocessors=self.preprocessors,
                               degree=degree,
                               learner=LinearRegressionLearner()
                               if self.learner is None else self.learner)
        learner.name = self.learner_name
        predictor = None

        self.Error.clear()

        if self.data is not None:
            attributes = self.x_var_model[self.x_var_index]
            class_var = self.y_var_model[self.y_var_index]
            data_table = Table(Domain([attributes], class_vars=[class_var]),
                               self.data)

            # all lines has nan
            if sum(
                    math.isnan(line[0]) or math.isnan(line.get_class())
                    for line in data_table) == len(data_table):
                self.Error.all_none()
                self.clear_plot()
                return

            predictor = learner(data_table)

            preprocessed_data = data_table
            if self.preprocessors is not None:
                for preprocessor in self.preprocessors:
                    preprocessed_data = preprocessor(preprocessed_data)

            x = preprocessed_data.X.ravel()
            y = preprocessed_data.Y.ravel()

            linspace = np.linspace(np.nanmin(x), np.nanmax(x),
                                   1000).reshape(-1, 1)
            values = predictor(linspace, predictor.Value)

            # calculate prediction for x from data
            predicted = TestOnTrainingData(preprocessed_data, [learner])
            self.rmse = round(RMSE(predicted)[0], 6)
            self.mae = round(MAE(predicted)[0], 6)

            # plot error bars
            self.plot_error_bars(x, predicted.actual,
                                 predicted.predicted.ravel())

            # plot data points
            self.plot_scatter_points(x, y)

            # plot regression line
            self.plot_regression_line(linspace.ravel(), values.ravel())

            x_label = self.x_var_model[self.x_var_index]
            axis = self.plot.getAxis("bottom")
            axis.setLabel(x_label)

            y_label = self.y_var_model[self.y_var_index]
            axis = self.plot.getAxis("left")
            axis.setLabel(y_label)

            self.set_range(x, y)

        self.send("Learner", learner)
        self.send("Predictor", predictor)

        # Send model coefficents
        model = None
        if predictor is not None:
            model = predictor.model
            if hasattr(model, "model"):
                model = model.model
            elif hasattr(model, "skl_model"):
                model = model.skl_model
        if model is not None and hasattr(model, "coef_"):
            domain = Domain([ContinuousVariable("coef", number_of_decimals=7)],
                            metas=[StringVariable("name")])
            coefs = [model.intercept_ + model.coef_[0]] + list(model.coef_[1:])
            names = ["1", x_label] + \
                    ["{}^{}".format(x_label, i) for i in range(2, degree + 1)]
            coef_table = Table(domain, list(zip(coefs, names)))
            self.send("Coefficients", coef_table)
        else:
            self.send("Coefficients", None)

        self.send_data()

    def send_data(self):
        if self.data is not None:
            attributes = self.x_var_model[self.x_var_index]
            class_var = self.y_var_model[self.y_var_index]
            data_table = Table(Domain([attributes], class_vars=[class_var]),
                               self.data)
            polyfeatures = skl_preprocessing.PolynomialFeatures(
                int(self.polynomialexpansion))

            x = data_table.X[~np.isnan(data_table.X).any(axis=1)]
            x = polyfeatures.fit_transform(x)

            x_label = data_table.domain.attributes[0].name
            out_domain = Domain([ContinuousVariable("1")] + (
                [data_table.domain.
                 attributes[0]] if self.polynomialexpansion > 0 else []) + [
                     ContinuousVariable("{}^{}".format(x_label, i))
                     for i in range(2,
                                    int(self.polynomialexpansion) + 1)
                 ])

            self.send("Data", Table(out_domain, x))
            return

        self.send("Data", None)

    def add_bottom_buttons(self):
        pass
Exemple #28
0
class OWROCAnalysis(widget.OWWidget):
    name = "ROC分析(ROC Analysis)"
    description = "根据分类器的评估结果显示接受者操作曲线。"
    icon = "icons/ROCAnalysis.svg"
    priority = 1010
    keywords = ['fenxi']
    category = '评估(Evaluate)'

    class Inputs:
        evaluation_results = Input("评估结果(Evaluation Results)",
                                   Orange.evaluation.Results,
                                   replaces=["Evaluation Results"])

    buttons_area_orientation = None
    settingsHandler = EvaluationResultsContextHandler()
    target_index = settings.ContextSetting(0)
    selected_classifiers = settings.ContextSetting([])

    display_perf_line = settings.Setting(True)
    display_def_threshold = settings.Setting(True)

    fp_cost = settings.Setting(500)
    fn_cost = settings.Setting(500)
    target_prior = settings.Setting(50.0, schema_only=True)

    #: ROC Averaging Types
    Merge, Vertical, Threshold, NoAveraging = 0, 1, 2, 3
    roc_averaging = settings.Setting(Merge)

    display_convex_hull = settings.Setting(False)
    display_convex_curve = settings.Setting(False)

    graph_name = "plot"

    def __init__(self):
        super().__init__()

        self.results = None
        self.classifier_names = []
        self.perf_line = None
        self.colors = []
        self._curve_data = {}
        self._plot_curves = {}
        self._rocch = None
        self._perf_line = None
        self._tooltip_cache = None

        box = gui.vBox(self.controlArea, "绘制")
        self.target_cb = gui.comboBox(box,
                                      self,
                                      "target_index",
                                      label="目标",
                                      orientation=Qt.Horizontal,
                                      callback=self._on_target_changed,
                                      contentsLength=8,
                                      searchable=True)

        gui.widgetLabel(box, "分类器")
        line_height = 4 * QFontMetrics(self.font()).lineSpacing()
        self.classifiers_list_box = gui.listBox(
            box,
            self,
            "selected_classifiers",
            "classifier_names",
            selectionMode=QListView.MultiSelection,
            callback=self._on_classifiers_changed,
            sizeHint=QSize(0, line_height))

        abox = gui.vBox(self.controlArea, "曲线")
        gui.comboBox(abox,
                     self,
                     "roc_averaging",
                     items=["合并折叠预测", "平均真阳性率", "阈值处的平均真阳性率和假阳性率", "显示单个曲线"],
                     callback=self._replot)

        gui.checkBox(abox,
                     self,
                     "display_convex_curve",
                     "显示凸ROC曲线",
                     callback=self._replot)
        gui.checkBox(abox,
                     self,
                     "display_convex_hull",
                     "显示ROC凸包",
                     callback=self._replot)

        box = gui.vBox(self.controlArea, "分析")

        gui.checkBox(box,
                     self,
                     "display_def_threshold",
                     "默认阈值(0.5)点",
                     callback=self._on_display_def_threshold_changed)

        gui.checkBox(box,
                     self,
                     "display_perf_line",
                     "显示性能线",
                     callback=self._on_display_perf_line_changed)
        grid = QGridLayout()
        gui.indentedBox(box, orientation=grid)

        sp = gui.spin(box,
                      self,
                      "fp_cost",
                      1,
                      1000,
                      10,
                      alignment=Qt.AlignRight,
                      callback=self._on_display_perf_line_changed)
        grid.addWidget(QLabel("假阳性率损失:"), 0, 0)
        grid.addWidget(sp, 0, 1)

        sp = gui.spin(box,
                      self,
                      "fn_cost",
                      1,
                      1000,
                      10,
                      alignment=Qt.AlignRight,
                      callback=self._on_display_perf_line_changed)
        grid.addWidget(QLabel("假阴性率损失:"))
        grid.addWidget(sp, 1, 1)
        self.target_prior_sp = gui.spin(box,
                                        self,
                                        "target_prior",
                                        1,
                                        99,
                                        alignment=Qt.AlignRight,
                                        spinType=float,
                                        callback=self._on_target_prior_changed)
        self.target_prior_sp.setSuffix(" %")
        self.target_prior_sp.addAction(QAction("Auto", sp))
        grid.addWidget(QLabel("先验概率:"))
        grid.addWidget(self.target_prior_sp, 2, 1)

        self.plotview = GraphicsView(background=None)
        self.plotview.setFrameStyle(QFrame.StyledPanel)
        self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved)

        self.plot = PlotItem(enableMenu=False)
        self.plot.setMouseEnabled(False, False)
        self.plot.hideButtons()

        tickfont = QFont(self.font())
        tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))

        axis = self.plot.getAxis("bottom")
        axis.setTickFont(tickfont)
        axis.setLabel("假阳性率 (1-特异度)")
        axis.setGrid(16)

        axis = self.plot.getAxis("left")
        axis.setTickFont(tickfont)
        axis.setLabel("真阳性率 (灵敏度)")
        axis.setGrid(16)

        self.plot.showGrid(True, True, alpha=0.2)
        self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)

        self.plotview.setCentralItem(self.plot)
        self.mainArea.layout().addWidget(self.plotview)

    @Inputs.evaluation_results
    def set_results(self, results):
        """Set the input evaluation results."""
        self.closeContext()
        self.clear()
        self.results = check_results_adequacy(results, self.Error)
        if self.results is not None:
            self._initialize(self.results)
            self.openContext(self.results.domain.class_var,
                             self.classifier_names)
            self._setup_plot()
        else:
            self.warning()

    def clear(self):
        """Clear the widget state."""
        self.results = None
        self.plot.clear()
        self.classifier_names = []
        self.selected_classifiers = []
        self.target_cb.clear()
        self.colors = []
        self._curve_data = {}
        self._plot_curves = {}
        self._rocch = None
        self._perf_line = None
        self._tooltip_cache = None

    def _initialize(self, results):
        names = getattr(results, "learner_names", None)

        if names is None:
            names = [
                "#{}".format(i + 1) for i in range(len(results.predicted))
            ]

        self.colors = colorpalettes.get_default_curve_colors(len(names))

        self.classifier_names = names
        self.selected_classifiers = list(range(len(names)))
        for i in range(len(names)):
            listitem = self.classifiers_list_box.item(i)
            listitem.setIcon(colorpalettes.ColorIcon(self.colors[i]))

        class_var = results.data.domain.class_var
        self.target_cb.addItems(class_var.values)
        self.target_index = 0
        self._set_target_prior()

    def _set_target_prior(self):
        """
        This function sets the initial target class probability prior value
        based on the input data.
        """
        if self.results.data:
            # here we can use target_index directly since values in the
            # dropdown are sorted in same order than values in the table
            target_values_cnt = np.count_nonzero(
                self.results.data.Y == self.target_index)
            count_all = np.count_nonzero(~np.isnan(self.results.data.Y))
            self.target_prior = np.round(target_values_cnt / count_all * 100)

            # set the spin text to gray color when set automatically
            self.target_prior_sp.setStyleSheet("color: gray;")

    def curve_data(self, target, clf_idx):
        """Return `ROCData' for the given target and classifier."""
        if (target, clf_idx) not in self._curve_data:
            # pylint: disable=no-member
            data = ROCData.from_results(self.results, clf_idx, target)
            self._curve_data[target, clf_idx] = data

        return self._curve_data[target, clf_idx]

    def plot_curves(self, target, clf_idx):
        """Return a set of functions `plot_curves` generating plot curves."""
        def generate_pens(basecolor):
            pen = QPen(basecolor, 1)
            pen.setCosmetic(True)

            shadow_pen = QPen(pen.color().lighter(160), 2.5)
            shadow_pen.setCosmetic(True)
            return pen, shadow_pen

        data = self.curve_data(target, clf_idx)

        if (target, clf_idx) not in self._plot_curves:
            pen, shadow_pen = generate_pens(self.colors[clf_idx])
            name = self.classifier_names[clf_idx]

            @once
            def merged():
                return plot_curve(data.merged,
                                  pen=pen,
                                  shadow_pen=shadow_pen,
                                  name=name)

            @once
            def folds():
                return [
                    plot_curve(fold, pen=pen, shadow_pen=shadow_pen)
                    for fold in data.folds
                ]

            @once
            def avg_vert():
                return plot_avg_curve(data.avg_vertical,
                                      pen=pen,
                                      shadow_pen=shadow_pen,
                                      name=name)

            @once
            def avg_thres():
                return plot_avg_curve(data.avg_threshold,
                                      pen=pen,
                                      shadow_pen=shadow_pen,
                                      name=name)

            self._plot_curves[target,
                              clf_idx] = PlotCurves(merge=merged,
                                                    folds=folds,
                                                    avg_vertical=avg_vert,
                                                    avg_threshold=avg_thres)

        return self._plot_curves[target, clf_idx]

    def _setup_plot(self):
        def merge_averaging():
            for curve in curves:
                graphics = curve.merge()
                curve = graphics.curve
                self.plot.addItem(graphics.curve_item)

                if self.display_convex_curve:
                    self.plot.addItem(graphics.hull_item)

                if self.display_def_threshold and curve.is_valid:
                    points = curve.points
                    ind = np.argmin(np.abs(points.thresholds - 0.5))
                    item = pg.TextItem(text="{:.3f}".format(
                        points.thresholds[ind]),
                                       color=foreground)
                    item.setPos(points.fpr[ind], points.tpr[ind])
                    self.plot.addItem(item)

            hull_curves = [curve.merged.hull for curve in selected]
            if hull_curves:
                self._rocch = convex_hull(hull_curves)
                iso_pen = QPen(foreground, 1.0)
                iso_pen.setCosmetic(True)
                self._perf_line = InfiniteLine(pen=iso_pen, antialias=True)
                self.plot.addItem(self._perf_line)
            return hull_curves

        def vertical_averaging():
            for curve in curves:
                graphics = curve.avg_vertical()

                self.plot.addItem(graphics.curve_item)
                self.plot.addItem(graphics.confint_item)
            return [curve.avg_vertical.hull for curve in selected]

        def threshold_averaging():
            for curve in curves:
                graphics = curve.avg_threshold()
                self.plot.addItem(graphics.curve_item)
                self.plot.addItem(graphics.confint_item)
            return [curve.avg_threshold.hull for curve in selected]

        def no_averaging():
            for curve in curves:
                graphics = curve.folds()
                for fold in graphics:
                    self.plot.addItem(fold.curve_item)
                    if self.display_convex_curve:
                        self.plot.addItem(fold.hull_item)
            return [fold.hull for curve in selected for fold in curve.folds]

        averagings = {
            OWROCAnalysis.Merge: merge_averaging,
            OWROCAnalysis.Vertical: vertical_averaging,
            OWROCAnalysis.Threshold: threshold_averaging,
            OWROCAnalysis.NoAveraging: no_averaging
        }
        foreground = self.plotview.scene().palette().color(QPalette.Text)
        target = self.target_index
        selected = self.selected_classifiers

        curves = [self.plot_curves(target, i) for i in selected]
        selected = [self.curve_data(target, i) for i in selected]
        hull_curves = averagings[self.roc_averaging]()

        if self.display_convex_hull and hull_curves:
            hull = convex_hull(hull_curves)
            hull_color = QColor(foreground)
            hull_color.setAlpha(100)
            hull_pen = QPen(hull_color, 2)
            hull_pen.setCosmetic(True)
            hull_color.setAlpha(50)
            item = self.plot.plot(hull.fpr,
                                  hull.tpr,
                                  pen=hull_pen,
                                  brush=QBrush(hull_color),
                                  fillLevel=0)
            item.setZValue(-10000)
        line_color = self.palette().color(QPalette.Disabled, QPalette.Text)
        pen = QPen(QColor(*line_color.getRgb()[:3], 200), 1.0, Qt.DashLine)
        pen.setCosmetic(True)
        self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True)

        if self.roc_averaging == OWROCAnalysis.Merge:
            self._update_perf_line()

        self._update_axes_ticks()

        warning = ""
        if not all(c.is_valid for c in hull_curves):
            if any(c.is_valid for c in hull_curves):
                warning = "Some ROC curves are undefined"
            else:
                warning = "All ROC curves are undefined"
        self.warning(warning)

    def _update_axes_ticks(self):
        def enumticks(a):
            a = np.unique(a)
            if len(a) > 15:
                return None
            return [[(x, f"{x:.2f}") for x in a[::-1]]]

        axis_bottom = self.plot.getAxis("bottom")
        axis_left = self.plot.getAxis("left")

        if not self.selected_classifiers or len(self.selected_classifiers) > 1 \
                or self.roc_averaging != OWROCAnalysis.Merge:
            axis_bottom.setTicks(None)
            axis_left.setTicks(None)
        else:
            data = self.curve_data(self.target_index,
                                   self.selected_classifiers[0])
            points = data.merged.points
            axis_bottom.setTicks(enumticks(points.fpr))
            axis_left.setTicks(enumticks(points.tpr))

    def _on_mouse_moved(self, pos):
        target = self.target_index
        selected = self.selected_classifiers
        curves = [(clf_idx, self.plot_curves(target, clf_idx))
                  for clf_idx in selected
                  ]  # type: List[Tuple[int, PlotCurves]]
        valid_thresh, valid_clf = [], []
        pt, ave_mode = None, self.roc_averaging

        for clf_idx, crv in curves:
            if self.roc_averaging == OWROCAnalysis.Merge:
                curve = crv.merge()
            elif self.roc_averaging == OWROCAnalysis.Vertical:
                curve = crv.avg_vertical()
            elif self.roc_averaging == OWROCAnalysis.Threshold:
                curve = crv.avg_threshold()
            else:
                # currently not implemented for 'Show Individual Curves'
                return

            sp = curve.curve_item.childItems()[0]  # type: pg.ScatterPlotItem
            act_pos = sp.mapFromScene(pos)
            pts = list(sp.pointsAt(act_pos))

            if pts:
                mouse_pt = pts[0].pos()
                if self._tooltip_cache:
                    cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache
                    curr_thresh, curr_clf = [], []
                    if np.linalg.norm(mouse_pt - cache_pt) < 10e-6 \
                            and cache_ave == self.roc_averaging:
                        mask = np.equal(cache_clf, clf_idx)
                        curr_thresh = np.compress(mask, cache_thresh).tolist()
                        curr_clf = np.compress(mask, cache_clf).tolist()
                    else:
                        QToolTip.showText(QCursor.pos(), "")
                        self._tooltip_cache = None

                    if curr_thresh:
                        valid_thresh.append(*curr_thresh)
                        valid_clf.append(*curr_clf)
                        pt = cache_pt
                        continue

                curve_pts = curve.curve.points
                roc_points = np.column_stack((curve_pts.fpr, curve_pts.tpr))
                diff = np.subtract(roc_points, mouse_pt)
                # Find closest point on curve and save the corresponding threshold
                idx_closest = np.argmin(np.linalg.norm(diff, axis=1))

                thresh = curve_pts.thresholds[idx_closest]
                if not np.isnan(thresh):
                    valid_thresh.append(thresh)
                    valid_clf.append(clf_idx)
                    pt = [
                        curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]
                    ]

        if valid_thresh:
            clf_names = self.classifier_names
            msg = "Thresholds:\n" + "\n".join([
                "({:s}) {:.3f}".format(clf_names[i], thresh)
                for i, thresh in zip(valid_clf, valid_thresh)
            ])
            QToolTip.showText(QCursor.pos(), msg)
            self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode)

    def _on_target_changed(self):
        self.plot.clear()
        self._set_target_prior()
        self._setup_plot()

    def _on_classifiers_changed(self):
        self.plot.clear()
        if self.results is not None:
            self._setup_plot()

    def _on_target_prior_changed(self):
        self.target_prior_sp.setStyleSheet("color: black;")
        self._on_display_perf_line_changed()

    def _on_display_perf_line_changed(self):
        if self.roc_averaging == OWROCAnalysis.Merge:
            self._update_perf_line()

        if self.perf_line is not None:
            self.perf_line.setVisible(self.display_perf_line)

    def _on_display_def_threshold_changed(self):
        self._replot()

    def _replot(self):
        self.plot.clear()
        if self.results is not None:
            self._setup_plot()

    def _update_perf_line(self):
        if self._perf_line is None:
            return

        self._perf_line.setVisible(self.display_perf_line)
        if self.display_perf_line:
            m = roc_iso_performance_slope(self.fp_cost, self.fn_cost,
                                          self.target_prior / 100.0)

            hull = self._rocch
            if hull.is_valid:
                ind = roc_iso_performance_line(m, hull)
                angle = np.arctan2(m, 1)  # in radians
                self._perf_line.setAngle(angle * 180 / np.pi)
                self._perf_line.setPos((hull.fpr[ind[0]], hull.tpr[ind[0]]))
            else:
                self._perf_line.setVisible(False)

    def onDeleteWidget(self):
        self.clear()

    def send_report(self):
        if self.results is None:
            return
        items = OrderedDict()
        items["Target class"] = self.target_cb.currentText()
        if self.display_perf_line:
            items["Costs"] = \
                "FP = {}, FN = {}".format(self.fp_cost, self.fn_cost)
            items["Target probability"] = "{} %".format(self.target_prior)
        caption = report.list_legend(self.classifiers_list_box,
                                     self.selected_classifiers)
        self.report_items(items)
        self.report_plot()
        self.report_caption(caption)
Exemple #29
0
class OWPythagorasTree(OWWidget):
    name = 'Pythagorean Tree'
    description = 'Pythagorean Tree visualization for tree like-structures.'
    icon = 'icons/PythagoreanTree.svg'
    keywords = ["fractal"]

    priority = 1000

    class Inputs:
        tree = Input("Tree", TreeModel)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    size_log_scale = settings.Setting(2)
    tooltips_enabled = settings.Setting(True)
    show_legend = settings.Setting(False)

    LEGEND_OPTIONS = {
        'corner': Anchorable.BOTTOM_RIGHT,
        'offset': (10, 10),
    }

    def __init__(self):
        super().__init__()
        # Instance variables
        self.model = None
        self.instances = None
        self.clf_dataset = None
        # The tree adapter instance which is passed from the outside
        self.tree_adapter = None
        self.legend = None

        self.color_palette = None

        # Different methods to calculate the size of squares
        self.SIZE_CALCULATION = [
            ('Normal', lambda x: x),
            ('Square root', lambda x: sqrt(x)),
            ('Logarithmic', lambda x: log(x * self.size_log_scale + 1)),
        ]

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Tree Info')
        self.info = gui.widgetLabel(box_info)

        # Display settings area
        box_display = gui.widgetBox(self.controlArea, 'Display Settings')
        self.depth_slider = gui.hSlider(box_display,
                                        self,
                                        'depth_limit',
                                        label='Depth',
                                        ticks=False,
                                        callback=self.update_depth)
        self.target_class_combo = gui.comboBox(box_display,
                                               self,
                                               'target_class_index',
                                               label='Target class',
                                               orientation=Qt.Horizontal,
                                               items=[],
                                               contentsLength=8,
                                               callback=self.update_colors)
        self.size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
            callback=self.update_size_calc)
        self.log_scale_box = gui.hSlider(box_display,
                                         self,
                                         'size_log_scale',
                                         label='Log scale factor',
                                         minValue=1,
                                         maxValue=100,
                                         ticks=False,
                                         callback=self.invalidate_tree)

        # Plot properties area
        box_plot = gui.widgetBox(self.controlArea, 'Plot Properties')
        self.cb_show_tooltips = gui.checkBox(
            box_plot,
            self,
            'tooltips_enabled',
            label='Enable tooltips',
            callback=self.update_tooltip_enabled)
        self.cb_show_legend = gui.checkBox(box_plot,
                                           self,
                                           'show_legend',
                                           label='Show legend',
                                           callback=self.update_show_legend)

        gui.button(self.controlArea,
                   self,
                   label="Redraw",
                   callback=self.redraw)

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.scene = TreeGraphicsScene(self)
        self.scene.selectionChanged.connect(self.commit)
        self.view = TreeGraphicsView(self.scene, padding=(150, 150))
        self.view.setRenderHint(QPainter.Antialiasing, True)
        self.mainArea.layout().addWidget(self.view)

        self.ptree = PythagorasTreeViewer(self)
        self.scene.addItem(self.ptree)
        self.view.set_central_widget(self.ptree)

        self.resize(800, 500)
        # Clear the widget to correctly set the intial values
        self.clear()

    @Inputs.tree
    def set_tree(self, model=None):
        """When a different tree is given."""
        self.clear()
        self.model = model

        if model is not None:
            self.instances = model.instances
            # this bit is important for the regression classifier
            if self.instances is not None and \
                    self.instances.domain != model.domain:
                self.clf_dataset = self.instances.transform(self.model.domain)
            else:
                self.clf_dataset = self.instances

            self.tree_adapter = self._get_tree_adapter(self.model)
            self.ptree.clear()

            self.ptree.set_tree(
                self.tree_adapter,
                weight_adjustment=self.SIZE_CALCULATION[self.size_calc_idx][1],
                target_class_index=self.target_class_index,
            )

            self._update_depth_slider()
            self.color_palette = self.ptree.root.color_palette
            self._update_legend_colors()
            self._update_legend_visibility()
            self._update_info_box()
            self._update_target_class_combo()

            self._update_main_area()

            # The target class can also be passed from the meta properties
            # This must be set after `_update_target_class_combo`
            if hasattr(model, 'meta_target_class_index'):
                self.target_class_index = model.meta_target_class_index
                self.update_colors()

            # Get meta variables describing what the settings should look like
            # if the tree is passed from the Pythagorean forest widget.
            if hasattr(model, 'meta_size_calc_idx'):
                self.size_calc_idx = model.meta_size_calc_idx
                self.update_size_calc()

            # TODO There is still something wrong with this
            # if hasattr(model, 'meta_depth_limit'):
            #     self.depth_limit = model.meta_depth_limit
            #     self.update_depth()

        self.Outputs.annotated_data.send(
            create_annotated_table(self.instances, None))

    def clear(self):
        """Clear all relevant data from the widget."""
        self.model = None
        self.instances = None
        self.clf_dataset = None
        self.tree_adapter = None

        if self.legend is not None:
            self.scene.removeItem(self.legend)
        self.legend = None

        self.ptree.clear()
        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()
        self._update_log_scale_slider()

    def update_depth(self):
        """This method should be called when the depth changes"""
        self.ptree.set_depth_limit(self.depth_limit)

    def update_colors(self):
        """When the target class / node coloring needs to be updated."""
        self.ptree.target_class_changed(self.target_class_index)
        self._update_legend_colors()

    def update_size_calc(self):
        """When the tree size calculation is updated."""
        self._update_log_scale_slider()
        self.invalidate_tree()

    def redraw(self):
        self.tree_adapter.shuffle_children()
        self.invalidate_tree()

    def invalidate_tree(self):
        """When the tree needs to be completely recalculated."""
        if self.model is not None:
            self.ptree.set_tree(
                self.tree_adapter,
                weight_adjustment=self.SIZE_CALCULATION[self.size_calc_idx][1],
                target_class_index=self.target_class_index,
            )
            self.ptree.set_depth_limit(self.depth_limit)
            self._update_main_area()

    def update_tooltip_enabled(self):
        """When the tooltip visibility is changed and need to be updated."""
        self.ptree.tooltip_changed(self.tooltips_enabled)

    def update_show_legend(self):
        """When the legend visibility needs to be updated."""
        self._update_legend_visibility()

    def _update_info_box(self):
        self.info.setText('Nodes: {}\nDepth: {}'.format(
            self.tree_adapter.num_nodes, self.tree_adapter.max_depth))

    def _update_depth_slider(self):
        self.depth_slider.parent().setEnabled(True)
        self.depth_slider.setMaximum(self.tree_adapter.max_depth)
        self._set_max_depth()

    def _update_legend_visibility(self):
        if self.legend is not None:
            self.legend.setVisible(self.show_legend)

    def _update_log_scale_slider(self):
        """On calc method combo box changed."""
        self.log_scale_box.parent().setEnabled(
            self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic')

    def _clear_info_box(self):
        self.info.setText('No tree on input')

    def _clear_depth_slider(self):
        self.depth_slider.parent().setEnabled(False)
        self.depth_slider.setMaximum(0)

    def _clear_target_class_combo(self):
        self.target_class_combo.clear()
        self.target_class_index = 0
        self.target_class_combo.setCurrentIndex(self.target_class_index)

    def _set_max_depth(self):
        """Set the depth to the max depth and update appropriate actors."""
        self.depth_limit = self.tree_adapter.max_depth
        self.depth_slider.setValue(self.depth_limit)

    def _update_main_area(self):
        # refresh the scene rect, cuts away the excess whitespace, and adds
        # padding for panning.
        self.scene.setSceneRect(self.view.central_widget_rect())
        # reset the zoom level
        self.view.recalculate_and_fit()
        self.view.update_anchored_items()

    def _get_tree_adapter(self, model):
        if isinstance(model, SklModel):
            return SklTreeAdapter(model)
        return TreeAdapter(model)

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self):
        """Commit the selected data to output."""
        if self.instances is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            return
        nodes = [
            i.tree_node.label for i in self.scene.selectedItems()
            if isinstance(i, SquareGraphicsItem)
        ]
        data = self.tree_adapter.get_instances_in_nodes(nodes)
        self.Outputs.selected_data.send(data)
        selected_indices = self.tree_adapter.get_indices(nodes)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.instances, selected_indices))

    def send_report(self):
        """Send report."""
        self.report_plot()

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [
            x for x in self.target_class_combo.parent().children()
            if isinstance(x, QLabel)
        ][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [
                c.title() for c in self.instances.domain.class_vars[0].values
            ]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.target_class_combo.addItems(values)
        self.target_class_combo.setCurrentIndex(self.target_class_index)

    def _update_legend_colors(self):
        if self.legend is not None:
            self.scene.removeItem(self.legend)

        if self.instances.domain.has_discrete_class:
            self._classification_update_legend_colors()
        else:
            self._regression_update_legend_colors()

    def _classification_update_legend_colors(self):
        if self.target_class_index == 0:
            self.legend = OWDiscreteLegend(domain=self.model.domain,
                                           **self.LEGEND_OPTIONS)
        else:
            items = ((self.target_class_combo.itemText(
                self.target_class_index),
                      self.color_palette[self.target_class_index - 1]),
                     ('other', QColor('#ffffff')))
            self.legend = OWDiscreteLegend(items=items, **self.LEGEND_OPTIONS)

        self.legend.setVisible(self.show_legend)
        self.scene.addItem(self.legend)

    def _regression_update_legend_colors(self):
        def _get_colors_domain(domain):
            class_var = domain.class_var
            start, end, pass_through_black = class_var.colors
            if pass_through_black:
                lst_colors = [QColor(*c) for c in [start, (0, 0, 0), end]]
            else:
                lst_colors = [QColor(*c) for c in [start, end]]
            return lst_colors

        # The colors are the class mean
        if self.target_class_index == 1:
            values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y))
            colors = _get_colors_domain(self.model.domain)
            while len(values) != len(colors):
                values.insert(1, -1)
            items = list(zip(values, colors))
        # Colors are the stddev
        elif self.target_class_index == 2:
            values = (0, np.std(self.clf_dataset.Y))
            colors = _get_colors_domain(self.model.domain)
            while len(values) != len(colors):
                values.insert(1, -1)
            items = list(zip(values, colors))
        else:
            items = None

        self.legend = OWContinuousLegend(items=items, **self.LEGEND_OPTIONS)
        self.legend.setVisible(self.show_legend)
        self.scene.addItem(self.legend)
Exemple #30
0
class OWCalibrationPlot(widget.OWWidget):
    name = "Calibration Plot"
    description = "Calibration plot based on evaluation of classifiers."
    icon = "icons/CalibrationPlot.svg"
    priority = 1030
    keywords = []

    class Inputs:
        evaluation_results = Input("Evaluation Results", Results)

    class Outputs:
        calibrated_model = Output("Calibrated Model", Model)

    class Error(widget.OWWidget.Error):
        non_discrete_target = Msg("Calibration plot requires a categorical "
                                  "target variable.")
        empty_input = widget.Msg("Empty result on input. Nothing to display.")
        nan_classes = \
            widget.Msg("Remove test data instances with unknown classes.")
        all_target_class = widget.Msg(
            "All data instances belong to target class.")
        no_target_class = widget.Msg(
            "No data instances belong to target class.")

    class Warning(widget.OWWidget.Warning):
        omitted_folds = widget.Msg(
            "Test folds where all data belongs to (non)-target are not shown.")
        omitted_nan_prob_points = widget.Msg(
            "Instance for which the model couldn't compute probabilities are"
            "skipped.")
        no_valid_data = widget.Msg("No valid data for model(s) {}")

    class Information(widget.OWWidget.Information):
        no_output = Msg("Can't output a model: {}")

    settingsHandler = EvaluationResultsContextHandler()
    target_index = settings.ContextSetting(0)
    selected_classifiers = settings.ContextSetting([])
    score = settings.Setting(0)
    output_calibration = settings.Setting(0)
    fold_curves = settings.Setting(False)
    display_rug = settings.Setting(True)
    threshold = settings.Setting(0.5)
    visual_settings = settings.Setting({}, schema_only=True)
    auto_commit = settings.Setting(True)

    graph_name = "plot"

    def __init__(self):
        super().__init__()

        self.results = None
        self.scores = None
        self.classifier_names = []
        self.colors = []
        self.line = None

        self._last_score_value = -1

        box = gui.vBox(self.controlArea, box="Settings")
        self.target_cb = gui.comboBox(box,
                                      self,
                                      "target_index",
                                      label="Target:",
                                      orientation=Qt.Horizontal,
                                      callback=self.target_index_changed,
                                      contentsLength=8,
                                      searchable=True)
        gui.checkBox(box,
                     self,
                     "display_rug",
                     "Show rug",
                     callback=self._on_display_rug_changed)
        gui.checkBox(box,
                     self,
                     "fold_curves",
                     "Curves for individual folds",
                     callback=self._replot)

        self.classifiers_list_box = gui.listBox(
            self.controlArea,
            self,
            "selected_classifiers",
            "classifier_names",
            box="Classifier",
            selectionMode=QListWidget.ExtendedSelection,
            sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred),
            sizeHint=QSize(150, 40),
            callback=self._on_selection_changed)

        box = gui.vBox(self.controlArea, "Metrics")
        combo = gui.comboBox(box,
                             self,
                             "score",
                             items=(metric.name for metric in Metrics),
                             callback=self.score_changed)

        self.explanation = gui.widgetLabel(box,
                                           wordWrap=True,
                                           fixedWidth=combo.sizeHint().width())
        self.explanation.setContentsMargins(8, 8, 0, 0)
        font = self.explanation.font()
        font.setPointSizeF(0.85 * font.pointSizeF())
        self.explanation.setFont(font)

        gui.radioButtons(box,
                         self,
                         value="output_calibration",
                         btnLabels=("Sigmoid calibration",
                                    "Isotonic calibration"),
                         label="Output model calibration",
                         callback=self.commit.deferred)

        self.info_box = gui.widgetBox(self.controlArea, "Info")
        self.info_label = gui.widgetLabel(self.info_box)

        gui.auto_apply(self.buttonsArea, self, "auto_commit")

        self.plotview = pg.GraphicsView(background="w")
        axes = {
            "bottom": AxisItem(orientation="bottom"),
            "left": AxisItem(orientation="left")
        }
        self.plot = pg.PlotItem(enableMenu=False, axisItems=axes)
        self.plot.parameter_setter = ParameterSetter(self.plot)
        self.plot.setMouseEnabled(False, False)
        self.plot.hideButtons()

        for axis_name in ("bottom", "left"):
            axis = self.plot.getAxis(axis_name)
            axis.setPen(pg.mkPen(color=0.0))
            # Remove the condition (that is, allow setting this for bottom
            # axis) when pyqtgraph is fixed
            # Issue: https://github.com/pyqtgraph/pyqtgraph/issues/930
            # Pull request: https://github.com/pyqtgraph/pyqtgraph/pull/932
            if axis_name != "bottom":  # remove if when pyqtgraph is fixed
                axis.setStyle(stopAxisAtTick=(True, True))

        self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)
        self.plotview.setCentralItem(self.plot)

        self.mainArea.layout().addWidget(self.plotview)
        self._set_explanation()

        VisualSettingsDialog(self, self.plot.parameter_setter.initial_settings)

    @Inputs.evaluation_results
    def set_results(self, results):
        self.closeContext()
        self.clear()
        self.Error.clear()
        self.Information.clear()

        self.results = None
        if results is not None:
            if not results.domain.has_discrete_class:
                self.Error.non_discrete_target()
            elif not results.actual.size:
                self.Error.empty_input()
            elif np.any(np.isnan(results.actual)):
                self.Error.nan_classes()
            else:
                self.results = results
                self._initialize(results)
                class_var = self.results.domain.class_var
                self.target_index = int(len(class_var.values) == 2)
                self.openContext(class_var, self.classifier_names)
                self._replot()

        self.commit.now()

    def clear(self):
        self.plot.clear()
        self.results = None
        self.classifier_names = []
        self.selected_classifiers = []
        self.target_cb.clear()
        self.colors = []

    def target_index_changed(self):
        if len(self.results.domain.class_var.values) == 2:
            self.threshold = 1 - self.threshold
        self._set_explanation()
        self._replot()
        self.commit.deferred()

    def score_changed(self):
        self._set_explanation()
        self._replot()
        if self._last_score_value != self.score:
            self.commit.deferred()
            self._last_score_value = self.score

    def _set_explanation(self):
        explanation = Metrics[self.score].explanation
        if explanation:
            self.explanation.setText(explanation)
            self.explanation.show()
        else:
            self.explanation.hide()

        if self.score == 0:
            self.controls.output_calibration.show()
            self.info_box.hide()
        else:
            self.controls.output_calibration.hide()
            self.info_box.show()

        axis = self.plot.getAxis("bottom")
        axis.setLabel("Predicted probability" if self.score ==
                      0 else "Threshold probability to classify as positive")

        axis = self.plot.getAxis("left")
        axis.setLabel(Metrics[self.score].name)

    def _initialize(self, results):
        n = len(results.predicted)
        names = getattr(results, "learner_names", None)
        if names is None:
            names = ["#{}".format(i + 1) for i in range(n)]

        self.classifier_names = names
        self.colors = colorpalettes.get_default_curve_colors(n)

        for i in range(n):
            item = self.classifiers_list_box.item(i)
            item.setIcon(colorpalettes.ColorIcon(self.colors[i]))

        self.selected_classifiers = list(range(n))
        self.target_cb.addItems(results.domain.class_var.values)
        self.target_index = 0

    def _rug(self, data, pen_args):
        color = pen_args["pen"].color()
        rh = 0.025
        rug_x = np.c_[data.probs[:-1], data.probs[:-1]]
        rug_x_true = rug_x[data.ytrue].ravel()
        rug_x_false = rug_x[~data.ytrue].ravel()

        rug_y_true = np.ones_like(rug_x_true)
        rug_y_true[1::2] = 1 - rh
        rug_y_false = np.zeros_like(rug_x_false)
        rug_y_false[1::2] = rh

        self.plot.plot(rug_x_false,
                       rug_y_false,
                       pen=color,
                       connect="pairs",
                       antialias=True)
        self.plot.plot(rug_x_true,
                       rug_y_true,
                       pen=color,
                       connect="pairs",
                       antialias=True)

    def plot_metrics(self, data, metrics, pen_args):
        if metrics is None:
            return self._prob_curve(data.ytrue, data.probs[:-1], pen_args)
        ys = [metric(data) for metric in metrics]
        for y in ys:
            self.plot.plot(data.probs, y, **pen_args)
        return data.probs, ys

    def _prob_curve(self, ytrue, probs, pen_args):
        xmin, xmax = probs.min(), probs.max()
        x = np.linspace(xmin, xmax, 100)
        if xmax != xmin:
            f = gaussian_smoother(probs, ytrue, sigma=0.15 * (xmax - xmin))
            y = f(x)
        else:
            y = np.full(100, xmax)

        self.plot.plot(x, y, symbol="+", symbolSize=4, **pen_args)
        return x, (y, )

    def _setup_plot(self):
        target = self.target_index
        results = self.results
        metrics = Metrics[self.score].functions
        plot_folds = self.fold_curves and results.folds is not None
        self.scores = []

        if not self._check_class_presence(results.actual == target):
            return

        self.Warning.omitted_folds.clear()
        self.Warning.omitted_nan_prob_points.clear()
        no_valid_models = []
        shadow_width = 4 + 4 * plot_folds
        for clsf in self.selected_classifiers:
            data = Curves.from_results(results, target, clsf)
            if data.tot == 0:  # all probabilities are nan
                no_valid_models.append(clsf)
                continue
            if data.tot != results.probabilities.shape[1]:  # some are nan
                self.Warning.omitted_nan_prob_points()

            color = self.colors[clsf]
            pen_args = dict(pen=pg.mkPen(color, width=1),
                            antiAlias=True,
                            shadowPen=pg.mkPen(color.lighter(160),
                                               width=shadow_width))
            self.scores.append((self.classifier_names[clsf],
                                self.plot_metrics(data, metrics, pen_args)))

            if self.display_rug:
                self._rug(data, pen_args)

            if plot_folds:
                pen_args = dict(pen=pg.mkPen(color, width=1,
                                             style=Qt.DashLine),
                                antiAlias=True)
                for fold in range(len(results.folds)):
                    fold_results = results.get_fold(fold)
                    fold_curve = Curves.from_results(fold_results, target,
                                                     clsf)
                    # Can't check this before: p and n can be 0 because of
                    # nan probabilities
                    if fold_curve.p * fold_curve.n == 0:
                        self.Warning.omitted_folds()
                    self.plot_metrics(fold_curve, metrics, pen_args)

        if no_valid_models:
            self.Warning.no_valid_data(", ".join(self.classifier_names[i]
                                                 for i in no_valid_models))

        if self.score == 0:
            self.plot.plot([0, 1], [0, 1], antialias=True)
        else:
            self.line = pg.InfiniteLine(
                pos=self.threshold,
                movable=True,
                pen=pg.mkPen(color="k", style=Qt.DashLine, width=2),
                hoverPen=pg.mkPen(color="k", style=Qt.DashLine, width=3),
                bounds=(0, 1),
            )
            self.line.sigPositionChanged.connect(self.threshold_change)
            self.line.sigPositionChangeFinished.connect(
                self.threshold_change_done)
            self.plot.addItem(self.line)

    def _check_class_presence(self, ytrue):
        self.Error.all_target_class.clear()
        self.Error.no_target_class.clear()
        if np.max(ytrue) == 0:
            self.Error.no_target_class()
            return False
        if np.min(ytrue) == 1:
            self.Error.all_target_class()
            return False
        return True

    def _replot(self):
        self.plot.clear()
        if self.results is not None:
            self._setup_plot()
        self._update_info()

    def _on_display_rug_changed(self):
        self._replot()

    def _on_selection_changed(self):
        self._replot()
        self.commit.deferred()

    def threshold_change(self):
        self.threshold = round(self.line.pos().x(), 2)
        self.line.setPos(self.threshold)
        self._update_info()

    def get_info_text(self, short):
        if short:

            def elided(s):
                return s[:17] + "..." if len(s) > 20 else s

            text = f"""<table>
                            <tr>
                                <th align='right'>Threshold: p=</th>
                                <td colspan='4'>{self.threshold:.2f}<br/></td>
                            </tr>"""

        else:

            def elided(s):
                return s

            text = f"""<table>
                            <tr>
                                <th align='right'>Threshold:</th>
                                <td colspan='4'>p = {self.threshold:.2f}<br/>
                                </td>
                                <tr/>
                            </tr>"""

        if self.scores is not None:
            short_names = Metrics[self.score].short_names
            if short_names:
                text += f"""<tr>
                                <th></th>
                                {"<td></td>".join(f"<td align='right'>{n}</td>"
                                                  for n in short_names)}
                            </tr>"""
            for name, (probs, curves) in self.scores:
                ind = min(np.searchsorted(probs, self.threshold),
                          len(probs) - 1)
                text += f"<tr><th align='right'>{elided(name)}:</th>"
                text += "<td>/</td>".join(f'<td>{curve[ind]:.3f}</td>'
                                          for curve in curves)
                text += "</tr>"
            text += "<table>"
            return text
        return None

    def _update_info(self):
        self.info_label.setText(self.get_info_text(short=True))

    def threshold_change_done(self):
        self.commit.deferred()

    @gui.deferred
    def commit(self):
        self.Information.no_output.clear()
        wrapped = None
        results = self.results
        if results is not None:
            problems = [
                msg for condition, msg in (
                    (len(results.folds) > 1,
                     "each training data sample produces a different model"),
                    (results.models is None,
                     "test results do not contain stored models - try testing "
                     "on separate data or on training data"),
                    (len(self.selected_classifiers) != 1,
                     "select a single model - the widget can output only one"),
                    (self.score != 0
                     and len(results.domain.class_var.values) != 2,
                     "cannot calibrate non-binary classes")) if condition
            ]
            if len(problems) == 1:
                self.Information.no_output(problems[0])
            elif problems:
                self.Information.no_output("".join(f"\n - {problem}"
                                                   for problem in problems))
            else:
                clsf_idx = self.selected_classifiers[0]
                model = results.models[0, clsf_idx]
                if self.score == 0:
                    cal_learner = CalibratedLearner(None,
                                                    self.output_calibration)
                    wrapped = cal_learner.get_model(
                        model, results.actual, results.probabilities[clsf_idx])
                else:
                    threshold = [1 - self.threshold,
                                 self.threshold][self.target_index]
                    wrapped = ThresholdClassifier(model, threshold)

        self.Outputs.calibrated_model.send(wrapped)

    def send_report(self):
        if self.results is None:
            return
        self.report_items(
            (("Target class", self.target_cb.currentText()),
             ("Output model calibration", self.score == 0
              and ("Sigmoid calibration",
                   "Isotonic calibration")[self.output_calibration])))
        caption = report.list_legend(self.classifiers_list_box,
                                     self.selected_classifiers)
        self.report_plot()
        self.report_caption(caption)
        self.report_caption(self.controls.score.currentText())

        if self.score != 0:
            self.report_raw(self.get_info_text(short=False))

    def set_visual_settings(self, key, value):
        self.plot.parameter_setter.set_parameter(key, value)
        self.visual_settings[key] = value