コード例 #1
0
    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="显示概率",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        self.reset_button = gui.button(self.controlArea,
                                       self,
                                       "恢复原始顺序",
                                       callback=self._reset_order,
                                       tooltip="以原始顺序显示行")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.NoSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(sortingEnabled=True,
                                  verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)
コード例 #2
0
ファイル: test_utils.py プロジェクト: markotoplak/orange3
    def test_update_shown_columns(self):
        score_table = ScoreTable(None)
        view = score_table.view
        all, shown = "MABDEFG", "ABDF"
        header = view.horizontalHeader()
        score_table.shown_scores = set(shown)
        score_table.model.setHorizontalHeaderLabels(list(all))
        score_table._update_shown_columns()
        for i, name in enumerate(all):
            self.assertEqual(name == "M" or name in shown,
                             not header.isSectionHidden(i),
                             msg="error in section {}({})".format(i, name))

        score_table.shown_scores = set()
        score_table._update_shown_columns()
        for i, name in enumerate(all):
            self.assertEqual(i == 0,
                             not header.isSectionHidden(i),
                             msg="error in section {}({})".format(i, name))
コード例 #3
0
ファイル: test_utils.py プロジェクト: qeryq/SFECOMLA
    def test_sorting(self):
        def order(n=5):
            return "".join(model.index(i, 0).data() for i in range(n))

        score_table = ScoreTable(None)

        data = [
            ["D", 11.0, 15.3],
            ["C", 5.0, -15.4],
            ["b", 20.0, np.nan],
            ["A", None, None],
            ["E", "", 0.0]
        ]
        for data_row in data:
            row = []
            for x in data_row:
                item = QStandardItem()
                if x is not None:
                    item.setData(x, Qt.DisplayRole)
                row.append(item)
            score_table.model.appendRow(row)

        model = score_table.view.model()

        model.sort(0, Qt.AscendingOrder)
        self.assertEqual(order(), "AbCDE")

        model.sort(0, Qt.DescendingOrder)
        self.assertEqual(order(), "EDCbA")

        model.sort(1, Qt.AscendingOrder)
        self.assertEqual(order(3), "CDb")

        model.sort(1, Qt.DescendingOrder)
        self.assertEqual(order(3), "bDC")

        model.sort(2, Qt.AscendingOrder)
        self.assertEqual(order(3), "CED")

        model.sort(2, Qt.DescendingOrder)
        self.assertEqual(order(3), "DEC")
コード例 #4
0
ファイル: test_utils.py プロジェクト: markotoplak/orange3
    def test_show_column_chooser(self):
        score_table = ScoreTable(None)
        view = score_table.view
        all, shown = "MABDEFG", "ABDF"
        header = view.horizontalHeader()
        score_table.shown_scores = set(shown)
        score_table.model.setHorizontalHeaderLabels(list(all))
        score_table._update_shown_columns()

        actions = collections.OrderedDict()
        menu_add_action = QMenu.addAction

        def addAction(menu, a):
            action = menu_add_action(menu, a)
            actions[a] = action
            return action

        def execmenu(*_):
            self.assertEqual(list(actions), list(all)[1:])
            for name, action in actions.items():
                self.assertEqual(action.isChecked(), name in shown)
            actions["E"].triggered.emit(True)
            self.assertEqual(score_table.shown_scores, set("ABDEF"))
            actions["B"].triggered.emit(False)
            self.assertEqual(score_table.shown_scores, set("ADEF"))
            for i, name in enumerate(all):
                self.assertEqual(name == "M" or name in "ADEF",
                                 not header.isSectionHidden(i),
                                 msg="error in section {}({})".format(i, name))

        # We must patch `QMenu.exec` because the Qt would otherwise (invisibly)
        # show the popup and wait for the user.
        # Assertions are made within `menuexec` since they check the
        # instances of `QAction`, which are invalid (destroyed by Qt?) after
        # `menuexec` finishes.
        with unittest.mock.patch("AnyQt.QtWidgets.QMenu.addAction", addAction), \
             unittest.mock.patch("AnyQt.QtWidgets.QMenu.exec", execmenu):
            score_table.show_column_chooser(QPoint(0, 0))
コード例 #5
0
    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
        gui.comboBox(box,
                     self,
                     "comparison_criterion",
                     model=PyListModel(),
                     callback=self.update_comparison_table)

        hbox = gui.hBox(box)
        gui.checkBox(hbox,
                     self,
                     "use_rope",
                     "Negligible difference: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(hbox,
                     self,
                     "rope",
                     validator=QDoubleValidator(),
                     controlWidth=70,
                     callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)

        self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
        table = self.comparison_table = QTableWidget(
            wordWrap=False,
            editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(
            QLabel(
                "<small>Table shows probabilities that the score for the model in "
                "the row is higher than that of the model in the column. "
                "Small numbers show the probability that the difference is "
                "negligible.</small>",
                wordWrap=True))
コード例 #6
0
class OWTestLearners(OWWidget):
    name = "Test and Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100
    keywords = ['Cross Validation', 'CV']

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message("Click on the table header to select shown columns",
                       "click_header")
    ]

    settingsHandler = settings.PerfectDomainContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    use_rope = settings.Setting(False)
    rope = settings.Setting(0.1)
    comparison_criterion = settings.Setting(0, schema_only=True)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train dataset is empty.")
        test_data_empty = Msg("Test dataset is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg(
            "Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        no_class_values = Msg("Target variable has no values.")
        only_one_class_var_value = Msg("Target variable has only one value.")
        test_data_incompatible = Msg(
            "Test data may be incompatible with train data.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")
        test_data_transformed = Msg(
            "Test data has been transformed to match the train data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
        gui.comboBox(box,
                     self,
                     "comparison_criterion",
                     model=PyListModel(),
                     callback=self.update_comparison_table)

        hbox = gui.hBox(box)
        gui.checkBox(hbox,
                     self,
                     "use_rope",
                     "Negligible difference: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(hbox,
                     self,
                     "rope",
                     validator=QDoubleValidator(),
                     controlWidth=70,
                     callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)

        self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
        table = self.comparison_table = QTableWidget(
            wordWrap=False,
            editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(
            QLabel(
                "<small>Table shows probabilities that the score for the model in "
                "the row is higher than that of the model in the column. "
                "Small numbers show the probability that the difference is "
                "negligible.</small>",
                wordWrap=True))

    @staticmethod
    def sizeHint():
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        elif learner is not None:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.cancel()
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.no_class_values.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not data:
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [
                not data.domain.class_vars,
                len(data.domain.class_vars) > 1,
                np.isnan(data.Y).all(), data.domain.has_discrete_class
                and len(data.domain.class_var.values) == 1
            ]
            errors = [
                self.Error.class_required, self.Error.too_many_classes,
                self.Error.no_class_values, self.Error.only_one_class_var_value
            ]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not data:
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {
            (True, True): " ",  # both, don't specify
            (True, False): " train ",
            (False, True): " test "
        }[(self.train_data_missing_vals, self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data and self.data.domain.class_var:
            new_scorers = usable_scorers(self.data.domain.class_var)
        else:
            new_scorers = []
        # Don't unnecessarily reset the model because this would always reset
        # comparison_criterion; we alse set it explicitly, though, for clarity
        if new_scorers != self.scorers:
            self.scorers = new_scorers
            self.controls.comparison_criterion.model()[:] = \
                [scorer.long_name or scorer.name for scorer in self.scorers]
            self.comparison_criterion = 0
        if self.__pending_comparison_criterion is not None:
            # Check for the unlikely case that some scorers have been removed
            # from modules
            if self.__pending_comparison_criterion < len(self.scorers):
                self.comparison_criterion = self.__pending_comparison_criterion
            self.__pending_comparison_criterion = None
        self._update_compbox_title()

    def _update_compbox_title(self):
        criterion = self.comparison_criterion
        if criterion < len(self.scorers):
            scorer = self.scorers[criterion]()
            self.compbox.setTitle(f"Model Comparison by {scorer.name}")
        else:
            self.compbox.setTitle(f"Model Comparison")

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self.score_table.update_header(self.scorers)
        self._update_view_enabled()
        self.update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold)
        self._update_view_enabled()
        self._invalidate()
        self.__update()

    def _update_view_enabled(self):
        self.comparison_table.setEnabled(
            self.resampling == OWTestLearners.KFold and len(self.learners) > 1
            and self.data is not None)
        self.score_table.view.setEnabled(self.data is not None)

    def update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.score_table.model
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        names = []
        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            names.append(name)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            results = slot.results
            if results is not None and results.success:
                train = QStandardItem("{:.3f}".format(
                    results.value.train_time))
                train.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                train.setData(key, Qt.UserRole)
                test = QStandardItem("{:.3f}".format(results.value.test_time))
                test.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                test.setData(key, Qt.UserRole)
                row = [head, train, test]
            else:
                row = [head]
            if isinstance(results, Try.Fail):
                head.setToolTip(str(results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                if isinstance(results.exception, DomainTransformationError) \
                        and self.resampling == self.TestOnTest:
                    self.Error.test_data_incompatible()
                    self.Information.test_data_transformed.clear()
                else:
                    errors.append("{name} failed with error:\n"
                                  "{exc.__class__.__name__}: {exc!s}".format(
                                      name=name, exc=slot.results.exception))

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(slot.results.value,
                                                      target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [
                        Try(scorer_caller(scorer, ovr_results, target=1))
                        for scorer in self.scorers
                    ]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat, scorer in zip(stats, self.scorers):
                    item = QStandardItem()
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    if stat.success:
                        item.setData(float(stat.value[0]), Qt.DisplayRole)
                    else:
                        item.setToolTip(str(stat.exception))
                        if scorer.name in self.score_table.shown_scores:
                            has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        # Resort rows based on current sorting
        header = self.score_table.view.horizontalHeader()
        model.sort(header.sortIndicatorSection(), header.sortIndicatorOrder())
        self._set_comparison_headers(names)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _on_use_rope_changed(self):
        self.controls.rope.setEnabled(self.use_rope)
        self.update_comparison_table()

    def update_comparison_table(self):
        self.comparison_table.clearContents()
        slots = self._successful_slots()
        if not (slots and self.scorers):
            return
        names = [learner_name(slot.learner) for slot in slots]
        self._set_comparison_headers(names)
        if self.resampling == OWTestLearners.KFold:
            scores = self._scores_by_folds(slots)
            self._fill_table(names, scores)

    def _successful_slots(self):
        model = self.score_table.model
        proxy = self.score_table.sorted_model

        keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
                for row in range(proxy.rowCount()))
        slots = [
            slot for slot in (self.learners[key] for key in keys)
            if slot.results is not None and slot.results.success
        ]
        return slots

    def _set_comparison_headers(self, names):
        table = self.comparison_table
        try:
            # Prevent glitching during update
            table.setUpdatesEnabled(False)
            header = table.horizontalHeader()
            if len(names) > 2:
                header.setSectionResizeMode(QHeaderView.Stretch)
            else:
                header.setSectionResizeMode(QHeaderView.Fixed)
            table.setRowCount(len(names))
            table.setColumnCount(len(names))
            table.setVerticalHeaderLabels(names)
            table.setHorizontalHeaderLabels(names)
        finally:
            table.setUpdatesEnabled(True)

    def _scores_by_folds(self, slots):
        scorer = self.scorers[self.comparison_criterion]()
        self._update_compbox_title()
        if scorer.is_binary:
            if self.class_selection != self.TARGET_AVERAGE:
                class_var = self.data.domain.class_var
                target_index = class_var.values.index(self.class_selection)
                kw = dict(target=target_index)
            else:
                kw = dict(average='weighted')
        else:
            kw = {}

        def call_scorer(results):
            def thunked():
                return scorer.scores_by_folds(results.value, **kw).flatten()

            return thunked

        scores = [Try(call_scorer(slot.results)) for slot in slots]
        scores = [score.value if score.success else None for score in scores]
        # `None in scores doesn't work -- these are np.arrays)
        if any(score is None for score in scores):
            self.Warning.scores_not_computed()
        return scores

    def _fill_table(self, names, scores):
        table = self.comparison_table
        for row, row_name, row_scores in zip(count(), names, scores):
            for col, col_name, col_scores in zip(range(row), names, scores):
                if row_scores is None or col_scores is None:
                    continue
                if self.use_rope and self.rope:
                    p0, rope, p1 = baycomp.two_on_single(
                        row_scores, col_scores, self.rope)
                    if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(
                        table, row, col,
                        f"{p0:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({row_name} > {col_name}) = {p0:.3f}\n"
                        f"p({row_name} = {col_name}) = {rope:.3f}")
                    self._set_cell(
                        table, col, row,
                        f"{p1:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({col_name} > {row_name}) = {p1:.3f}\n"
                        f"p({col_name} = {row_name}) = {rope:.3f}")
                else:
                    p0, p1 = baycomp.two_on_single(row_scores, col_scores)
                    if np.isnan(p0) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(table, row, col, f"{p0:.3f}",
                                   f"p({row_name} > {col_name}) = {p0:.3f}")
                    self._set_cell(table, col, row, f"{p1:.3f}",
                                   f"p({col_name} > {row_name}) = {p1:.3f}")

    @classmethod
    def _set_cells_na(cls, table, row, col):
        cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
        cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

    @staticmethod
    def _set_cell(table, row, col, label, tooltip):
        item = QLabel(label)
        item.setToolTip(tooltip)
        item.setAlignment(Qt.AlignCenter)
        table.setCellWidget(row, col, item)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self.update_stats_model()
        self.update_comparison_table()

    def _invalidate(self, which=None):
        self.cancel()
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.score_table.model
        statmodelkeys = [
            model.item(row, 0).data(Qt.UserRole)
            for row in range(model.rowCount())
        ]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.comparison_table.clearContents()

        self.__needupdate = True

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [
            slot for slot in self.learners.values()
            if slot.results is not None and slot.results.success
        ]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [
                learner_name(slot.learner) for slot in valid
            ]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(
                    combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".format(
                stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [
                ("Sampling type",
                 "{}Shuffle split, {} random samples with {}% data ".format(
                     stratified, self.NRepeats[self.n_repeats],
                     self.SampleSizes[self.sample_size]))
            ]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.score_table.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')
            ]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Error.test_data_incompatible.clear()
        self.Warning.test_data_missing.clear()
        self.Information.test_data_transformed(
            shown=self.resampling == self.TestOnTest and self.data is not None
            and self.test_data is not None and
            self.data.domain.attributes != self.test_data.domain.attributes)
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData(store_data=True,
                                                 store_models=True), self.data,
                self.test_data, learners_c, self.preprocessor)
        else:
            if self.resampling == OWTestLearners.KFold:
                sampler = Orange.evaluation.CrossValidation(
                    k=self.NFolds[self.n_folds], random_state=rstate)
            elif self.resampling == OWTestLearners.FeatureFold:
                sampler = Orange.evaluation.CrossValidationFeature(
                    feature=self.fold_feature)
            elif self.resampling == OWTestLearners.LeaveOneOut:
                sampler = Orange.evaluation.LeaveOneOut()
            elif self.resampling == OWTestLearners.ShuffleSplit:
                sampler = Orange.evaluation.ShuffleSplit(
                    n_resamples=self.NRepeats[self.n_repeats],
                    train_size=self.SampleSizes[self.sample_size] / 100,
                    test_size=None,
                    stratified=self.shuffle_stratified,
                    random_state=rstate)
            elif self.resampling == OWTestLearners.TestOnTrain:
                sampler = Orange.evaluation.TestOnTrainingData(
                    store_models=True)
            else:
                assert False, "self.resampling %s" % self.resampling

            sampler.store_data = True
            test_f = partial(sampler, self.data, learners_c, self.preprocessor)

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[[float], None]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = TaskState()

        def progress_callback(finished):
            if task.is_interruption_requested():
                raise UserInterrupt()
            task.set_progress_value(100 * finished)

        testfunc = partial(testfunc, callback=progress_callback)
        task.start(self.__executor, testfunc)

        task.progress_changed.connect(self.setProgressValue)
        task.watcher.finished.connect(self.__task_complete)

        self.Outputs.evaluations_results.invalidate()
        self.Outputs.predictions.invalidate()
        self.progressBarInit()
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, f: 'Future[Results]'):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        assert self.__task is not None and self.__task.future is f
        self.progressBarFinished()
        self.setStatusMessage("")
        assert f.done()
        self.__task = None
        self.__state = State.Done
        try:
            results = f.result()  # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:  # pylint: disable=broad-except
            log.exception("testing error (in __task_complete):", exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er),
                                                                 er)))
            return

        learner_key = {
            slot.learner: key
            for key, slot in self.learners.items()
        }
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [
                        Try(scorer_caller(scorer, result))
                        for scorer in self.scorers
                    ]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self.score_table.update_header(self.scorers)
        self.update_stats_model()
        self.update_comparison_table()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            task.progress_changed.disconnect(self.setProgressValue)
            task.watcher.finished.disconnect(self.__task_complete)
            self.progressBarFinished()
            self.setStatusMessage("")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
コード例 #7
0
    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)
コード例 #8
0
ファイル: owpredictions.py プロジェクト: erelin6613/orange3
    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []
        self.left_width = 10
        self.selection_store = None
        self.__pending_selection = self.selection

        self._set_input_summary()
        self._set_output_summary(None)

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.ExtendedSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        self.reset_button = gui.button(
            self.controlArea,
            self,
            "Restore Original Order",
            callback=self._reset_order,
            tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.ExtendedSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(sortingEnabled=True,
                                  verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.dataview.setItemDelegate(DataItemDelegate(self.dataview))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.splitterMoved.connect(self.splitter_resized)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)
コード例 #9
0
ファイル: owpredictions.py プロジェクト: erelin6613/orange3
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results", Results)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")
        wrong_targets = Msg(
            "Some model(s) predict a different target (see more ...)\n{}")

    class Error(OWWidget.Error):
        predictor_failed = Msg("Some predictor(s) failed (see more ...)\n{}")
        scorer_failed = Msg("Some scorer(s) failed (see more ...)\n{}")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: List of selected class value indices in the `class_values` list
    selected_classes = settings.ContextSetting([])
    selection = settings.Setting([], schema_only=True)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []
        self.left_width = 10
        self.selection_store = None
        self.__pending_selection = self.selection

        self._set_input_summary()
        self._set_output_summary(None)

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.ExtendedSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        self.reset_button = gui.button(
            self.controlArea,
            self,
            "Restore Original Order",
            callback=self._reset_order,
            tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.ExtendedSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(sortingEnabled=True,
                                  verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.dataview.setItemDelegate(DataItemDelegate(self.dataview))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.splitterMoved.connect(self.splitter_resized)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)

    def get_selection_store(self, proxy):
        # Both proxies map the same, so it doesn't matter which one is used
        # to initialize SharedSelectionStore
        if self.selection_store is None:
            self.selection_store = SharedSelectionStore(proxy)
        return self.selection_store

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.Warning.empty_data(shown=data is not None and not data)
        self.data = data
        self.selection_store = None
        if not data:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = SortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            sel_model = SharedSelectionModel(
                self.get_selection_store(modelproxy), modelproxy,
                self.dataview)
            self.dataview.setSelectionModel(sel_model)
            if self.__pending_selection is not None:
                self.selection = self.__pending_selection
                self.__pending_selection = None
                self.selection_store.select_rows(
                    set(self.selection), QItemSelectionModel.ClearAndSelect)
            sel_model.selectionChanged.connect(self.commit)
            sel_model.selectionChanged.connect(self._store_selection)

            self.dataview.model().list_sorted.connect(
                partial(self._update_data_sort_order, self.dataview,
                        self.predictionsview))

        self._invalidate_predictions()

    def _store_selection(self):
        self.selection = list(self.selection_store.rows)

    @property
    def class_var(self):
        return self.data and self.data.domain.class_var

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = PredictorSlot(predictor, predictor.name,
                                                None)

    def _set_class_values(self):
        class_values = []
        for slot in self.predictors.values():
            class_var = slot.predictor.domain.class_var
            if class_var and class_var.is_discrete:
                for value in class_var.values:
                    if value not in class_values:
                        class_values.append(value)

        if self.class_var and self.class_var.is_discrete:
            values = self.class_var.values
            self.class_values = sorted(class_values,
                                       key=lambda val: val not in values)
            self.selected_classes = [
                i for i, name in enumerate(class_values) if name in values
            ]
        else:
            self.class_values = class_values  # This assignment updates listview
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_values()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._set_input_summary()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        if self.class_var:
            domain = self.data.domain
            classless_data = self.data.transform(
                Domain(domain.attributes, None, domain.metas))
        else:
            classless_data = self.data

        for inputid, slot in self.predictors.items():
            if isinstance(slot.results, Results):
                continue

            predictor = slot.predictor
            try:
                if predictor.domain.class_var.is_discrete:
                    pred, prob = predictor(classless_data, Model.ValueProbs)
                else:
                    pred = predictor(classless_data, Model.Value)
                    prob = numpy.zeros((len(pred), 0))
            except (ValueError, DomainTransformationError) as err:
                self.predictors[inputid] = \
                    slot._replace(results=f"{predictor.name}: {err}")
                continue

            results = Results()
            results.data = self.data
            results.domain = self.data.domain
            results.row_indices = numpy.arange(len(self.data))
            results.folds = (Ellipsis, )
            results.actual = self.data.Y
            results.unmapped_probabilities = prob
            results.unmapped_predicted = pred
            results.probabilities = results.predicted = None
            self.predictors[inputid] = slot._replace(results=results)

            target = predictor.domain.class_var
            if target != self.class_var:
                continue

            if target is not self.class_var and target.is_discrete:
                backmappers, n_values = predictor.get_backmappers(self.data)
                prob = predictor.backmap_probs(prob, n_values, backmappers)
                pred = predictor.backmap_value(pred, prob, n_values,
                                               backmappers)
            results.predicted = pred.reshape((1, len(self.data)))
            results.probabilities = prob.reshape((1, ) + prob.shape)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()
        scorers = usable_scorers(self.class_var) if self.class_var else []
        self.score_table.update_header(scorers)
        errors = []
        for inputid, pred in self.predictors.items():
            results = self.predictors[inputid].results
            if not isinstance(results, Results) or results.predicted is None:
                continue
            row = [
                QStandardItem(learner_name(pred.predictor)),
                QStandardItem("N/A"),
                QStandardItem("N/A")
            ]
            for scorer in scorers:
                item = QStandardItem()
                try:
                    score = scorer_caller(scorer, results)()[0]
                    item.setText(f"{score:.3f}")
                except Exception as exc:  # pylint: disable=broad-except
                    item.setToolTip(str(exc))
                    if scorer.name in self.score_table.shown_scores:
                        errors.append(str(exc))
                row.append(item)
            self.score_table.model.appendRow(row)

        view = self.score_table.view
        if model.rowCount():
            view.setVisible(True)
            view.ensurePolished()
            view.setFixedHeight(5 + view.horizontalHeader().height() +
                                view.verticalHeader().sectionSize(0) *
                                model.rowCount())
        else:
            view.setVisible(False)

        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(f"- {p.predictor.name}: {p.results}"
                           for p in self.predictors.values()
                           if isinstance(p.results, str) and p.results)
        self.Error.predictor_failed(errors, shown=bool(errors))

        if self.class_var:
            inv_targets = "\n".join(
                f"- {pred.name} predicts '{pred.domain.class_var.name}'"
                for pred in (p.predictor for p in self.predictors.values()
                             if isinstance(p.results, Results)
                             and p.results.probabilities is None))
            self.Warning.wrong_targets(inv_targets, shown=bool(inv_targets))
        else:
            self.Warning.wrong_targets.clear()

    def _set_input_summary(self):
        if not self.data and not self.predictors:
            self.info.set_input_summary(self.info.NoInput)
            return

        summary = len(self.data) if self.data else 0
        details = self._get_details()
        self.info.set_input_summary(summary, details, format=Qt.RichText)

    def _get_details(self):
        details = "Data:<br>"
        details += format_summary_details(self.data).replace('\n', '<br>') if \
            self.data else "No data on input."
        details += "<hr>"
        pred_names = [v.name for v in self.predictors.values()]
        n_predictors = len(self.predictors)
        if n_predictors:
            n_valid = len(self._non_errored_predictors())
            details += plural("Model: {number} model{s}", n_predictors)
            if n_valid != n_predictors:
                details += f" ({n_predictors - n_valid} failed)"
            details += "<ul>"
            for name in pred_names:
                details += f"<li>{name}</li>"
            details += "</ul>"
        else:
            details += "Model:<br>No model on input."
        return details

    def _set_output_summary(self, output):
        summary = len(output) if output else self.info.NoOutput
        details = format_summary_details(output) if output else ""
        self.info.set_output_summary(summary, details)

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _non_errored_predictors(self):
        return [
            p for p in self.predictors.values()
            if isinstance(p.results, Results)
        ]

    def _reordered_probabilities(self, prediction):
        cur_values = prediction.predictor.domain.class_var.values
        new_ind = [self.class_values.index(x) for x in cur_values]
        probs = prediction.results.unmapped_probabilities
        new_probs = numpy.full((probs.shape[0], len(self.class_values)),
                               numpy.nan)
        new_probs[:, new_ind] = probs
        return new_probs

    def _update_predictions_model(self):
        results = []
        headers = []
        for p in self._non_errored_predictors():
            values = p.results.unmapped_predicted
            target = p.predictor.domain.class_var
            if target.is_discrete:
                # order probabilities in order from Show prob. for
                prob = self._reordered_probabilities(p)
                values = [Value(target, v) for v in values]
            else:
                prob = numpy.zeros((len(values), 0))
            results.append((values, prob))
            headers.append(p.predictor.name)

        if results:
            results = list(zip(*(zip(*res) for res in results)))
            model = PredictionsModel(results, headers)
        else:
            model = None

        if self.selection_store is not None:
            self.selection_store.unregister(
                self.predictionsview.selectionModel())

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setModel(predmodel)

        self.predictionsview.setSelectionModel(
            SharedSelectionModel(self.get_selection_store(predmodel),
                                 predmodel, self.predictionsview))

        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        self.predictionsview.model().list_sorted.connect(
            partial(self._update_data_sort_order, self.predictionsview,
                    self.dataview))

        self.predictionsview.resizeColumnsToContents()

    def _update_data_sort_order(self, sort_source_view, sort_dest_view):
        sort_dest = sort_dest_view.model()
        sort_source = sort_source_view.model()
        sortindicatorshown = False
        if sort_dest is not None:
            assert isinstance(sort_dest, QSortFilterProxyModel)
            n = sort_dest.rowCount()
            if sort_source is not None and sort_source.sortColumn() >= 0:
                sortind = numpy.argsort([
                    sort_source.mapToSource(sort_source.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            sort_dest.setSortIndices(sortind)

        sort_dest_view.horizontalHeader().setSortIndicatorShown(False)
        sort_source_view.horizontalHeader().setSortIndicatorShown(
            sortindicatorshown)
        self.commit()

    def _reset_order(self):
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.setSortIndices(None)
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.setSortIndices(None)
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)
        self.dataview.horizontalHeader().setSortIndicatorShown(False)

    def _all_color_values(self):
        """
        Return list of colors together with their values from all predictors
        classes. Colors and values are sorted according to the values order
        for simpler comparison.
        """
        predictors = self._non_errored_predictors()
        color_values = [
            list(
                zip(*sorted(zip(p.predictor.domain.class_var.colors,
                                p.predictor.domain.class_var.values),
                            key=itemgetter(1)))) for p in predictors
            if p.predictor.domain.class_var.is_discrete
        ]
        return color_values if color_values else [([], [])]

    @staticmethod
    def _colors_match(colors1, values1, color2, values2):
        """
        Test whether colors for values match. Colors matches when all
        values match for shorter list and colors match for shorter list.
        It is assumed that values will be sorted together with their colors.
        """
        shorter_length = min(len(colors1), len(color2))
        return (values1[:shorter_length] == values2[:shorter_length]
                and (numpy.array(colors1[:shorter_length]) == numpy.array(
                    color2[:shorter_length])).all())

    def _get_colors(self):
        """
        Defines colors for values. If colors match in all models use the union
        otherwise use standard colors.
        """
        all_colors_values = self._all_color_values()
        base_color, base_values = all_colors_values[0]
        for c, v in all_colors_values[1:]:
            if not self._colors_match(base_color, base_values, c, v):
                base_color = []
                break
            # replace base_color if longer
            if len(v) > len(base_color):
                base_color = c
                base_values = v

        if len(base_color) != len(self.class_values):
            return LimitedDiscretePalette(len(self.class_values)).palette
        # reorder colors to widgets order
        colors = [None] * len(self.class_values)
        for c, v in zip(base_color, base_values):
            colors[self.class_values.index(v)] = c
        return colors

    def _update_prediction_delegate(self):
        self._delegates.clear()
        colors = self._get_colors()
        for col, slot in enumerate(self.predictors.values()):
            target = slot.predictor.domain.class_var
            shown_probs = (() if target.is_continuous else [
                val if self.class_values[val] in target.values else None
                for val in self.selected_classes
            ])
            delegate = PredictionsItemDelegate(
                None if target.is_continuous else self.class_values,
                colors,
                shown_probs,
                target.format_str if target.is_continuous else None,
                parent=self.predictionsview)
            # QAbstractItemView does not take ownership of delegates, so we must
            self._delegates.append(delegate)
            self.predictionsview.setItemDelegateForColumn(col, delegate)
            self.predictionsview.setColumnHidden(col, False)

        self.predictionsview.resizeColumnsToContents()
        self._recompute_splitter_sizes()
        if self.predictionsview.model() is not None:
            self.predictionsview.model().setProbInd(self.selected_classes)

    def _recompute_splitter_sizes(self):
        if not self.data:
            return
        view = self.predictionsview
        self.left_width = \
            view.horizontalHeader().length() + view.verticalHeader().width()
        self._update_splitter()

    def _update_splitter(self):
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([self.left_width, w1 + w2 - self.left_width])

    def splitter_resized(self):
        self.left_width = self.splitter.sizes()[0]

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = [
            p for p in self._non_errored_predictors()
            if p.results.predicted is not None
        ]
        if not slots:
            self.Outputs.evaluation_results.send(None)
            return

        nanmask = numpy.isnan(self.data.get_column_view(self.class_var)[0])
        data = self.data[~nanmask]
        results = Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if self.class_var and self.class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        if not self.data:
            self._set_output_summary(None)
            self.Outputs.predictions.send(None)
            return

        newmetas = []
        newcolumns = []
        for slot in self._non_errored_predictors():
            if slot.predictor.domain.class_var.is_discrete:
                self._add_classification_out_columns(slot, newmetas,
                                                     newcolumns)
            else:
                self._add_regression_out_columns(slot, newmetas, newcolumns)

        attrs = list(self.data.domain.attributes)
        metas = list(self.data.domain.metas)
        names = [
            var.name
            for var in chain(attrs, self.data.domain.class_vars, metas) if var
        ]
        uniq_newmetas = []
        for new_ in newmetas:
            uniq = get_unique_names(names, new_.name)
            if uniq != new_.name:
                new_ = new_.copy(name=uniq)
            uniq_newmetas.append(new_)
            names.append(uniq)

        metas += uniq_newmetas
        domain = Orange.data.Domain(attrs, self.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns

        index = self.dataview.model().index
        map_to = self.dataview.model().mapToSource
        assert self.selection_store is not None
        rows = None
        if self.selection_store.rows:
            rows = [
                ind.row()
                for ind in self.dataview.selectionModel().selectedRows(0)
            ]
            rows.sort()
        elif self.dataview.model().isSorted() \
                or self.predictionsview.model().isSorted():
            rows = list(range(len(self.data)))
        if rows:
            source_rows = [map_to(index(row, 0)).row() for row in rows]
            predictions = predictions[source_rows]
        self.Outputs.predictions.send(predictions)
        self._set_output_summary(predictions)

    @staticmethod
    def _add_classification_out_columns(slot, newmetas, newcolumns):
        # Mapped or unmapped predictions?!
        # Or provide a checkbox so the user decides?
        pred = slot.predictor
        name = pred.name
        values = pred.domain.class_var.values
        newmetas.append(DiscreteVariable(name=name, values=values))
        newcolumns.append(slot.results.unmapped_predicted.reshape(-1, 1))
        newmetas += [
            ContinuousVariable(name=f"{name} ({value})") for value in values
        ]
        newcolumns.append(slot.results.unmapped_probabilities)

    @staticmethod
    def _add_regression_out_columns(slot, newmetas, newcolumns):
        newmetas.append(ContinuousVariable(name=slot.predictor.name))
        newcolumns.append(slot.results.unmapped_predicted.reshape((-1, 1)))

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_view = self.predictionsview
            predictions_model = predictions_view.model()

            # use ItemDelegate to style prediction values
            delegates = [
                predictions_view.itemDelegateForColumn(i)
                for i in range(predictions_model.columnCount())
            ]

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [delegate.displayText(
                          predictions_model.data(predictions_model.index(i, j)),
                          QLocale())
                       for j, delegate in enumerate(delegates)] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data:
            text = self._get_details().replace('\n', '<br>')
            if self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)

            self.report_table("Scores", self.score_table.view)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._update_splitter()

    def showEvent(self, event):
        super().showEvent(event)
        QTimer.singleShot(0, self._update_splitter)
コード例 #10
0
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results", Results)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")
        wrong_targets = Msg(
            "Some model(s) predict a different target (see more ...)\n{}")

    class Error(OWWidget.Error):
        predictor_failed = Msg("Some predictor(s) failed (see more ...)\n{}")
        scorer_failed = Msg("Some scorer(s) failed (see more ...)\n{}")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: List of selected class value indices in the `class_values` list
    selected_classes = settings.ContextSetting([])

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        gui.button(self.controlArea,
                   self,
                   "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.NoSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.Warning.empty_data(shown=data is not None and not data)
        self.data = data
        if not data:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)

        self._invalidate_predictions()

    @property
    def class_var(self):
        return self.data and self.data.domain.class_var

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = PredictorSlot(predictor, predictor.name,
                                                None)

    def _set_class_values(self):
        class_values = []
        for slot in self.predictors.values():
            class_var = slot.predictor.domain.class_var
            if class_var and class_var.is_discrete:
                for value in class_var.values:
                    if value not in class_values:
                        class_values.append(value)

        if self.class_var and self.class_var.is_discrete:
            values = self.class_var.values
            self.class_values = sorted(class_values,
                                       key=lambda val: val not in values)
            self.selected_classes = [
                i for i, name in enumerate(class_values) if name in values
            ]
        else:
            self.class_values = class_values  # This assignment updates listview
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_values()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._update_info()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        if self.class_var:
            domain = self.data.domain
            classless_data = self.data.transform(
                Domain(domain.attributes, None, domain.metas))
        else:
            classless_data = self.data

        for inputid, slot in self.predictors.items():
            if isinstance(slot.results, Results):
                continue

            predictor = slot.predictor
            try:
                if predictor.domain.class_var.is_discrete:
                    pred, prob = predictor(classless_data, Model.ValueProbs)
                else:
                    pred = predictor(classless_data, Model.Value)
                    prob = numpy.zeros((len(pred), 0))
            except (ValueError, DomainTransformationError) as err:
                self.predictors[inputid] = \
                    slot._replace(results=f"{predictor.name}: {err}")
                continue

            results = Results()
            results.data = self.data
            results.domain = self.data.domain
            results.row_indices = numpy.arange(len(self.data))
            results.folds = (Ellipsis, )
            results.actual = self.data.Y
            results.unmapped_probabilities = prob
            results.unmapped_predicted = pred
            results.probabilities = results.predicted = None
            self.predictors[inputid] = slot._replace(results=results)

            target = predictor.domain.class_var
            if target != self.class_var:
                continue

            if target is not self.class_var and target.is_discrete:
                backmappers, n_values = predictor.get_backmappers(self.data)
                prob = predictor.backmap_probs(prob, n_values, backmappers)
                pred = predictor.backmap_value(pred, prob, n_values,
                                               backmappers)
            results.predicted = pred.reshape((1, len(self.data)))
            results.probabilities = prob.reshape((1, ) + prob.shape)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()
        scorers = usable_scorers(self.class_var) if self.class_var else []
        self.score_table.update_header(scorers)
        errors = []
        for inputid, pred in self.predictors.items():
            results = self.predictors[inputid].results
            if not isinstance(results, Results) or results.predicted is None:
                continue
            row = [
                QStandardItem(learner_name(pred.predictor)),
                QStandardItem("N/A"),
                QStandardItem("N/A")
            ]
            for scorer in scorers:
                item = QStandardItem()
                try:
                    score = scorer_caller(scorer, results)()[0]
                    item.setText(f"{score:.3f}")
                except Exception as exc:  # pylint: disable=broad-except
                    item.setToolTip(str(exc))
                    if scorer.name in self.score_table.shown_scores:
                        errors.append(str(exc))
                row.append(item)
            self.score_table.model.appendRow(row)

        view = self.score_table.view
        if model.rowCount():
            view.setVisible(True)
            view.ensurePolished()
            view.setFixedHeight(5 + view.horizontalHeader().height() +
                                view.verticalHeader().sectionSize(0) *
                                model.rowCount())
        else:
            view.setVisible(False)

        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(f"- {p.predictor.name}: {p.results}"
                           for p in self.predictors.values()
                           if isinstance(p.results, str) and p.results)
        self.Error.predictor_failed(errors, shown=bool(errors))

        if self.class_var:
            inv_targets = "\n".join(
                f"- {pred.name} predicts '{pred.domain.class_var.name}'"
                for pred in (p.predictor for p in self.predictors.values()
                             if isinstance(p.results, Results)
                             and p.results.probabilities is None))
            self.Warning.wrong_targets(inv_targets, shown=bool(inv_targets))
        else:
            self.Warning.wrong_targets.clear()

    def _update_info(self):
        n_predictors = len(self.predictors)
        if not self.data and not n_predictors:
            self.info.set_input_summary(self.info.NoInput)
            return

        n_valid = len(self._non_errored_predictors())
        summary = str(len(self.data)) if self.data else "0"
        details = f"{len(self.data)} instances" if self.data else "No data"
        details += f"\n{n_predictors} models" if n_predictors else "No models"
        if n_valid != n_predictors:
            details += f" ({n_predictors - n_valid} failed)"
        self.info.set_input_summary(summary, details)

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _non_errored_predictors(self):
        return [
            p for p in self.predictors.values()
            if isinstance(p.results, Results)
        ]

    def _update_predictions_model(self):
        results = []
        headers = []
        for p in self._non_errored_predictors():
            values = p.results.unmapped_predicted
            target = p.predictor.domain.class_var
            if target.is_discrete:
                prob = p.results.unmapped_probabilities
                values = [Value(target, v) for v in values]
            else:
                prob = numpy.zeros((len(values), 0))
            results.append((values, prob))
            headers.append(p.predictor.name)

        if results:
            results = list(zip(*(zip(*res) for res in results)))
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setModel(predmodel)
        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_data_sort_order(self):
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort([
                    predmodel.mapToSource(predmodel.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        selected = {self.class_values[i] for i in self.selected_classes}
        self._delegates.clear()
        for col, slot in enumerate(self.predictors.values()):
            target = slot.predictor.domain.class_var
            shown_probs = () if target.is_continuous else \
                [i for i, name in enumerate(target.values) if name in selected]
            delegate = PredictionsItemDelegate(target, shown_probs)
            # QAbstractItemView does not take ownership of delegates, so we must
            self._delegates.append(delegate)
            self.predictionsview.setItemDelegateForColumn(col, delegate)
            self.predictionsview.setColumnHidden(col, False)

        self.predictionsview.resizeColumnsToContents()
        self._update_spliter()

    def _update_spliter(self):
        if not self.data:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        w = width(self.predictionsview) + 4
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([w, w1 + w2 - w])

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = [
            p for p in self._non_errored_predictors()
            if p.results.predicted is not None
        ]
        if not slots:
            self.Outputs.evaluation_results.send(None)
            return

        nanmask = numpy.isnan(self.data.get_column_view(self.class_var)[0])
        data = self.data[~nanmask]
        results = Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if self.class_var and self.class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        if not self.data:
            self.Outputs.predictions.send(None)
            return

        newmetas = []
        newcolumns = []
        for slot in self._non_errored_predictors():
            if slot.predictor.domain.class_var.is_discrete:
                self._add_classification_out_columns(slot, newmetas,
                                                     newcolumns)
            else:
                self._add_regression_out_columns(slot, newmetas, newcolumns)

        attrs = list(self.data.domain.attributes)
        metas = list(self.data.domain.metas) + newmetas
        domain = Orange.data.Domain(attrs, self.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns
        self.Outputs.predictions.send(predictions)

    @staticmethod
    def _add_classification_out_columns(slot, newmetas, newcolumns):
        # Mapped or unmapped predictions?!
        # Or provide a checkbox so the user decides?
        pred = slot.predictor
        name = pred.name
        values = pred.domain.class_var.values
        newmetas.append(DiscreteVariable(name=name, values=values))
        newcolumns.append(slot.results.unmapped_predicted.reshape(-1, 1))
        newmetas += [
            ContinuousVariable(name=f"{name} ({value})") for value in values
        ]
        newcolumns.append(slot.results.unmapped_probabilities)

    @staticmethod
    def _add_regression_out_columns(slot, newmetas, newcolumns):
        newmetas.append(ContinuousVariable(name=slot.predictor.name))
        newcolumns.append(slot.results.unmapped_predicted.reshape((-1, 1)))

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_model = self.predictionsview.model()

            # use ItemDelegate to style prediction values
            style = lambda x: self.predictionsview.itemDelegate().displayText(
                x, QLocale())

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [style(predictions_model.data(predictions_model.index(i, j)))
                       for j in range(predictions_model.columnCount())] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data:
            text = self.infolabel.text().replace('\n', '<br>')
            if self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)
コード例 #11
0
    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "抽样")
        rbox = gui.radioButtons(
            sbox, self, "resampling", callback=self._param_changed)

        gui.appendRadioButton(rbox, "交叉验证(Cross validation)")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_folds", label="折叠次数(Number of folds): ",
            items=[str(x) for x in self.NFolds], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.kfold_changed)
        gui.checkBox(
            ibox, self, "cv_stratified", "分层(Stratified)",
            callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "按特征交叉验证(Cross validation by feature)")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(
            ibox, self, "fold_feature", model=self.feature_model,
            orientation=Qt.Horizontal, callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "随机抽样(Random sampling)")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_repeats", label="重复训练/测试: ",
            items=[str(x) for x in self.NRepeats], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.shuffle_split_changed)
        gui.comboBox(
            ibox, self, "sample_size", label="训练集大小: ",
            items=["{} %".format(x) for x in self.SampleSizes],
            maximumContentsLength=5, orientation=Qt.Horizontal,
            callback=self.shuffle_split_changed)
        gui.checkBox(
            ibox, self, "shuffle_stratified", "分层(Stratified)",
            callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "留一法(Leave one out)")

        gui.appendRadioButton(rbox, "测试训练数据")
        gui.appendRadioButton(rbox, "测试测试数据")

        self.cbox = gui.vBox(self.controlArea, "目标类别")
        self.class_selection_combo = gui.comboBox(
            self.cbox, self, "class_selection", items=[],
            sendSelectedValue=True, valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        self.modcompbox = box = gui.vBox(self.controlArea, "模型比较")
        gui.comboBox(
            box, self, "comparison_criterion", model=PyListModel(),
            callback=self.update_comparison_table)

        hbox = gui.hBox(box)
        gui.checkBox(hbox, self, "use_rope",
                     "可忽略区别: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(),
                     controlWidth=70, callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        box = gui.vBox(self.mainArea, "评价结果")
        box.layout().addWidget(self.score_table.view)

        self.compbox = box = gui.vBox(self.mainArea, box='模型比较')
        table = self.comparison_table = QTableWidget(
            wordWrap=False, editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(QLabel(
            "<small>此表显示数值为一概率. 此概率表示\"行模型\"的评价指标大于"
            "\"列模型\"的评价指标. "
            "值小说明概率很小, 区别可以忽略不计.</small>", wordWrap=True))
コード例 #12
0
    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []

        box = gui.vBox(self.controlArea, "Show", spacing=-1, addSpace=False)

        gui.widgetLabel(box, "Show probabilities for:")
        gui.listBox(box, self, "selected_classes", "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box, self, "show_predictions", "Show predicted class",
                     callback=self._update_prediction_delegate)
        gui.checkBox(box, self, "draw_dist", "Distribution bars",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "Data View")
        gui.checkBox(box, self, "show_attrs", "Show full dataset",
                     callback=self._update_column_visibility)
        gui.button(box, self, "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        box = gui.vBox(self.controlArea, "Output", spacing=-1)
        self.checkbox_class = gui.checkBox(
            box, self, "output_attrs", "Original data",
            callback=self.commit)
        self.checkbox_class = gui.checkBox(
            box, self, "output_predictions", "Predictions",
            callback=self.commit)
        self.checkbox_prob = gui.checkBox(
            box, self, "output_probabilities", "Probabilities",
            callback=self.commit)

        gui.rubber(self.controlArea)

        self.vsplitter = gui.vBox(self.mainArea)

        self.splitter = QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus
        )
        self.predictionsview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.dataview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size:
            self.predictionsview.verticalHeader().resizeSection(index, size)
        )

        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)
コード例 #13
0
    def __init__(self):
        super().__init__()

        #: Input data table
        self.data = None  # type: Optional[Orange.data.Table]
        #: A dict mapping input ids to PredictorSlot
        self.predictors = OrderedDict()  # type: Dict[object, PredictorSlot]
        #: A class variable (prediction target)
        self.class_var = None  # type: Optional[Orange.data.Variable]
        #: List of (discrete) class variable's values
        self.class_values = []  # type: List[str]

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(
            box, "No data on input.\nPredictors: 0\nTask: N/A")
        self.infolabel.setMinimumWidth(150)
        gui.button(box,
                   self,
                   "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        self.classification_options = box = gui.vBox(self.controlArea,
                                                     "Show",
                                                     spacing=-1,
                                                     addSpace=False)

        gui.checkBox(box,
                     self,
                     "show_predictions",
                     "Predicted class",
                     callback=self._update_prediction_delegate)
        b = gui.checkBox(box,
                         self,
                         "show_probabilities",
                         "Predicted probabilities for:",
                         callback=self._update_prediction_delegate)
        ibox = gui.indentedBox(box,
                               sep=gui.checkButtonOffsetHint(b),
                               addSpace=False)
        gui.listBox(ibox,
                    self,
                    "selected_classes",
                    "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box,
                     self,
                     "draw_dist",
                     "Draw distribution bars",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "Data View")
        gui.checkBox(box,
                     self,
                     "show_attrs",
                     "Show full dataset",
                     callback=self._update_column_visibility)

        box = gui.vBox(self.controlArea, "Output", spacing=-1)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_attrs",
                                           "Original data",
                                           callback=self.commit)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_predictions",
                                           "Predictions",
                                           callback=self.commit)
        self.checkbox_prob = gui.checkBox(box,
                                          self,
                                          "output_probabilities",
                                          "Probabilities",
                                          callback=self.commit)

        gui.rubber(self.controlArea)

        self.vsplitter = QSplitter(orientation=Qt.Vertical,
                                   childrenCollapsible=True,
                                   handleWidth=2)

        self.splitter = QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus)
        self.predictionsview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.dataview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter.addWidget(self.splitter)
        self.vsplitter.addWidget(self.score_table.view)
        self.vsplitter.setStretchFactor(0, 5)
        self.vsplitter.setStretchFactor(1, 1)
        self.mainArea.layout().addWidget(self.vsplitter)
コード例 #14
0
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display the predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results",
                                    Orange.evaluation.Results,
                                    dynamic=False)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")

    class Error(OWWidget.Error):
        predictor_failed = \
            Msg("One or more predictors failed (see more...)\n{}")
        scorer_failed = \
            Msg("One or more scorers failed (see more...)\n{}")
        predictors_target_mismatch = \
            Msg("Predictors do not have the same target.")
        data_target_mismatch = \
            Msg("Data does not have the same target as predictors.")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Display the full input dataset or only the target variable columns (if
    #: available)
    show_attrs = settings.Setting(True)
    #: Show predicted values (for discrete target variable)
    show_predictions = settings.Setting(True)
    #: Show predictions probabilities (for discrete target variable)
    show_probabilities = settings.Setting(True)
    #: List of selected class value indices in the "Show probabilities" list
    selected_classes = settings.ContextSetting([])
    #: Draw colored distribution bars
    draw_dist = settings.Setting(True)

    output_attrs = settings.Setting(True)
    output_predictions = settings.Setting(True)
    output_probabilities = settings.Setting(True)

    def __init__(self):
        super().__init__()

        #: Input data table
        self.data = None  # type: Optional[Orange.data.Table]
        #: A dict mapping input ids to PredictorSlot
        self.predictors = OrderedDict()  # type: Dict[object, PredictorSlot]
        #: A class variable (prediction target)
        self.class_var = None  # type: Optional[Orange.data.Variable]
        #: List of (discrete) class variable's values
        self.class_values = []  # type: List[str]

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(
            box, "No data on input.\nPredictors: 0\nTask: N/A")
        self.infolabel.setMinimumWidth(150)
        gui.button(box,
                   self,
                   "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        self.classification_options = box = gui.vBox(self.controlArea,
                                                     "Show",
                                                     spacing=-1,
                                                     addSpace=False)

        gui.checkBox(box,
                     self,
                     "show_predictions",
                     "Predicted class",
                     callback=self._update_prediction_delegate)
        b = gui.checkBox(box,
                         self,
                         "show_probabilities",
                         "Predicted probabilities for:",
                         callback=self._update_prediction_delegate)
        ibox = gui.indentedBox(box,
                               sep=gui.checkButtonOffsetHint(b),
                               addSpace=False)
        gui.listBox(ibox,
                    self,
                    "selected_classes",
                    "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box,
                     self,
                     "draw_dist",
                     "Draw distribution bars",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "Data View")
        gui.checkBox(box,
                     self,
                     "show_attrs",
                     "Show full dataset",
                     callback=self._update_column_visibility)

        box = gui.vBox(self.controlArea, "Output", spacing=-1)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_attrs",
                                           "Original data",
                                           callback=self.commit)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_predictions",
                                           "Predictions",
                                           callback=self.commit)
        self.checkbox_prob = gui.checkBox(box,
                                          self,
                                          "output_probabilities",
                                          "Probabilities",
                                          callback=self.commit)

        gui.rubber(self.controlArea)

        self.vsplitter = QSplitter(orientation=Qt.Vertical,
                                   childrenCollapsible=True,
                                   handleWidth=2)

        self.splitter = QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus)
        self.predictionsview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.dataview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter.addWidget(self.splitter)
        self.vsplitter.addWidget(self.score_table.view)
        self.vsplitter.setStretchFactor(0, 5)
        self.vsplitter.setStretchFactor(1, 1)
        self.mainArea.layout().addWidget(self.vsplitter)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set the input dataset"""
        if data is not None and not data:
            data = None
            self.Warning.empty_data()
        else:
            self.Warning.empty_data.clear()

        self.data = data
        if data is None:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
            self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            self._update_column_visibility()

        self._invalidate_predictions()

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = \
                PredictorSlot(predictor, predictor.name, None)

    def _set_class_var(self):
        pred_classes = set(pred.predictor.domain.class_var
                           for pred in self.predictors.values())
        self.Error.predictors_target_mismatch.clear()
        self.Error.data_target_mismatch.clear()
        self.class_var = None
        if len(pred_classes) > 1:
            self.Error.predictors_target_mismatch()
        if len(pred_classes) == 1:
            self.class_var = pred_classes.pop()
            if self.data is not None and \
                    self.data.domain.class_var is not None and \
                    self.class_var != self.data.domain.class_var:
                self.Error.data_target_mismatch()
                self.class_var = None

        discrete_class = self.class_var is not None \
                         and self.class_var.is_discrete
        self.classification_options.setVisible(discrete_class)
        self.closeContext()
        if discrete_class:
            self.class_values = list(self.class_var.values)
            self.selected_classes = list(range(len(self.class_values)))
            self.openContext(self.class_var)
        else:
            self.class_values = []
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_var()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._update_info()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        for inputid, slot in self.predictors.items():
            if slot.results is not None \
                    and not isinstance(slot.results, str) \
                    and not numpy.isnan(slot.results.predicted[0]).all():
                continue
            try:
                pred, prob = self.predict(slot.predictor, self.data)
            except (ValueError, DomainTransformationError) as err:
                results = "{}: {}".format(slot.predictor.name, err)
            else:
                results = Orange.evaluation.Results()
                results.data = self.data
                results.domain = self.data.domain
                results.row_indices = numpy.arange(len(self.data))
                results.folds = (Ellipsis, )
                results.actual = self.data.Y
                results.predicted = pred.reshape((1, len(self.data)))
                results.probabilities = prob.reshape((1, ) + prob.shape)
            self.predictors[inputid] = slot._replace(results=results)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()

        if self.data is None or self.data.domain.class_var is None:
            scorers = []
        else:
            scorers = usable_scorers(self.data.domain.class_var)
        self.score_table.update_header(scorers)

        errors = []
        for inputid, pred in self.predictors.items():
            name = learner_name(pred.predictor)
            head = QStandardItem(name)
            #            head.setData(key, Qt.UserRole)
            row = [head]
            results = self.predictors[inputid].results
            if isinstance(results, str):
                head.setToolTip(results)
                head.setText("{} (error)".format(name))
                head.setForeground(QBrush(Qt.red))
            else:
                for scorer in scorers:
                    item = QStandardItem()
                    try:
                        score = scorer_caller(scorer, results)()[0]
                        item.setText(f"{score:.3f}")
                    except Exception as exc:  # pylint: disable=broad-except
                        item.setToolTip(str(exc))
                        if scorer.name in self.score_table.shown_scores:
                            errors.append(str(exc))
                    row.append(item)
            self.score_table.model.appendRow(row)
        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(p.results for p in self.predictors.values()
                           if isinstance(p.results, str))
        self.Error.predictor_failed(errors, shown=bool(errors))

    def _update_info(self):
        info = []
        if self.data is not None:
            info.append("Data: {} instances.".format(len(self.data)))
        else:
            info.append("Data: N/A")

        n_predictors = len(self.predictors)
        n_valid = len(self._valid_predictors())
        if n_valid != n_predictors:
            info.append("Predictors: {} (+ {} failed)".format(
                n_valid, n_predictors - n_valid))
        else:
            info.append("Predictors: {}".format(n_predictors or "N/A"))

        if self.class_var is None:
            info.append("Task: N/A")
        elif self.class_var.is_discrete:
            info.append("Task: Classification")
            self.checkbox_class.setEnabled(True)
            self.checkbox_prob.setEnabled(True)
        else:
            info.append("Task: Regression")
            self.checkbox_class.setEnabled(False)
            self.checkbox_prob.setEnabled(False)

        self.infolabel.setText("\n".join(info))

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _valid_predictors(self):
        if self.class_var is not None and self.data is not None:
            return [
                p for p in self.predictors.values()
                if p.results is not None and not isinstance(p.results, str)
            ]
        else:
            return []

    def _update_predictions_model(self):
        """Update the prediction view model."""
        if self.data is not None and self.class_var is not None:
            slots = self._valid_predictors()
            results = []
            class_var = self.class_var
            for p in slots:
                if isinstance(p.results, str):
                    continue
                values = p.results.predicted[0]
                if self.class_var.is_discrete:
                    # if values were added to class_var between building the
                    # model and predicting, add zeros for new class values,
                    # which are always at the end
                    prob = p.results.probabilities[0]
                    prob = numpy.c_[prob,
                                    numpy.zeros(
                                        (prob.shape[0], len(class_var.values) -
                                         prob.shape[1]))]
                    values = [Value(class_var, v) for v in values]
                else:
                    prob = numpy.zeros((len(values), 0))
                results.append((values, prob))
            results = list(zip(*(zip(*res) for res in results)))
            headers = [p.name for p in slots]
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.predictionsview.setModel(predmodel)
        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_column_visibility(self):
        """Update data column visibility."""
        if self.data is not None and self.class_var is not None:
            domain = self.data.domain
            first_attr = len(domain.class_vars) + len(domain.metas)

            for i in range(first_attr, first_attr + len(domain.attributes)):
                self.dataview.setColumnHidden(i, not self.show_attrs)
            if domain.class_var:
                self.dataview.setColumnHidden(0, False)

    def _update_data_sort_order(self):
        """Update data row order to match the current predictions view order"""
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort([
                    predmodel.mapToSource(predmodel.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        """Reset the row sorting to original input order."""
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        """Update the predicted probability visibility state"""
        if self.class_var is not None:
            delegate = PredictionsItemDelegate()
            if self.class_var.is_continuous:
                self._setup_delegate_continuous(delegate)
            else:
                self._setup_delegate_discrete(delegate)
                proxy = self.predictionsview.model()
                if proxy is not None:
                    proxy.setProbInd(
                        numpy.array(self.selected_classes, dtype=int))
            self.predictionsview.setItemDelegate(delegate)
            self.predictionsview.resizeColumnsToContents()
        self._update_spliter()

    def _setup_delegate_discrete(self, delegate):
        colors = [QColor(*rgb) for rgb in self.class_var.colors]
        fmt = []
        if self.show_probabilities:
            fmt.append(" : ".join("{{dist[{}]:.2f}}".format(i)
                                  for i in sorted(self.selected_classes)))
        if self.show_predictions:
            fmt.append("{value!s}")
        delegate.setFormat(" \N{RIGHTWARDS ARROW} ".join(fmt))
        if self.draw_dist and colors is not None:
            delegate.setColors(colors)
        return delegate

    def _setup_delegate_continuous(self, delegate):
        delegate.setFormat("{{value:{}}}".format(
            self.class_var.format_str[1:]))

    def _update_spliter(self):
        if self.data is None:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        w = width(self.predictionsview) + 4
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([w, w1 + w2 - w])

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = self._valid_predictors()
        if not slots or self.data.domain.class_var is None:
            self.Outputs.evaluation_results.send(None)
            return

        class_var = self.class_var
        nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
        data = self.data[~nanmask]
        results = Orange.evaluation.Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if class_var and class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        slots = self._valid_predictors()
        if not slots:
            self.Outputs.predictions.send(None)
            return

        if self.class_var and self.class_var.is_discrete:
            newmetas, newcolumns = self._classification_output_columns()
        else:
            newmetas, newcolumns = self._regression_output_columns()

        attrs = list(self.data.domain.attributes) if self.output_attrs else []
        metas = list(self.data.domain.metas) + newmetas
        domain = \
            Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns
        self.Outputs.predictions.send(predictions)

    def _classification_output_columns(self):
        newmetas = []
        newcolumns = []
        slots = self._valid_predictors()
        if self.output_predictions:
            newmetas += [
                DiscreteVariable(name=p.name, values=self.class_values)
                for p in slots
            ]
            newcolumns += [
                p.results.predicted[0].reshape((-1, 1)) for p in slots
            ]

        if self.output_probabilities:
            newmetas += [
                ContinuousVariable(name="%s (%s)" % (p.name, value))
                for p in slots for value in self.class_values
            ]
            newcolumns += [p.results.probabilities[0] for p in slots]
        return newmetas, newcolumns

    def _regression_output_columns(self):
        slots = self._valid_predictors()
        newmetas = [ContinuousVariable(name=p.name) for p in slots]
        newcolumns = [p.results.predicted[0].reshape((-1, 1)) for p in slots]
        return newmetas, newcolumns

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_model = self.predictionsview.model()

            # use ItemDelegate to style prediction values
            style = lambda x: self.predictionsview.itemDelegate().displayText(
                x, QLocale())

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [style(predictions_model.data(predictions_model.index(i, j)))
                       for j in range(predictions_model.columnCount())] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data is not None and self.class_var is not None:
            text = self.infolabel.text().replace('\n', '<br>')
            if self.show_probabilities and self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)

    @classmethod
    def predict(cls, predictor, data):
        class_var = predictor.domain.class_var
        if class_var:
            if class_var.is_discrete:
                return cls.predict_discrete(predictor, data)
            else:
                return cls.predict_continuous(predictor, data)
        return None

    @staticmethod
    def predict_discrete(predictor, data):
        return predictor(data, Model.ValueProbs)

    @staticmethod
    def predict_continuous(predictor, data):
        values = predictor(data, Model.Value)
        return values, numpy.zeros((len(data), 0))