Пример #1
0
class OWClassificationTree(widget.OWWidget):
    name = "Classification Tree"
    icon = "icons/ClassificationTree.svg"
    priority = 30

    inputs = [("Data", Orange.data.Table, "set_data"),
              ("Preprocessor", Orange.preprocess.Preprocess,
               "set_preprocessor")]

    outputs = [("Learner", tree.TreeLearner),
               ("Classification Tree", tree.TreeClassifier)]
    want_main_area = False

    model_name = Setting("Classification Tree")
    attribute_score = Setting(0)
    limit_min_leaf = Setting(True)
    min_leaf = Setting(2)
    limit_min_internal = Setting(True)
    min_internal = Setting(5)
    limit_depth = Setting(True)
    max_depth = Setting(100)

    scores = (("Entropy", "entropy"), ("Gini Index", "gini"))

    def __init__(self):
        super().__init__()

        self.data = None
        self.learner = None
        self.preprocessors = None
        self.classifier = None

        gui.lineEdit(self.controlArea,
                     self,
                     'model_name',
                     box='Name',
                     tooltip='The name will identify this model in other '
                     'widgets')

        gui.comboBox(self.controlArea,
                     self,
                     "attribute_score",
                     box='Feature selection',
                     items=[name for name, _ in self.scores])

        box = gui.widgetBox(self.controlArea, 'Pruning')
        gui.spin(box,
                 self,
                 "min_leaf",
                 1,
                 1000,
                 label="Min. instances in leaves ",
                 checked="limit_min_leaf")
        gui.spin(box,
                 self,
                 "min_internal",
                 1,
                 1000,
                 label="Stop splitting nodes with less instances than ",
                 checked="limit_min_internal")
        gui.spin(box,
                 self,
                 "max_depth",
                 1,
                 1000,
                 label="Limit the depth to ",
                 checked="limit_depth")

        self.btn_apply = gui.button(self.controlArea,
                                    self,
                                    "&Apply",
                                    callback=self.set_learner,
                                    disabled=0,
                                    default=True)

        gui.rubber(self.controlArea)
        self.resize(100, 100)

        self.set_learner()

    def sendReport(self):
        self.reportSettings(
            "Model parameters",
            [("Attribute selection", self.scores[self.attribute_score][0]),
             ("Pruning", ", ".join(s for s, c in (
                 ("%i instances in leaves" % self.min_leaf,
                  self.limit_min_leaf),
                 ("%i instance in internal node" % self.min_internal,
                  self.limit_min_internal),
                 ("maximum depth %i" % self.max_depth, self.limit_depth)) if c)
              or ": None")])
        self.reportData(self.data)

    def set_learner(self):
        self.learner = tree.TreeLearner(
            criterion=self.scores[self.attribute_score][1],
            max_depth=self.max_depth,
            min_samples_split=self.min_internal,
            min_samples_leaf=self.min_leaf,
            preprocessors=self.preprocessors)

        self.learner.name = self.model_name

        self.send("Learner", self.learner)

        self.error(1)
        if self.data is not None:
            try:
                self.classifier = self.learner(self.data)
                self.classifier.name = self.model_name
                self.classifier.instances = self.data
            except Exception as errValue:
                self.error(1, str(errValue))
                self.classifier = None
        else:
            self.classifier = None
        self.send("Classification Tree", self.classifier)

    def set_data(self, data):
        self.error(0)
        self.data = data
        if data is not None and data.domain.class_var is None:
            self.error(0, "Data has no target variable")
            self.data = None
        self.set_learner()

    def set_preprocessor(self, preproc):
        if preproc is None:
            self.preprocessors = None
        else:
            self.preprocessors = (preproc, )
        self.set_learner()
Пример #2
0
class OWSelectRows(widget.OWWidget):
    name = "选择行(Select Rows)"
    id = "Orange.widgets.data.file"
    description = "根据变量值从数据中选择行。"
    icon = "icons/SelectRows.svg"
    priority = 100
    category = "Data"
    keywords = ["filter"]

    class Inputs:
        data = Input("数据(Data)", Table, replaces=['Data'])

    class Outputs:
        matching_data = Output("匹配的数据(Matching Data)", Table, default=True, replaces=['Matching Data'])
        unmatched_data = Output("不匹配的数据(Unmatched Data)", Table, replaces=['Unmatched Data'])
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_Chinese_NAME, Table, replaces=['Data'])

    want_main_area = False

    settingsHandler = SelectRowsContextHandler()
    conditions = ContextSetting([])
    update_on_change = Setting(True)
    purge_attributes = Setting(False, schema_only=True)
    purge_classes = Setting(False, schema_only=True)
    auto_commit = Setting(True)

    settings_version = 2

    Operators = {
        ContinuousVariable: [
            (FilterContinuous.Equal, "等于(equals)"),
            (FilterContinuous.NotEqual, "不是(is not)"),
            (FilterContinuous.Less, "低于(is below)"),
            (FilterContinuous.LessEqual, "最多(is at most)"),
            (FilterContinuous.Greater, "大于(is greater than)"),
            (FilterContinuous.GreaterEqual, "至少(is at least)"),
            (FilterContinuous.Between, "介于(is between)"),
            (FilterContinuous.Outside, "超出(is outside)"),
            (FilterContinuous.IsDefined, "已定义(is defined)"),
        ],
        DiscreteVariable: [
            (FilterDiscreteType.Equal, "是(is)"),
            (FilterDiscreteType.NotEqual, "不是(is not)"),
            (FilterDiscreteType.In, "是...中的一个(is one of)"),
            (FilterDiscreteType.IsDefined, "已定义(is defined)")
        ],
        StringVariable: [
            (FilterString.Equal, "等于(equals)"),
            (FilterString.NotEqual, "不是(is not)"),
            (FilterString.Less, "在...之前(is before)"),
            (FilterString.LessEqual, "等于或早于(is equal or before)"),
            (FilterString.Greater, "在之后(is after)"),
            (FilterString.GreaterEqual, "等于或大于(is equal or after)"),
            (FilterString.Between, "介于(is between)"),
            (FilterString.Outside, "超出(is outside)"),
            (FilterString.Contains, "包含(contains)"),
            (FilterString.StartsWith, "以...开始(begins with)"),
            (FilterString.EndsWith, "以...结尾(ends with)"),
            (FilterString.IsDefined, "已定义(is defined)"),
        ]
    }

    Operators[TimeVariable] = Operators[ContinuousVariable]

    AllTypes = {}
    for _all_name, _all_type, _all_ops in (
            ("所有变量", 0,
             [(None, "are defined")]),
            ("所有数值变量", 2,
             [(v, _plural(t)) for v, t in Operators[ContinuousVariable]]),
            ("所有字符串变量", 3,
             [(v, _plural(t)) for v, t in Operators[StringVariable]])):
        Operators[_all_name] = _all_ops
        AllTypes[_all_name] = _all_type

    operator_names = {vtype: [name for _, name in filters]
                      for vtype, filters in Operators.items()}

    class Error(widget.OWWidget.Error):
        parsing_error = Msg("{}")

    def __init__(self):
        super().__init__()

        self.old_purge_classes = True

        self.conditions = []
        self.last_output_conditions = None
        self.data = None
        self.data_desc = self.match_desc = self.nonmatch_desc = None

        box = gui.vBox(self.controlArea, '条件', stretch=100)
        self.cond_list = QTableWidget(
            box, showGrid=False, selectionMode=QTableWidget.NoSelection)
        box.layout().addWidget(self.cond_list)
        self.cond_list.setColumnCount(4)
        self.cond_list.setRowCount(0)
        self.cond_list.verticalHeader().hide()
        self.cond_list.horizontalHeader().hide()
        for i in range(3):
            self.cond_list.horizontalHeader().setSectionResizeMode(i, QHeaderView.Stretch)
        self.cond_list.horizontalHeader().resizeSection(3, 30)
        self.cond_list.viewport().setBackgroundRole(QPalette.Window)

        box2 = gui.hBox(box)
        gui.rubber(box2)
        # TODO: button style
        self.add_button = gui.button(
            box2, self, "添加条件", callback=self.add_row)
        self.add_all_button = gui.button(
            box2, self, "添加所有变量", callback=self.add_all)
        self.remove_all_button = gui.button(
            box2, self, "删除全部", callback=self.remove_all)
        gui.rubber(box2)

        boxes = gui.widgetBox(self.controlArea, orientation=QGridLayout())
        layout = boxes.layout()
        layout.setColumnStretch(0, 1)
        layout.setColumnStretch(1, 1)

        box_data = gui.vBox(boxes, '数据', addToLayout=False)
        self.data_in_variables = gui.widgetLabel(box_data, " ")
        self.data_out_rows = gui.widgetLabel(box_data, " ")
        layout.addWidget(box_data, 0, 0)

        box_setting = gui.vBox(boxes, '清除', addToLayout=False)
        self.cb_pa = gui.checkBox(
            box_setting, self, "purge_attributes", "删除未使用的特征",
            callback=self.conditions_changed)
        gui.separator(box_setting, height=1)
        self.cb_pc = gui.checkBox(
            box_setting, self, "purge_classes", "删除未使用的分类",
            callback=self.conditions_changed)
        layout.addWidget(box_setting, 0, 1)

        self.report_button.setFixedWidth(120)
        gui.rubber(self.buttonsArea.layout())
        layout.addWidget(self.buttonsArea, 1, 0)

        acbox = gui.auto_send(None, self, "auto_commit")
        layout.addWidget(acbox, 1, 1)

        self.set_data(None)
        self.resize(600, 400)

    def add_row(self, attr=None, condition_type=None, condition_value=None):
        model = self.cond_list.model()
        row = model.rowCount()
        model.insertRow(row)

        attr_combo = gui.OrangeComboBox(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
        attr_combo.row = row
        for var in self._visible_variables(self.data.domain):
            if isinstance(var, Variable):
                attr_combo.addItem(*gui.attributeItem(var))
            else:
                attr_combo.addItem(var)
        if isinstance(attr, str):
            attr_combo.setCurrentText(attr)
        else:
            attr_combo.setCurrentIndex(
                attr or
                len(self.AllTypes) - (attr_combo.count() == len(self.AllTypes)))
        self.cond_list.setCellWidget(row, 0, attr_combo)

        index = QPersistentModelIndex(model.index(row, 3))
        temp_button = QPushButton('×', self, flat=True,
                                  styleSheet='* {font-size: 16pt; color: silver}'
                                             '*:hover {color: black}')
        temp_button.clicked.connect(lambda: self.remove_one(index.row()))
        self.cond_list.setCellWidget(row, 3, temp_button)

        self.remove_all_button.setDisabled(False)
        self.set_new_operators(attr_combo, attr is not None,
                               condition_type, condition_value)
        attr_combo.currentIndexChanged.connect(
            lambda _: self.set_new_operators(attr_combo, False))

        self.cond_list.resizeRowToContents(row)

    @classmethod
    def _visible_variables(cls, domain):
        """Generate variables in order they should be presented in in combos."""
        return chain(
            cls.AllTypes,
            filter_visible(chain(domain.class_vars,
                                 domain.metas,
                                 domain.attributes)))

    def add_all(self):
        if self.cond_list.rowCount():
            Mb = QMessageBox
            if Mb.question(
                    self, "删除现有过滤器",
                    "这将用所有变量的过滤器替换现有的过滤器。", Mb.Ok | Mb.Cancel) != Mb.Ok:
                return
            self.remove_all()
        domain = self.data.domain
        for i in range(len(domain.variables) + len(domain.metas)):
            self.add_row(i)

    def remove_one(self, rownum):
        self.remove_one_row(rownum)
        self.conditions_changed()

    def remove_all(self):
        self.remove_all_rows()
        self.conditions_changed()

    def remove_one_row(self, rownum):
        self.cond_list.removeRow(rownum)
        if self.cond_list.model().rowCount() == 0:
            self.remove_all_button.setDisabled(True)

    def remove_all_rows(self):
        self.cond_list.clear()
        self.cond_list.setRowCount(0)
        self.remove_all_button.setDisabled(True)

    def set_new_operators(self, attr_combo, adding_all,
                          selected_index=None, selected_values=None):
        oper_combo = QComboBox()
        oper_combo.row = attr_combo.row
        oper_combo.attr_combo = attr_combo
        attr_name = attr_combo.currentText()
        if attr_name in self.AllTypes:
            oper_combo.addItems(self.operator_names[attr_name])
        else:
            var = self.data.domain[attr_name]
            oper_combo.addItems(self.operator_names[type(var)])
        oper_combo.setCurrentIndex(selected_index or 0)
        self.cond_list.setCellWidget(oper_combo.row, 1, oper_combo)
        self.set_new_values(oper_combo, adding_all, selected_values)
        oper_combo.currentIndexChanged.connect(
            lambda _: self.set_new_values(oper_combo, False))

    @staticmethod
    def _get_lineedit_contents(box):
        return [child.text() for child in getattr(box, "controls", [box])
                if isinstance(child, QLineEdit)]

    @staticmethod
    def _get_value_contents(box):
        cont = []
        names = []
        for child in getattr(box, "controls", [box]):
            if isinstance(child, QLineEdit):
                cont.append(child.text())
            elif isinstance(child, QComboBox):
                cont.append(child.currentIndex())
            elif isinstance(child, QToolButton):
                if child.popup is not None:
                    model = child.popup.list_view.model()
                    for row in range(model.rowCount()):
                        item = model.item(row)
                        if item.checkState():
                            cont.append(row + 1)
                            names.append(item.text())
                    child.desc_text = ', '.join(names)
                    child.set_text()
            elif isinstance(child, QLabel) or child is None:
                pass
            else:
                raise TypeError('Type %s not supported.' % type(child))
        return tuple(cont)

    class QDoubleValidatorEmpty(QDoubleValidator):
        def validate(self, input_, pos):
            if not input_:
                return QDoubleValidator.Acceptable, input_, pos
            if self.locale().groupSeparator() in input_:
                return QDoubleValidator.Invalid, input_, pos
            return super().validate(input_, pos)

    def set_new_values(self, oper_combo, adding_all, selected_values=None):
        # def remove_children():
        #     for child in box.children()[1:]:
        #         box.layout().removeWidget(child)
        #         child.setParent(None)

        def add_textual(contents):
            le = gui.lineEdit(box, self, None,
                              sizePolicy=QSizePolicy(QSizePolicy.Expanding,
                                                     QSizePolicy.Expanding))
            if contents:
                le.setText(contents)
            le.setAlignment(Qt.AlignRight)
            le.editingFinished.connect(self.conditions_changed)
            return le

        def add_numeric(contents):
            le = add_textual(contents)
            le.setValidator(OWSelectRows.QDoubleValidatorEmpty())
            return le

        def add_datetime(contents):
            le = add_textual(contents)
            le.setValidator(QRegExpValidator(QRegExp(TimeVariable.REGEX)))
            return le

        box = self.cond_list.cellWidget(oper_combo.row, 2)
        lc = ["", ""]
        oper = oper_combo.currentIndex()
        attr_name = oper_combo.attr_combo.currentText()
        if attr_name in self.AllTypes:
            vtype = self.AllTypes[attr_name]
            var = None
        else:
            var = self.data.domain[attr_name]
            vtype = vartype(var)
            if selected_values is not None:
                lc = list(selected_values) + ["", ""]
                lc = [str(x) for x in lc[:2]]
        if box and vtype == box.var_type:
            lc = self._get_lineedit_contents(box) + lc

        if oper_combo.currentText().endswith(" defined"):
            label = QLabel()
            label.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, label)
        elif var is not None and var.is_discrete:
            if oper_combo.currentText().endswith(" one of"):
                if selected_values:
                    lc = [x for x in list(selected_values)]
                button = DropDownToolButton(self, var, lc)
                button.var_type = vtype
                self.cond_list.setCellWidget(oper_combo.row, 2, button)
            else:
                combo = QComboBox()
                combo.addItems([""] + var.values)
                if lc[0]:
                    combo.setCurrentIndex(int(lc[0]))
                else:
                    combo.setCurrentIndex(0)
                combo.var_type = vartype(var)
                self.cond_list.setCellWidget(oper_combo.row, 2, combo)
                combo.currentIndexChanged.connect(self.conditions_changed)
        else:
            box = gui.hBox(self, addToLayout=False)
            box.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, box)
            if vtype in (2, 4):  # continuous, time:
                validator = add_datetime if isinstance(var, TimeVariable) else add_numeric
                box.controls = [validator(lc[0])]
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(validator(lc[1]))
            elif vtype == 3:  # string:
                box.controls = [add_textual(lc[0])]
                if oper in [6, 7]:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_textual(lc[1]))
            else:
                box.controls = []
        if not adding_all:
            self.conditions_changed()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.cb_pa.setEnabled(not isinstance(data, SqlTable))
        self.cb_pc.setEnabled(not isinstance(data, SqlTable))
        self.remove_all_rows()
        self.add_button.setDisabled(data is None)
        self.add_all_button.setDisabled(
            data is None or
            len(data.domain.variables) + len(data.domain.metas) > 100)
        if not data:
            self.data_desc = None
            self.commit()
            return
        self.data_desc = report.describe_data_brief(data)
        self.conditions = []
        try:
            self.openContext(data)
        except Exception:
            pass

        variables = list(self._visible_variables(self.data.domain))
        varnames = [v.name if isinstance(v, Variable) else v for v in variables]
        if self.conditions:
            for attr, cond_type, cond_value in self.conditions:
                if attr in varnames:
                    self.add_row(varnames.index(attr), cond_type, cond_value)
                elif attr in self.AllTypes:
                    self.add_row(attr, cond_type, cond_value)
        else:
            self.add_row()

        self.update_info(data, self.data_in_variables, "输入: ")
        self.unconditional_commit()

    def conditions_changed(self):
        try:
            self.conditions = []
            self.conditions = [
                (self.cond_list.cellWidget(row, 0).currentText(),
                 self.cond_list.cellWidget(row, 1).currentIndex(),
                 self._get_value_contents(self.cond_list.cellWidget(row, 2)))
                for row in range(self.cond_list.rowCount())]
            if self.update_on_change and (
                    self.last_output_conditions is None or
                    self.last_output_conditions != self.conditions):
                self.commit()
        except AttributeError:
            # Attribute error appears if the signal is triggered when the
            # controls are being constructed
            pass

    def _values_to_floats(self, attr, values):
        if not len(values):
            return values
        if not all(values):
            return None
        if isinstance(attr, TimeVariable):
            parse = lambda x: (attr.parse(x), True)
        else:
            parse = QLocale().toDouble

        try:
            floats, ok = zip(*[parse(v) for v in values])
            if not all(ok):
                raise ValueError('Some values could not be parsed as floats'
                                 'in the current locale: {}'.format(values))
        except TypeError:
            floats = values  # values already floats
        assert all(isinstance(v, float) for v in floats)
        return floats

    def commit(self):
        matching_output = self.data
        non_matching_output = None
        annotated_output = None

        self.Error.clear()
        if self.data:
            domain = self.data.domain
            conditions = []
            for attr_name, oper_idx, values in self.conditions:
                if attr_name in self.AllTypes:
                    attr_index = attr = None
                    attr_type = self.AllTypes[attr_name]
                    operators = self.Operators[attr_name]
                else:
                    attr_index = domain.index(attr_name)
                    attr = domain[attr_index]
                    attr_type = vartype(attr)
                    operators = self.Operators[type(attr)]
                opertype, _ = operators[oper_idx]
                if attr_type == 0:
                    filter = data_filter.IsDefined()
                elif attr_type in (2, 4):  # continuous, time
                    try:
                        floats = self._values_to_floats(attr, values)
                    except ValueError as e:
                        self.Error.parsing_error(e.args[0])
                        return
                    if floats is None:
                        continue
                    filter = data_filter.FilterContinuous(
                        attr_index, opertype, *floats)
                elif attr_type == 3:  # string
                    filter = data_filter.FilterString(
                        attr_index, opertype, *[str(v) for v in values])
                else:
                    if opertype == FilterDiscreteType.IsDefined:
                        f_values = None
                    else:
                        if not values or not values[0]:
                            continue
                        values = [attr.values[i-1] for i in values]
                        if opertype == FilterDiscreteType.Equal:
                            f_values = {values[0]}
                        elif opertype == FilterDiscreteType.NotEqual:
                            f_values = set(attr.values)
                            f_values.remove(values[0])
                        elif opertype == FilterDiscreteType.In:
                            f_values = set(values)
                        else:
                            raise ValueError("invalid operand")
                    filter = data_filter.FilterDiscrete(attr_index, f_values)
                conditions.append(filter)

            if conditions:
                self.filters = data_filter.Values(conditions)
                matching_output = self.filters(self.data)
                self.filters.negate = True
                non_matching_output = self.filters(self.data)

                row_sel = np.in1d(self.data.ids, matching_output.ids)
                annotated_output = create_annotated_table(self.data, row_sel)

            # if hasattr(self.data, "name"):
            #     matching_output.name = self.data.name
            #     non_matching_output.name = self.data.name

            purge_attrs = self.purge_attributes
            purge_classes = self.purge_classes
            if (purge_attrs or purge_classes) and \
                    not isinstance(self.data, SqlTable):
                attr_flags = sum([Remove.RemoveConstant * purge_attrs,
                                  Remove.RemoveUnusedValues * purge_attrs])
                class_flags = sum([Remove.RemoveConstant * purge_classes,
                                   Remove.RemoveUnusedValues * purge_classes])
                # same settings used for attributes and meta features
                remover = Remove(attr_flags, class_flags, attr_flags)

                matching_output = remover(matching_output)
                non_matching_output = remover(non_matching_output)
                annotated_output = remover(annotated_output)

        if matching_output is not None and not len(matching_output):
            matching_output = None
        if non_matching_output is not None and not len(non_matching_output):
            non_matching_output = None
        if annotated_output is not None and not len(annotated_output):
            annotated_output = None

        self.Outputs.matching_data.send(matching_output)
        self.Outputs.unmatched_data.send(non_matching_output)
        self.Outputs.annotated_data.send(annotated_output)

        self.match_desc = report.describe_data_brief(matching_output)
        self.nonmatch_desc = report.describe_data_brief(non_matching_output)

        self.update_info(matching_output, self.data_out_rows, "输出: ")

    def update_info(self, data, lab1, label):
        def sp(s, capitalize=True):
            return s and s or ("No" if capitalize else "no"), #"s" * (s != 1)

        if data is None:
            lab1.setText("")
        else:
            lab1.setText(label + "~%s 行, %s 个变量" %
                         (sp(data.approx_len()) +
                          sp(len(data.domain.variables) +
                             len(data.domain.metas)))
                        )

    def send_report(self):
        if not self.data:
            self.report_paragraph("No data.")
            return

        pdesc = None
        describe_domain = False
        for d in (self.data_desc, self.match_desc, self.nonmatch_desc):
            if not d or not d["Data instances"]:
                continue
            ndesc = d.copy()
            del ndesc["Data instances"]
            if pdesc is not None and pdesc != ndesc:
                describe_domain = True
            pdesc = ndesc

        conditions = []
        domain = self.data.domain
        for attr_name, oper, values in self.conditions:
            if attr_name in self.AllTypes:
                attr = attr_name
                names = self.operator_names[attr_name]
                var_type = self.AllTypes[attr_name]
            else:
                attr = domain[attr_name]
                var_type = vartype(attr)
                names = self.operator_names[type(attr)]
            name = names[oper]
            if oper == len(names) - 1:
                conditions.append("{} {}".format(attr, name))
            elif var_type == 1:  # discrete
                if name == "is one of":
                    valnames = [attr.values[v - 1] for v in values]
                    if not valnames:
                        continue
                    if len(valnames) == 1:
                        valstr = valnames[0]
                    else:
                        valstr = f"{', '.join(valnames[:-1])} or {valnames[-1]}"
                    conditions.append(f"{attr} is {valstr}")
                elif values and values[0]:
                    value = values[0] - 1
                    conditions.append(f"{attr} {name} {attr.values[value]}")
            elif var_type == 3:  # string variable
                conditions.append(
                    f"{attr} {name} {' and '.join(map(repr, values))}")
            elif all(x for x in values):  # numeric variable
                conditions.append(f"{attr} {name} {' and '.join(values)}")
        items = OrderedDict()
        if describe_domain:
            items.update(self.data_desc)
        else:
            items["Instances"] = self.data_desc["Data instances"]
        items["Condition"] = " AND ".join(conditions) or "no conditions"
        self.report_items("Data", items)
        if describe_domain:
            self.report_items("Matching data", self.match_desc)
            self.report_items("Non-matching data", self.nonmatch_desc)
        else:
            match_inst = \
                bool(self.match_desc) and \
                self.match_desc["Data instances"]
            nonmatch_inst = \
                bool(self.nonmatch_desc) and \
                self.nonmatch_desc["Data instances"]
            self.report_items(
                "Output",
                (("Matching data",
                  "{} instances".format(match_inst) if match_inst else "None"),
                 ("Non-matching data",
                  nonmatch_inst > 0 and "{} instances".format(nonmatch_inst))))
Пример #3
0
class OWPaintData(OWWidget):
    TOOLS = [("Brush", "Create multiple instances", AirBrushTool,
              _icon("brush.svg")),
             ("Put", "Put individual instances", PutInstanceTool,
              _icon("put.svg")),
             ("Select", "Select and move instances", SelectTool,
              _icon("select-transparent_42px.png")),
             ("Jitter", "Jitter instances", JitterTool, _icon("jitter.svg")),
             ("Magnet", "Attract multiple instances", MagnetTool,
              _icon("magnet.svg")),
             ("Clear", "Clear the plot", ClearTool,
              _icon("../../../icons/Dlg_clear.png"))]

    name = "Paint Data"
    description = "Create data by painting data points on a plane."
    icon = "icons/PaintData.svg"
    priority = 60
    keywords = ["create", "draw"]

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    autocommit = Setting(True)
    table_name = Setting("Painted data")
    attr1 = Setting("x")
    attr2 = Setting("y")
    hasAttr2 = Setting(True)

    brushRadius = Setting(75)
    density = Setting(7)
    symbol_size = Setting(10)

    #: current data array (shape=(N, 3)) as presented on the output
    data = Setting(None, schema_only=True)
    labels = Setting(["C1", "C2"], schema_only=True)

    graph_name = "plot"

    class Warning(OWWidget.Warning):
        no_input_variables = Msg("Input data has no variables")
        continuous_target = Msg("Continuous target value can not be used.")
        sparse_not_supported = Msg("Sparse data is ignored.")

    class Information(OWWidget.Information):
        use_first_two = \
            Msg("Paint Data uses data from the first two attributes.")

    def __init__(self):
        super().__init__()

        self.input_data = None
        self.input_classes = []
        self.input_colors = None
        self.input_has_attr2 = True
        self.current_tool = None
        self._selected_indices = None
        self._scatter_item = None
        #: A private data buffer (can be modified in place). `self.data` is
        #: a copy of this array (as seen when the `invalidate` method is
        #: called
        self.__buffer = None

        self.undo_stack = QUndoStack(self)

        self.class_model = ColoredListModel(
            self.labels,
            self,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable)

        self.class_model.dataChanged.connect(self._class_value_changed)
        self.class_model.rowsInserted.connect(self._class_count_changed)
        self.class_model.rowsRemoved.connect(self._class_count_changed)

        if not self.data:
            self.data = []
            self.__buffer = np.zeros((0, 3))
        elif isinstance(self.data, np.ndarray):
            self.__buffer = self.data.copy()
            self.data = self.data.tolist()
        else:
            self.__buffer = np.array(self.data)

        self.colors = colorpalette.ColorPaletteGenerator(
            len(colorpalette.DefaultRGBColors))
        self.tools_cache = {}

        self._init_ui()
        self.commit()

    def _init_ui(self):
        namesBox = gui.vBox(self.controlArea, "Names")

        hbox = gui.hBox(namesBox, margin=0, spacing=0)
        gui.lineEdit(hbox,
                     self,
                     "attr1",
                     "Variable X: ",
                     controlWidth=80,
                     orientation=Qt.Horizontal,
                     callback=self._attr_name_changed)
        gui.separator(hbox, 21)
        hbox = gui.hBox(namesBox, margin=0, spacing=0)
        attr2 = gui.lineEdit(hbox,
                             self,
                             "attr2",
                             "Variable Y: ",
                             controlWidth=80,
                             orientation=Qt.Horizontal,
                             callback=self._attr_name_changed)
        gui.separator(hbox)
        gui.checkBox(hbox,
                     self,
                     "hasAttr2",
                     '',
                     disables=attr2,
                     labelWidth=0,
                     callback=self.set_dimensions)
        gui.separator(namesBox)

        gui.widgetLabel(namesBox, "Labels")
        self.classValuesView = listView = gui.ListViewWithSizeHint(
            preferred_size=(-1, 30))
        listView.setModel(self.class_model)
        itemmodels.select_row(listView, 0)
        namesBox.layout().addWidget(listView)

        self.addClassLabel = QAction("+",
                                     self,
                                     toolTip="Add new class label",
                                     triggered=self.add_new_class_label)

        self.removeClassLabel = QAction(
            unicodedata.lookup("MINUS SIGN"),
            self,
            toolTip="Remove selected class label",
            triggered=self.remove_selected_class_label)

        actionsWidget = itemmodels.ModelActionsWidget(
            [self.addClassLabel, self.removeClassLabel], self)
        actionsWidget.layout().addStretch(10)
        actionsWidget.layout().setSpacing(1)
        namesBox.layout().addWidget(actionsWidget)

        tBox = gui.vBox(self.controlArea, "Tools", addSpace=True)
        buttonBox = gui.hBox(tBox)
        toolsBox = gui.widgetBox(buttonBox, orientation=QGridLayout())

        self.toolActions = QActionGroup(self)
        self.toolActions.setExclusive(True)
        self.toolButtons = []

        for i, (name, tooltip, tool, icon) in enumerate(self.TOOLS):
            action = QAction(
                name,
                self,
                toolTip=tooltip,
                checkable=tool.checkable,
                icon=QIcon(icon),
            )
            action.triggered.connect(partial(self.set_current_tool, tool))

            button = QToolButton(iconSize=QSize(24, 24),
                                 toolButtonStyle=Qt.ToolButtonTextUnderIcon,
                                 sizePolicy=QSizePolicy(
                                     QSizePolicy.MinimumExpanding,
                                     QSizePolicy.Fixed))
            button.setDefaultAction(action)
            self.toolButtons.append((button, tool))

            toolsBox.layout().addWidget(button, i / 3, i % 3)
            self.toolActions.addAction(action)

        for column in range(3):
            toolsBox.layout().setColumnMinimumWidth(column, 10)
            toolsBox.layout().setColumnStretch(column, 1)

        undo = self.undo_stack.createUndoAction(self)
        redo = self.undo_stack.createRedoAction(self)

        undo.setShortcut(QKeySequence.Undo)
        redo.setShortcut(QKeySequence.Redo)

        self.addActions([undo, redo])
        self.undo_stack.indexChanged.connect(lambda _: self.invalidate())

        gui.separator(tBox)
        indBox = gui.indentedBox(tBox, sep=8)
        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        indBox.layout().addLayout(form)
        slider = gui.hSlider(indBox,
                             self,
                             "brushRadius",
                             minValue=1,
                             maxValue=100,
                             createLabel=False)
        form.addRow("Radius:", slider)

        slider = gui.hSlider(indBox,
                             self,
                             "density",
                             None,
                             minValue=1,
                             maxValue=100,
                             createLabel=False)

        form.addRow("Intensity:", slider)

        slider = gui.hSlider(indBox,
                             self,
                             "symbol_size",
                             None,
                             minValue=1,
                             maxValue=100,
                             createLabel=False,
                             callback=self.set_symbol_size)

        form.addRow("Symbol:", slider)

        self.btResetToInput = gui.button(tBox, self, "Reset to Input Data",
                                         self.reset_to_input)
        self.btResetToInput.setDisabled(True)

        gui.auto_commit(self.left_side, self, "autocommit", "Send")

        # main area GUI
        viewbox = PaintViewBox(enableMouse=False)
        self.plotview = pg.PlotWidget(background="w", viewBox=viewbox)
        self.plotview.sizeHint = lambda: QSize(
            200, 100)  # Minimum size for 1-d painting
        self.plot = self.plotview.getPlotItem()

        axis_color = self.palette().color(QPalette.Text)
        axis_pen = QPen(axis_color)

        tickfont = QFont(self.font())
        tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))

        axis = self.plot.getAxis("bottom")
        axis.setLabel(self.attr1)
        axis.setPen(axis_pen)
        axis.setTickFont(tickfont)

        axis = self.plot.getAxis("left")
        axis.setLabel(self.attr2)
        axis.setPen(axis_pen)
        axis.setTickFont(tickfont)
        if not self.hasAttr2:
            self.plot.hideAxis('left')

        self.plot.hideButtons()
        self.plot.setXRange(0, 1, padding=0.01)

        self.mainArea.layout().addWidget(self.plotview)

        # enable brush tool
        self.toolActions.actions()[0].setChecked(True)
        self.set_current_tool(self.TOOLS[0][2])

        self.set_dimensions()

    def set_symbol_size(self):
        if self._scatter_item:
            self._scatter_item.setSize(self.symbol_size)

    def set_dimensions(self):
        if self.hasAttr2:
            self.plot.setYRange(0, 1, padding=0.01)
            self.plot.showAxis('left')
            self.plotview.setSizePolicy(QSizePolicy.Ignored,
                                        QSizePolicy.Minimum)
        else:
            self.plot.setYRange(-.5, .5, padding=0.01)
            self.plot.hideAxis('left')
            self.plotview.setSizePolicy(QSizePolicy.Ignored,
                                        QSizePolicy.Maximum)
        self._replot()
        for button, tool in self.toolButtons:
            if tool.only2d:
                button.setDisabled(not self.hasAttr2)

    @Inputs.data
    def set_data(self, data):
        """Set the input_data and call reset_to_input"""
        def _check_and_set_data(data):
            self.clear_messages()
            if data and data.is_sparse():
                self.Warning.sparse_not_supported()
                return False
            if data is not None and len(data):
                if not data.domain.attributes:
                    self.Warning.no_input_variables()
                    data = None
                elif len(data.domain.attributes) > 2:
                    self.Information.use_first_two()
            self.input_data = data
            self.btResetToInput.setDisabled(data is None)
            return data is not None and len(data)

        if not _check_and_set_data(data):
            return

        X = np.array([scale(vals) for vals in data.X[:, :2].T]).T
        try:
            y = next(cls for cls in data.domain.class_vars if cls.is_discrete)
        except StopIteration:
            if data.domain.class_vars:
                self.Warning.continuous_target()
            self.input_classes = ["C1"]
            self.input_colors = None
            y = np.zeros(len(data))
        else:
            self.input_classes = y.values
            self.input_colors = y.colors

            y = data[:, y].Y

        self.input_has_attr2 = len(data.domain.attributes) >= 2
        if not self.input_has_attr2:
            self.input_data = np.column_stack((X, np.zeros(len(data)), y))
        else:
            self.input_data = np.column_stack((X, y))
        self.reset_to_input()
        self.unconditional_commit()

    def reset_to_input(self):
        """Reset the painting to input data if present."""
        if self.input_data is None:
            return
        self.undo_stack.clear()

        index = self.selected_class_label()
        if self.input_colors is not None:
            colors = self.input_colors
        else:
            colors = colorpalette.DefaultRGBColors
        palette = colorpalette.ColorPaletteGenerator(
            number_of_colors=len(colors), rgb_colors=colors)
        self.colors = palette
        self.class_model.colors = palette
        self.class_model[:] = self.input_classes

        newindex = min(max(index, 0), len(self.class_model) - 1)
        itemmodels.select_row(self.classValuesView, newindex)

        self.data = self.input_data.tolist()
        self.__buffer = self.input_data.copy()

        prev_attr2 = self.hasAttr2
        self.hasAttr2 = self.input_has_attr2
        if prev_attr2 != self.hasAttr2:
            self.set_dimensions()
        else:  # set_dimensions already calls _replot, no need to call it again
            self._replot()

    def add_new_class_label(self, undoable=True):

        newlabel = next(label for label in namegen('C', 1)
                        if label not in self.class_model)

        command = SimpleUndoCommand(lambda: self.class_model.append(newlabel),
                                    lambda: self.class_model.__delitem__(-1))
        if undoable:
            self.undo_stack.push(command)
        else:
            command.redo()

    def remove_selected_class_label(self):
        index = self.selected_class_label()

        if index is None:
            return

        label = self.class_model[index]
        mask = self.__buffer[:, 2] == index
        move_mask = self.__buffer[~mask][:, 2] > index

        self.undo_stack.beginMacro("Delete class label")
        self.undo_stack.push(UndoCommand(DeleteIndices(mask), self))
        self.undo_stack.push(UndoCommand(Move((move_mask, 2), -1), self))
        self.undo_stack.push(
            SimpleUndoCommand(lambda: self.class_model.__delitem__(index),
                              lambda: self.class_model.insert(index, label)))
        self.undo_stack.endMacro()

        newindex = min(max(index - 1, 0), len(self.class_model) - 1)
        itemmodels.select_row(self.classValuesView, newindex)

    def _class_count_changed(self):
        self.labels = list(self.class_model)
        self.removeClassLabel.setEnabled(len(self.class_model) > 1)
        self.addClassLabel.setEnabled(
            len(self.class_model) < self.colors.number_of_colors)
        if self.selected_class_label() is None:
            itemmodels.select_row(self.classValuesView, 0)

    def _class_value_changed(self, index, _):
        index = index.row()
        newvalue = self.class_model[index]
        oldvalue = self.labels[index]
        if newvalue != oldvalue:
            self.labels[index] = newvalue


#             command = Command(
#                 lambda: self.class_model.__setitem__(index, newvalue),
#                 lambda: self.class_model.__setitem__(index, oldvalue),
#             )
#             self.undo_stack.push(command)

    def selected_class_label(self):
        rows = self.classValuesView.selectedIndexes()
        if rows:
            return rows[0].row()
        return None

    def set_current_tool(self, tool):
        prev_tool = self.current_tool.__class__

        if self.current_tool is not None:
            self.current_tool.deactivate()
            self.current_tool.editingStarted.disconnect(
                self._on_editing_started)
            self.current_tool.editingFinished.disconnect(
                self._on_editing_finished)
            self.current_tool = None
            self.plot.getViewBox().tool = None

        if tool not in self.tools_cache:
            newtool = tool(self, self.plot)
            self.tools_cache[tool] = newtool
            newtool.issueCommand.connect(self._add_command)

        self._selected_region = QRectF()
        self.current_tool = tool = self.tools_cache[tool]
        self.plot.getViewBox().tool = tool
        tool.editingStarted.connect(self._on_editing_started)
        tool.editingFinished.connect(self._on_editing_finished)
        tool.activate()

        if not tool.checkable:
            self.set_current_tool(prev_tool)

    def _on_editing_started(self):
        self.undo_stack.beginMacro("macro")

    def _on_editing_finished(self):
        self.undo_stack.endMacro()

    def execute(self, command):
        if isinstance(command, (Append, DeleteIndices, Insert, Move)):
            if isinstance(command, (DeleteIndices, Insert)):
                self._selected_indices = None

                if isinstance(self.current_tool, SelectTool):
                    self.current_tool._reset()

            self.__buffer, undo = transform(command, self.__buffer)
            self._replot()
            return undo
        else:
            assert False, "Non normalized command"

    def _add_command(self, cmd):
        name = "Name"

        if (not self.hasAttr2
                and isinstance(cmd, (Move, MoveSelection, Jitter, Magnet))):
            # tool only supported if both x and y are enabled
            return

        if isinstance(cmd, Append):
            cls = self.selected_class_label()
            points = np.array([(p.x(), p.y() if self.hasAttr2 else 0, cls)
                               for p in cmd.points])
            self.undo_stack.push(UndoCommand(Append(points), self, text=name))
        elif isinstance(cmd, Move):
            self.undo_stack.push(UndoCommand(cmd, self, text=name))
        elif isinstance(cmd, SelectRegion):
            indices = [
                i for i, (x, y) in enumerate(self.__buffer[:, :2])
                if cmd.region.contains(QPointF(x, y))
            ]
            indices = np.array(indices, dtype=int)
            self._selected_indices = indices
        elif isinstance(cmd, DeleteSelection):
            indices = self._selected_indices
            if indices is not None and indices.size:
                self.undo_stack.push(
                    UndoCommand(DeleteIndices(indices), self, text="Delete"))
        elif isinstance(cmd, MoveSelection):
            indices = self._selected_indices
            if indices is not None and indices.size:
                self.undo_stack.push(
                    UndoCommand(Move((self._selected_indices, slice(0, 2)),
                                     np.array([cmd.delta.x(),
                                               cmd.delta.y()])),
                                self,
                                text="Move"))
        elif isinstance(cmd, DeleteIndices):
            self.undo_stack.push(UndoCommand(cmd, self, text="Delete"))
        elif isinstance(cmd, Insert):
            self.undo_stack.push(UndoCommand(cmd, self))
        elif isinstance(cmd, AirBrush):
            data = create_data(cmd.pos.x(), cmd.pos.y(),
                               self.brushRadius / 1000,
                               int(1 + self.density / 20), cmd.rstate)
            self._add_command(Append([QPointF(*p) for p in zip(*data.T)]))
        elif isinstance(cmd, Jitter):
            point = np.array([cmd.pos.x(), cmd.pos.y()])
            delta = -apply_jitter(self.__buffer[:, :2], point,
                                  self.density / 100.0, 0, cmd.rstate)
            self._add_command(Move((..., slice(0, 2)), delta))
        elif isinstance(cmd, Magnet):
            point = np.array([cmd.pos.x(), cmd.pos.y()])
            delta = -apply_attractor(self.__buffer[:, :2], point,
                                     self.density / 100.0, 0)
            self._add_command(Move((..., slice(0, 2)), delta))
        else:
            assert False, "unreachable"

    def _replot(self):
        def pen(color):
            pen = QPen(color, 1)
            pen.setCosmetic(True)
            return pen

        if self._scatter_item is not None:
            self.plot.removeItem(self._scatter_item)
            self._scatter_item = None

        x = self.__buffer[:, 0].copy()
        if self.hasAttr2:
            y = self.__buffer[:, 1].copy()
        else:
            y = np.zeros(self.__buffer.shape[0])

        colors = self.colors[self.__buffer[:, 2]]
        pens = [pen(c) for c in colors]
        brushes = [QBrush(c) for c in colors]

        self._scatter_item = pg.ScatterPlotItem(x,
                                                y,
                                                symbol="+",
                                                brush=brushes,
                                                pen=pens)
        self.plot.addItem(self._scatter_item)
        self.set_symbol_size()

    def _attr_name_changed(self):
        self.plot.getAxis("bottom").setLabel(self.attr1)
        self.plot.getAxis("left").setLabel(self.attr2)
        self.invalidate()

    def invalidate(self):
        self.data = self.__buffer.tolist()
        self.commit()

    def commit(self):
        data = np.array(self.data)
        if len(data) == 0:
            self.Outputs.data.send(None)
            return
        if self.hasAttr2:
            X, Y = data[:, :2], data[:, 2]
            attrs = (Orange.data.ContinuousVariable(self.attr1),
                     Orange.data.ContinuousVariable(self.attr2))
        else:
            X, Y = data[:, np.newaxis, 0], data[:, 2]
            attrs = (Orange.data.ContinuousVariable(self.attr1), )
        if len(np.unique(Y)) >= 2:
            domain = Orange.data.Domain(
                attrs,
                Orange.data.DiscreteVariable("Class",
                                             values=list(self.class_model)))
            data = Orange.data.Table.from_numpy(domain, X, Y)
        else:
            domain = Orange.data.Domain(attrs)
            data = Orange.data.Table.from_numpy(domain, X)
        data.name = self.table_name
        self.Outputs.data.send(data)

    def sizeHint(self):
        sh = super().sizeHint()
        return sh.expandedTo(QSize(570, 690))

    def onDeleteWidget(self):
        self.plot.clear()

    def send_report(self):
        if self.data is None:
            return
        settings = []
        if self.attr1 != "x" or self.attr2 != "y":
            settings += [("Axis x", self.attr1), ("Axis y", self.attr2)]
        settings += [("Number of points", len(self.data))]
        self.report_items("Painted data", settings)
        self.report_plot()
Пример #4
0
class OWPivot(OWWidget):
    name = "数据透视表(Pivot Table)"
    description = "根据列值重新调整数据表的形状。"
    icon = "icons/Pivot.svg"
    priority = 1000
    keywords = ["pivot", "group", "aggregate"]

    class Inputs:
        data = Input("数据(Data)", Table, default=True, replaces=['Data'])

    class Outputs:
        pivot_table = Output("数据透视表(Pivot Table)",
                             Table,
                             default=True,
                             replaces=['Pivot Table'])
        filtered_data = Output("筛选的数据(Filtered Data)",
                               Table,
                               replaces=['Filtered Data'])
        grouped_data = Output("分组数据(Grouped Data)",
                              Table,
                              replaces=['Grouped Data'])

    class Warning(OWWidget.Warning):
        # TODO - inconsistent for different variable types
        no_col_feature = Msg("Column feature should be selected.")
        cannot_aggregate = Msg("({}) 无法执行.")

    settingsHandler = DomainContextHandler()
    row_feature = ContextSetting(None)
    col_feature = ContextSetting(None)
    val_feature = ContextSetting(None)
    sel_agg_functions = Setting(set([Pivot.Count]))
    selection = ContextSetting(set())
    auto_commit = Setting(True)

    AGGREGATIONS = (Pivot.Count, Pivot.Count_defined, None, Pivot.Sum,
                    Pivot.Mean, Pivot.Mode, Pivot.Min, Pivot.Max, Pivot.Median,
                    Pivot.Var, None, Pivot.Majority)

    def __init__(self):
        super().__init__()
        self.data = None  # type: Table
        self.pivot = None  # type: Pivot
        self._add_control_area_controls()
        self._add_main_area_controls()

    def _add_control_area_controls(self):
        box = gui.vBox(self.controlArea, "行")
        gui.comboBox(box,
                     self,
                     "row_feature",
                     contentsLength=12,
                     model=DomainModel(valid_types=DomainModel.PRIMITIVE),
                     callback=self.__feature_changed)
        box = gui.vBox(self.controlArea, "列")
        gui.comboBox(box,
                     self,
                     "col_feature",
                     contentsLength=12,
                     model=DomainModel(placeholder="(与行相同)",
                                       valid_types=DiscreteVariable),
                     callback=self.__feature_changed)
        box = gui.vBox(self.controlArea, "值")
        gui.comboBox(box,
                     self,
                     "val_feature",
                     contentsLength=12,
                     model=DomainModel(placeholder="(None)"),
                     orientation=Qt.Horizontal,
                     callback=self.__val_feature_changed)
        self.__add_aggregation_controls()
        gui.rubber(self.controlArea)
        gui.auto_apply(self.controlArea, self, "auto_commit")

    def __add_aggregation_controls(self):
        box = gui.vBox(self.controlArea, "聚合")
        chinese_aggs = [
            "计数(Count)", "计数已定义项(Count_defined)", None, "总和(Sum)", "平均(Mean)",
            "样式(Mode)", "最小(Min)", "最大(Max)", "中位数(Median)", "变量(Var)", None,
            "大多数(Majority)"
        ]
        for agg, chinese_agg in zip(self.AGGREGATIONS, chinese_aggs):
            if agg is None:
                gui.separator(box, height=1)
                line = QFrame()
                line.setFrameShape(QFrame.HLine)
                line.setLineWidth(1)
                line.setFrameShadow(QFrame.Sunken)
                box.layout().addWidget(line)
                continue
            check_box = QCheckBox(str(chinese_agg), box)
            check_box.setChecked(agg in self.sel_agg_functions)
            check_box.clicked.connect(
                lambda *args, a=agg: self.__aggregation_cb_clicked(a, args[0]))
            box.layout().addWidget(check_box)

    def _add_main_area_controls(self):
        self.table_view = PivotTableView()
        self.table_view.selection_changed.connect(self.__invalidate_filtered)
        self.mainArea.layout().addWidget(self.table_view)

    @property
    def no_col_feature(self):
        return self.col_feature is None and self.row_feature is not None \
            and self.row_feature.is_continuous

    @property
    def skipped_aggs(self):
        def add(fun):
            data, var = self.data, self.val_feature
            return data and not var and fun not in Pivot.AutonomousFunctions \
                or var and var.is_discrete and fun in Pivot.ContVarFunctions \
                or var and var.is_continuous and fun in Pivot.DiscVarFunctions

        skipped = [str(fun) for fun in self.sel_agg_functions if add(fun)]
        return ", ".join(sorted(skipped))

    def __feature_changed(self):
        self.selection = set()
        self.pivot = None
        self.commit()

    def __val_feature_changed(self):
        self.selection = set()
        if self.no_col_feature:
            return
        self.pivot.update_pivot_table(self.val_feature)
        self.commit()

    def __aggregation_cb_clicked(self, agg_fun: Pivot.Functions,
                                 checked: bool):
        self.selection = set()
        if checked:
            self.sel_agg_functions.add(agg_fun)
        else:
            self.sel_agg_functions.remove(agg_fun)
        if self.no_col_feature or not self.pivot or not self.data:
            return
        self.pivot.update_group_table(self.sel_agg_functions, self.val_feature)
        self.commit()

    def __invalidate_filtered(self):
        self.selection = self.table_view.get_selection()
        self.commit()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.pivot = None
        self.check_data()
        self.init_attr_values()
        self.openContext(self.data)
        self.unconditional_commit()

    def check_data(self):
        self.clear_messages()
        if not self.data:
            self.table_view.clear()

    def init_attr_values(self):
        domain = self.data.domain if self.data and len(self.data) else None
        for attr in ("row_feature", "col_feature", "val_feature"):
            getattr(self.controls, attr).model().set_domain(domain)
            setattr(self, attr, None)
        model = self.controls.row_feature.model()
        if model:
            self.row_feature = model[0]
        model = self.controls.val_feature.model()
        if model and len(model) > 2:
            self.val_feature = domain.variables[0] \
                if domain.variables[0] in model else model[2]

    def commit(self):
        if self.pivot is None:
            self.Warning.no_col_feature.clear()
            if self.no_col_feature:
                self.Warning.no_col_feature()
                return
            self.pivot = Pivot(self.data, self.sel_agg_functions,
                               self.row_feature, self.col_feature,
                               self.val_feature)
        self.Warning.cannot_aggregate.clear()
        if self.skipped_aggs:
            self.Warning.cannot_aggregate(self.skipped_aggs)
        self._update_graph()
        self.Outputs.grouped_data.send(self.pivot.group_table)
        self.Outputs.pivot_table.send(self.pivot.pivot_table)
        self.Outputs.filtered_data.send(self.get_filtered_data())

    def _update_graph(self):
        self.table_view.clear()
        if self.pivot.pivot_table:
            col_feature = self.col_feature or self.row_feature
            self.table_view.update_table(col_feature.name,
                                         self.row_feature.name,
                                         *self.pivot.pivot_tables)
            self.table_view.set_selection(self.selection)

    def get_filtered_data(self):
        if not self.data or not self.selection or not self.pivot.pivot_table:
            return None

        cond = []
        for i, j in self.selection:
            f = []
            for at, val in [(self.row_feature, self.pivot.pivot_table.X[i, 0]),
                            (self.col_feature, j)]:
                if isinstance(at, DiscreteVariable):
                    f.append(FilterDiscrete(at, [val]))
                elif isinstance(at, ContinuousVariable):
                    f.append(FilterContinuous(at, FilterContinuous.Equal, val))
            cond.append(Values(f))
        return Values([f for f in cond], conjunction=False)(self.data)

    def sizeHint(self):
        return QSize(640, 525)

    def send_report(self):
        self.report_items((("Row feature", self.row_feature),
                           ("Column feature", self.col_feature),
                           ("Value feature", self.val_feature)))
        if self.data and self.val_feature is not None:
            self.report_table("", self.table_view)
        if not self.data:
            self.report_items((("Group by", self.row_feature), ))
            self.report_table(self.table_view)
Пример #5
0
class OWContinuize(widget.OWWidget):
    name = "Continuize"
    description = ("Transform categorical attributes into numeric and, " +
                   "optionally, normalize numeric values.")
    icon = "icons/Continuize.svg"
    category = "Data"
    keywords = ["data", "continuize"]

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    want_main_area = False
    buttons_area_orientation = Qt.Vertical
    resizing_enabled = False

    multinomial_treatment = Setting(0)
    zero_based = Setting(1)
    continuous_treatment = Setting(0)
    class_treatment = Setting(0)

    transform_class = Setting(False)

    autosend = Setting(True)

    multinomial_treats = (("Target or first value as base",
                           Continuize.FirstAsBase),
                          ("Most frequent value as base",
                           Continuize.FrequentAsBase),
                          ("One attribute per value", Continuize.Indicators),
                          ("Ignore multinomial attributes",
                           Continuize.RemoveMultinomial),
                          ("Remove categorical attributes",
                           Continuize.Remove), ("Treat as ordinal",
                                                Continuize.AsOrdinal),
                          ("Divide by number of values",
                           Continuize.AsNormalizedOrdinal))

    continuous_treats = (("Leave them as they are", Continuize.Leave),
                         ("Normalize by span", Normalize.NormalizeBySpan),
                         ("Normalize by standard deviation",
                          Normalize.NormalizeBySD))

    class_treats = (
        ("Leave it as it is", Continuize.Leave),
        ("Treat as ordinal", Continuize.AsOrdinal),
        ("Divide by number of values", Continuize.AsNormalizedOrdinal),
        ("One class per value", Continuize.Indicators),
    )

    value_ranges = ["From -1 to 1", "From 0 to 1"]

    def __init__(self):
        super().__init__()

        box = gui.vBox(self.controlArea, "Categorical Features")
        gui.radioButtonsInBox(
            box,
            self,
            "multinomial_treatment",
            btnLabels=[x[0] for x in self.multinomial_treats],
            callback=self.settings_changed)

        box = gui.vBox(self.controlArea, "Numeric Features")
        gui.radioButtonsInBox(box,
                              self,
                              "continuous_treatment",
                              btnLabels=[x[0] for x in self.continuous_treats],
                              callback=self.settings_changed)

        box = gui.vBox(self.controlArea, "Categorical Outcomes")
        gui.radioButtonsInBox(box,
                              self,
                              "class_treatment",
                              btnLabels=[t[0] for t in self.class_treats],
                              callback=self.settings_changed)

        zbbox = gui.vBox(self.controlArea, "Value Range")

        gui.radioButtonsInBox(zbbox,
                              self,
                              "zero_based",
                              btnLabels=self.value_ranges,
                              callback=self.settings_changed)

        gui.auto_commit(self.buttonsArea, self, "autosend", "Apply", box=False)

        self.data = None

    def settings_changed(self):
        self.commit()

    @Inputs.data
    @check_sql_input
    def setData(self, data):
        self.data = data
        if data is None:
            self.Outputs.data.send(None)
        else:
            self.unconditional_commit()

    def constructContinuizer(self):
        conzer = DomainContinuizer(
            zero_based=self.zero_based,
            multinomial_treatment=self.multinomial_treats[
                self.multinomial_treatment][1],
            continuous_treatment=self.continuous_treats[
                self.continuous_treatment][1],
            class_treatment=self.class_treats[self.class_treatment][1])
        return conzer

    # def sendPreprocessor(self):
    #     continuizer = self.constructContinuizer()
    #     self.send("Preprocessor", PreprocessedLearner(
    #         lambda data, weightId=0, tc=(self.targetValue if self.classTreatment else -1):
    #             Table(continuizer(data, weightId, tc)
    #                 if data.domain.has_discrete_class
    #                 else continuizer(data, weightId), data)))

    def commit(self):
        continuizer = self.constructContinuizer()
        if self.data is not None and len(self.data):
            domain = continuizer(self.data)
            data = self.data.transform(domain)
            self.Outputs.data.send(data)
        else:
            self.Outputs.data.send(self.data)  # None or empty data

    def send_report(self):
        self.report_items(
            "Settings",
            [("Categorical features",
              self.multinomial_treats[self.multinomial_treatment][0]),
             ("Numeric features",
              self.continuous_treats[self.continuous_treatment][0]),
             ("Class", self.class_treats[self.class_treatment][0]),
             ("Value range", self.value_ranges[self.zero_based])])
Пример #6
0
class OWScoreCells(widget.OWWidget):
    name = "Score Cells"
    description = "Add a cell score based on the given set of genes"
    icon = "icons/ScoreCells.svg"
    priority = 180

    settingsHandler = DomainContextHandler()
    gene = ContextSetting(None)
    auto_apply = Setting(True)

    want_main_area = False

    class Warning(OWWidget.Warning):
        no_genes = Msg("No matching genes in data")
        some_genes = Msg("{} (of {}) genes not found in data")

    class Inputs:
        data = Input("Data", Table)
        genes = Input("Genes", Table)

    class Outputs:
        data = Output("Data", Table)

    def __init__(self):
        super().__init__()

        self.data = None
        self.genes = None
        self.feature_model = DomainModel(valid_types=StringVariable)

        box = gui.vBox(self.controlArea, "Gene name")
        gui.comboBox(box,
                     self,
                     'gene',
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._invalidate)

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.data = data

    @Inputs.genes
    @check_sql_input
    def set_genes(self, genes):
        self.closeContext()
        self.genes = genes
        self.feature_model.set_domain(None)
        self.gene = None
        if self.genes:
            self.feature_model.set_domain(self.genes.domain)
            if self.feature_model:
                self.gene = self.feature_model[0]
                self.openContext(genes)

    def handleNewSignals(self):
        self._invalidate()

    def commit(self):
        self.clear_messages()

        if self.data is None:
            self.Outputs.data.send(None)
            return

        score = np.zeros(len(self.data))
        if self.genes and self.gene:
            available_genes = set(f.name for f in self.data.domain.variables)
            gene_list_all = [str(ins[self.gene]) for ins in self.genes]
            gene_list = [g for g in gene_list_all if g in available_genes]
            if not gene_list:
                self.Warning.no_genes()
            else:
                if len(gene_list) < len(gene_list_all):
                    self.Warning.some_genes(
                        len(gene_list_all) - len(gene_list),
                        len(gene_list_all))
                values = self.data[:, gene_list].X
                score = np.nanmax(values, axis=1)

        d = self.data.domain
        score_var = ContinuousVariable('Score')
        dom = Domain(d.attributes, d.class_vars, d.metas + (score_var, ))
        table = self.data.transform(dom)
        col, sparse = table.get_column_view(score_var)
        col[:] = score
        self.Outputs.data.send(table)

    def _invalidate(self):
        self.commit()

    def send_report(self):
        gene = None
        if self.genes is not None:
            gene = self.gene
            if gene in self.genes.domain:
                gene = self.genes.domain[gene]
        self.report_items((("Gene", gene), ))
Пример #7
0
class OWScatterPlotBase(gui.OWComponent, QObject):
    """
    Provide a graph component for widgets that show any kind of point plot

    The component plots a set of points with given coordinates, shapes,
    sizes and colors. Its function is similar to that of a *view*, whereas
    the widget represents a *model* and a *controler*.

    The model (widget) needs to provide methods:

    - `get_coordinates_data`, `get_size_data`, `get_color_data`,
      `get_shape_data`, `get_label_data`, which return a 1d array (or two
      arrays, for `get_coordinates_data`) of `dtype` `float64`, except for
      `get_label_data`, which returns formatted labels;
    - `get_color_labels`, `get_shape_labels`, which are return lists of
       strings used for the color and shape legend;
    - `get_tooltip`, which gives a tooltip for a single data point
    - (optional) `impute_sizes`, `impute_shapes` get final coordinates and
      shapes, and replace nans;
    - `get_subset_mask` returns a bool array indicating whether a
      data point is in the subset or not (e.g. in the 'Data Subset' signal
      in the Scatter plot and similar widgets);
    - `get_palette` returns a palette appropriate for visualizing the
      current color data;
    - `is_continuous_color` decides the type of the color legend;

    The widget (in a role of controller) must also provide methods
    - `selection_changed`

    If `get_coordinates_data` returns `(None, None)`, the plot is cleared. If
    `get_size_data`, `get_color_data` or `get_shape_data` return `None`,
    all points will have the same size, color or shape, respectively.
    If `get_label_data` returns `None`, there are no labels.

    The view (this compomnent) provides methods `update_coordinates`,
    `update_sizes`, `update_colors`, `update_shapes` and `update_labels`
    that the widget (in a role of a controler) should call when any of
    these properties are changed. If the widget calls, for instance, the
    plot's `update_colors`, the plot will react by calling the widget's
    `get_color_data` as well as the widget's methods needed to construct the
    legend.

    The view also provides a method `reset_graph`, which should be called only
    when
    - the widget gets entirely new data
    - the number of points may have changed, for instance when selecting
    a different attribute for x or y in the scatter plot, where the points
    with missing x or y coordinates are hidden.

    Every `update_something` calls the plot's `get_something`, which
    calls the model's `get_something_data`, then it transforms this data
    into whatever is needed (colors, shapes, scaled sizes) and changes the
    plot. For the simplest example, here is `update_shapes`:

    ```
        def update_shapes(self):
            if self.scatterplot_item:
                shape_data = self.get_shapes()
                self.scatterplot_item.setSymbol(shape_data)
            self.update_legends()

        def get_shapes(self):
            shape_data = self.master.get_shape_data()
            shape_data = self.master.impute_shapes(
                shape_data, len(self.CurveSymbols) - 1)
            return self.CurveSymbols[shape_data]
    ```

    On the widget's side, `get_something_data` is essentially just:

    ```
        def get_size_data(self):
            return self.get_column(self.attr_size)
    ```

    where `get_column` retrieves a column while also filtering out the
    points with missing x and y and so forth. (Here we present the simplest
    two cases, "shapes" for the view and "sizes" for the model. The colors
    for the view are more complicated since they deal with discrete and
    continuous palettes, and the shapes for the view merge infrequent shapes.)

    The plot can also show just a random sample of the data. The sample size is
    set by `set_sample_size`, and the rest is taken care by the plot: the
    widget keeps providing the data for all points, selection indices refer
    to the entire set etc. Internally, sampling happens as early as possible
    (in methods `get_<something>`).
    """
    too_many_labels = Signal(bool)
    begin_resizing = Signal()
    step_resizing = Signal()
    end_resizing = Signal()

    label_only_selected = Setting(False)
    point_width = Setting(10)
    alpha_value = Setting(128)
    show_grid = Setting(False)
    show_legend = Setting(True)
    class_density = Setting(False)
    jitter_size = Setting(0)

    resolution = 256

    CurveSymbols = np.array("o x t + d s t2 t3 p h star ?".split())
    MinShapeSize = 6
    DarkerValue = 120
    UnknownColor = (168, 50, 168)

    COLOR_NOT_SUBSET = (128, 128, 128, 0)
    COLOR_SUBSET = (128, 128, 128, 255)
    COLOR_DEFAULT = (128, 128, 128, 0)

    MAX_VISIBLE_LABELS = 500

    def __init__(self, scatter_widget, parent=None, view_box=ViewBox):
        QObject.__init__(self)
        gui.OWComponent.__init__(self, scatter_widget)

        self.subset_is_shown = False

        self.view_box = view_box(self)
        self.plot_widget = pg.PlotWidget(viewBox=self.view_box, parent=parent,
                                         background="w")
        self.plot_widget.hideAxis("left")
        self.plot_widget.hideAxis("bottom")
        self.plot_widget.getPlotItem().buttonsHidden = True
        self.plot_widget.setAntialiasing(True)
        self.plot_widget.sizeHint = lambda: QSize(500, 500)

        self.density_img = None
        self.scatterplot_item = None
        self.scatterplot_item_sel = None
        self.labels = []

        self.master = scatter_widget
        self._create_drag_tooltip(self.plot_widget.scene())

        self.selection = None  # np.ndarray

        self.n_valid = 0
        self.n_shown = 0
        self.sample_size = None
        self.sample_indices = None

        self.palette = None

        self.shape_legend = self._create_legend(((1, 0), (1, 0)))
        self.color_legend = self._create_legend(((1, 1), (1, 1)))
        self.update_legend_visibility()

        self.scale = None  # DiscretizedScale
        self._too_many_labels = False

        # self.setMouseTracking(True)
        # self.grabGesture(QPinchGesture)
        # self.grabGesture(QPanGesture)

        self.update_grid_visibility()

        self._tooltip_delegate = EventDelegate(self.help_event)
        self.plot_widget.scene().installEventFilter(self._tooltip_delegate)
        self.view_box.sigTransformChanged.connect(self.update_density)
        self.view_box.sigRangeChangedManually.connect(self.update_labels)

        self.timer = None

    def _create_legend(self, anchor):
        legend = LegendItem()
        legend.setParentItem(self.plot_widget.getViewBox())
        legend.restoreAnchor(anchor)
        return legend

    def _create_drag_tooltip(self, scene):
        tip_parts = [
            (Qt.ShiftModifier, "Shift: Add group"),
            (Qt.ShiftModifier + Qt.ControlModifier,
             "Shift-{}: Append to group".
             format("Cmd" if sys.platform == "darwin" else "Ctrl")),
            (Qt.AltModifier, "Alt: Remove")
        ]
        all_parts = ", ".join(part for _, part in tip_parts)
        self.tiptexts = {
            int(modifier): all_parts.replace(part, "<b>{}</b>".format(part))
            for modifier, part in tip_parts
        }
        self.tiptexts[0] = all_parts

        self.tip_textitem = text = QGraphicsTextItem()
        # Set to the longest text
        text.setHtml(self.tiptexts[Qt.ShiftModifier + Qt.ControlModifier])
        text.setPos(4, 2)
        r = text.boundingRect()
        rect = QGraphicsRectItem(0, 0, r.width() + 8, r.height() + 4)
        rect.setBrush(QColor(224, 224, 224, 212))
        rect.setPen(QPen(Qt.NoPen))
        self.update_tooltip()

        scene.drag_tooltip = scene.createItemGroup([rect, text])
        scene.drag_tooltip.hide()

    def update_tooltip(self, modifiers=Qt.NoModifier):
        modifiers &= Qt.ShiftModifier + Qt.ControlModifier + Qt.AltModifier
        text = self.tiptexts.get(int(modifiers), self.tiptexts[0])
        self.tip_textitem.setHtml(text + self._get_jittering_tooltip())

    def _get_jittering_tooltip(self):
        warn_jittered = ""
        if self.jitter_size:
            warn_jittered = \
                '<br/><br/>' \
                '<span style="background-color: red; color: white; ' \
                'font-weight: 500;">' \
                '&nbsp;Warning: Selection is applied to unjittered data&nbsp;' \
                '</span>'
        return warn_jittered

    def update_jittering(self):
        self.update_tooltip()
        x, y = self.get_coordinates()
        if x is None or not len(x) or self.scatterplot_item is None:
            return
        self._update_plot_coordinates(self.scatterplot_item, x, y)
        self._update_plot_coordinates(self.scatterplot_item_sel, x, y)
        self.update_labels()

    # TODO: Rename to remove_plot_items
    def clear(self):
        """
        Remove all graphical elements from the plot

        Calls the pyqtgraph's plot widget's clear, sets all handles to `None`,
        removes labels and selections.

        This method should generally not be called by the widget. If the data
        is gone (*e.g.* upon receiving `None` as an input data signal), this
        should be handler by calling `reset_graph`, which will in turn call
        `clear`.

        Derived classes should override this method if they add more graphical
        elements. For instance, the regression line in the scatterplot adds
        `self.reg_line_item = None` (the line in the plot is already removed
        in this method).
        """
        self.plot_widget.clear()

        self.density_img = None
        if self.timer is not None and self.timer.isActive():
            self.timer.stop()
            self.timer = None
        self.scatterplot_item = None
        self.scatterplot_item_sel = None
        self.labels = []
        self._signal_too_many_labels(False)
        self.view_box.init_history()
        self.view_box.tag_history()

    # TODO: I hate `keep_something` and `reset_something` arguments
    # __keep_selection is used exclusively be set_sample size which would
    # otherwise just repeat the code from reset_graph except for resetting
    # the selection. I'm uncomfortable with this; we may prefer to have a
    # method _reset_graph which does everything except resetting the selection,
    # and reset_graph would call it.
    def reset_graph(self, __keep_selection=False):
        """
        Reset the graph to new data (or no data)

        The method must be called when the plot receives new data, in
        particular when the number of points change. If only their properties
        - like coordinates or shapes - change, an update method
        (`update_coordinates`, `update_shapes`...) should be called instead.

        The method must also be called when the data is gone.

        The method calls `clear`, followed by calls of all update methods.

        NB. Argument `__keep_selection` is for internal use only
        """
        self.clear()
        if not __keep_selection:
            self.selection = None
        self.sample_indices = None
        self.update_coordinates()
        self.update_point_props()

    def set_sample_size(self, sample_size):
        """
        Set the sample size

        Args:
            sample_size (int or None): sample size or `None` to show all points
        """
        if self.sample_size != sample_size:
            self.sample_size = sample_size
            self.reset_graph(True)

    def update_point_props(self):
        """
        Update the sizes, colors, shapes and labels

        The method calls the appropriate update methods for individual
        properties.
        """
        self.update_sizes()
        self.update_colors()
        self.update_selection_colors()
        self.update_shapes()
        self.update_labels()

    # Coordinates
    # TODO: It could be nice if this method was run on entire data, not just
    # a sample. For this, however, it would need to either be called from
    # `get_coordinates` before sampling (very ugly) or call
    # `self.master.get_coordinates_data` (beyond ugly) or the widget would
    # have to store the ranges of unsampled data (ugly).
    # Maybe we leave it as it is.
    def _reset_view(self, x_data, y_data):
        """
        Set the range of the view box

        Args:
            x_data (np.ndarray): x coordinates
            y_data (np.ndarray) y coordinates
        """
        min_x, max_x = np.min(x_data), np.max(x_data)
        min_y, max_y = np.min(y_data), np.max(y_data)
        self.view_box.setRange(
            QRectF(min_x, min_y, max_x - min_x or 1, max_y - min_y or 1),
            padding=0.025)

    def _filter_visible(self, data):
        """Return the sample from the data using the stored sample_indices"""
        if data is None or self.sample_indices is None:
            return data
        else:
            return np.asarray(data[self.sample_indices])

    def get_coordinates(self):
        """
        Prepare coordinates of the points in the plot

        The method is called by `update_coordinates`. It gets the coordinates
        from the widget, jitters them and return them.

        The methods also initializes the sample indices if neededd and stores
        the original and sampled number of points.

        Returns:
            (tuple): a pair of numpy arrays containing (sampled) coordinates,
                or `(None, None)`.
        """
        x, y = self.master.get_coordinates_data()
        if x is None:
            self.n_valid = self.n_shown = 0
            return None, None
        self.n_valid = len(x)
        self._create_sample()
        x = self._filter_visible(x)
        y = self._filter_visible(y)
        # Jittering after sampling is OK if widgets do not change the sample
        # semi-permanently, e.g. take a sample for the duration of some
        # animation. If the sample size changes dynamically (like by adding
        # a "sample size" slider), points would move around when the sample
        # size changes. To prevent this, jittering should be done before
        # sampling (i.e. two lines earlier). This would slow it down somewhat.
        x, y = self.jitter_coordinates(x, y)
        return x, y

    def _create_sample(self):
        """
        Create a random sample if the data is larger than the set sample size
        """
        self.n_shown = min(self.n_valid, self.sample_size or self.n_valid)
        if self.sample_size is not None \
                and self.sample_indices is None \
                and self.n_valid != self.n_shown:
            random = np.random.RandomState(seed=0)
            self.sample_indices = random.choice(
                self.n_valid, self.n_shown, replace=False)
            # TODO: Is this really needed?
            np.sort(self.sample_indices)

    def jitter_coordinates(self, x, y):
        """
        Display coordinates to random positions within ellipses with
        radiuses of `self.jittter_size` percents of spans
        """
        if self.jitter_size == 0:
            return x, y
        return self._jitter_data(x, y)

    def _jitter_data(self, x, y, span_x=None, span_y=None):
        if span_x is None:
            span_x = np.max(x) - np.min(x)
        if span_y is None:
            span_y = np.max(y) - np.min(y)
        random = np.random.RandomState(seed=0)
        rs = random.uniform(0, 1, len(x))
        phis = random.uniform(0, 2 * np.pi, len(x))
        magnitude = self.jitter_size / 100
        return (x + magnitude * span_x * rs * np.cos(phis),
                y + magnitude * span_y * rs * np.sin(phis))

    def _update_plot_coordinates(self, plot, x, y):
        """
        Change the coordinates of points while keeping other properites

        Note. Pyqtgraph does not offer a method for this: setting coordinates
        invalidates other data. We therefore retrieve the data to set it
        together with the coordinates. Pyqtgraph also does not offer a
        (documented) method for retrieving the data, yet using
        `plot.data[prop]` looks reasonably safe. The alternative, calling
        update for every property would essentially reset the graph, which
        can be time consuming.
        """
        data = dict(x=x, y=y)
        for prop in ('pen', 'brush', 'size', 'symbol', 'data',
                     'sourceRect', 'targetRect'):
            data[prop] = plot.data[prop]
        plot.setData(**data)

    def update_coordinates(self):
        """
        Trigger the update of coordinates while keeping other features intact.

        The method gets the coordinates by calling `self.get_coordinates`,
        which in turn calls the widget's `get_coordinate_data`. The number of
        coordinate pairs returned by the latter must match the current number
        of points. If this is not the case, the widget should trigger
        the complete update by calling `reset_graph` instead of this method.
        """
        x, y = self.get_coordinates()
        if x is None or not len(x):
            return
        if self.scatterplot_item is None:
            if self.sample_indices is None:
                indices = np.arange(self.n_valid)
            else:
                indices = self.sample_indices
            kwargs = dict(x=x, y=y, data=indices)
            self.scatterplot_item = ScatterPlotItem(**kwargs)
            self.scatterplot_item.sigClicked.connect(self.select_by_click)
            self.scatterplot_item_sel = ScatterPlotItem(**kwargs)
            self.plot_widget.addItem(self.scatterplot_item_sel)
            self.plot_widget.addItem(self.scatterplot_item)
        else:
            self._update_plot_coordinates(self.scatterplot_item, x, y)
            self._update_plot_coordinates(self.scatterplot_item_sel, x, y)
            self.update_labels()

        self.update_density()  # Todo: doesn't work: try MDS with density on
        self._reset_view(x, y)

    # Sizes
    def get_sizes(self):
        """
        Prepare data for sizes of points in the plot

        The method is called by `update_sizes`. It gets the sizes
        from the widget and performs the necessary scaling and sizing.

        Returns:
            (np.ndarray): sizes
        """
        size_column = self.master.get_size_data()
        if size_column is None:
            return np.full((self.n_shown,),
                           self.MinShapeSize + (5 + self.point_width) * 0.5)
        size_column = self._filter_visible(size_column)
        size_column = size_column.copy()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            size_column -= np.nanmin(size_column)
            mx = np.nanmax(size_column)
        if mx > 0:
            size_column /= mx
        else:
            size_column[:] = 0.5
        return self.MinShapeSize + (5 + self.point_width) * size_column

    def update_sizes(self):
        """
        Trigger an update of point sizes

        The method calls `self.get_sizes`, which in turn calls the widget's
        `get_size_data`. The result are properly scaled and then passed
        back to widget for imputing (`master.impute_sizes`).
        """
        if self.scatterplot_item:
            size_data = self.get_sizes()
            size_imputer = getattr(
                self.master, "impute_sizes", self.default_impute_sizes)
            size_imputer(size_data)

            if self.timer is not None and self.timer.isActive():
                self.timer.stop()
                self.timer = None

            current_size_data = self.scatterplot_item.data["size"].copy()
            diff = size_data - current_size_data
            widget = self

            class Timeout:
                # 0.5 - np.cos(np.arange(0.17, 1, 0.17) * np.pi) / 2
                factors = [0.07, 0.26, 0.52, 0.77, 0.95, 1]

                def __init__(self):
                    self._counter = 0

                def __call__(self):
                    factor = self.factors[self._counter]
                    self._counter += 1
                    size = current_size_data + diff * factor
                    if len(self.factors) == self._counter:
                        widget.timer.stop()
                        widget.timer = None
                        size = size_data
                    widget.scatterplot_item.setSize(size)
                    widget.scatterplot_item_sel.setSize(size + SELECTION_WIDTH)
                    if widget.timer is None:
                        widget.end_resizing.emit()
                    else:
                        widget.step_resizing.emit()

            if self.n_valid <= MAX_N_VALID_SIZE_ANIMATE and \
                    np.all(current_size_data > 0) and np.any(diff != 0):
                # If encountered any strange behaviour when updating sizes,
                # implement it with threads
                self.begin_resizing.emit()
                self.timer = QTimer(self.scatterplot_item, interval=50)
                self.timer.timeout.connect(Timeout())
                self.timer.start()
            else:
                self.begin_resizing.emit()
                self.scatterplot_item.setSize(size_data)
                self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH)
                self.end_resizing.emit()

    update_point_size = update_sizes  # backward compatibility (needed?!)
    update_size = update_sizes

    @classmethod
    def default_impute_sizes(cls, size_data):
        """
        Fallback imputation for sizes.

        Set the size to two pixels smaller than the minimal size

        Returns:
            (bool): True if there was any missing data
        """
        nans = np.isnan(size_data)
        if np.any(nans):
            size_data[nans] = cls.MinShapeSize - 2
            return True
        else:
            return False

    # Colors
    def get_colors(self):
        """
        Prepare data for colors of the points in the plot

        The method is called by `update_colors`. It gets the colors and the
        indices of the data subset from the widget (`get_color_data`,
        `get_subset_mask`), and constructs lists of pens and brushes for
        each data point.

        The method uses different palettes for discrete and continuous data,
        as determined by calling the widget's method `is_continuous_color`.

        If also marks the points that are in the subset as defined by, for
        instance the 'Data Subset' signal in the Scatter plot and similar
        widgets. (Do not confuse this with *selected points*, which are
        marked by circles around the points, which are colored by groups
        and thus independent of this method.)

        Returns:
            (tuple): a list of pens and list of brushes
        """
        self.palette = self.master.get_palette()
        c_data = self.master.get_color_data()
        c_data = self._filter_visible(c_data)
        subset = self.master.get_subset_mask()
        subset = self._filter_visible(subset)
        self.subset_is_shown = subset is not None
        if c_data is None:  # same color
            return self._get_same_colors(subset)
        elif self.master.is_continuous_color():
            return self._get_continuous_colors(c_data, subset)
        else:
            return self._get_discrete_colors(c_data, subset)

    def _get_same_colors(self, subset):
        """
        Return the same pen for all points while the brush color depends
        upon whether the point is in the subset or not

        Args:
            subset (np.ndarray): a bool array indicating whether a data point
                is in the subset or not (e.g. in the 'Data Subset' signal
                in the Scatter plot and similar widgets);

        Returns:
            (tuple): a list of pens and list of brushes
        """
        color = self.plot_widget.palette().color(OWPalette.Data)
        pen = [_make_pen(color, 1.5) for _ in range(self.n_shown)]
        if subset is not None:
            brush = np.where(
                subset,
                *(QBrush(QColor(*col))
                  for col in (self.COLOR_SUBSET, self.COLOR_NOT_SUBSET)))
        else:
            color = QColor(*self.COLOR_DEFAULT)
            color.setAlpha(self.alpha_value)
            brush = [QBrush(color) for _ in range(self.n_shown)]
        return pen, brush

    def _get_continuous_colors(self, c_data, subset):
        """
        Return the pens and colors whose color represent an index into
        a continuous palette. The same color is used for pen and brush,
        except the former is darker. If the data has a subset, the brush
        is transparent for points that are not in the subset.
        """
        if np.isnan(c_data).all():
            self.scale = None
        else:
            self.scale = DiscretizedScale(np.nanmin(c_data), np.nanmax(c_data))
            c_data -= self.scale.offset
            c_data /= self.scale.width
            c_data = np.floor(c_data) + 0.5
            c_data /= self.scale.bins
            c_data = np.clip(c_data, 0, 1)
        pen = self.palette.getRGB(c_data)
        brush = np.hstack(
            [pen, np.full((len(pen), 1), self.alpha_value, dtype=int)])
        pen *= 100
        pen //= self.DarkerValue
        pen = [_make_pen(QColor(*col), 1.5) for col in pen.tolist()]

        if subset is not None:
            brush[:, 3] = 0
            brush[subset, 3] = 255
        brush = np.array([QBrush(QColor(*col)) for col in brush.tolist()])
        return pen, brush

    def _get_discrete_colors(self, c_data, subset):
        """
        Return the pens and colors whose color represent an index into
        a discrete palette. The same color is used for pen and brush,
        except the former is darker. If the data has a subset, the brush
        is transparent for points that are not in the subset.
        """
        n_colors = self.palette.number_of_colors
        c_data = c_data.copy()
        c_data[np.isnan(c_data)] = n_colors
        c_data = c_data.astype(int)
        colors = np.r_[self.palette.getRGB(np.arange(n_colors)),
                       [[128, 128, 128]]]
        pens = np.array(
            [_make_pen(QColor(*col).darker(self.DarkerValue), 1.5)
             for col in colors])
        pen = pens[c_data]
        alpha = self.alpha_value if subset is None else 255
        brushes = np.array([
            [QBrush(QColor(0, 0, 0, 0)),
             QBrush(QColor(col[0], col[1], col[2], alpha))]
            for col in colors])
        brush = brushes[c_data]

        if subset is not None:
            brush = np.where(subset, brush[:, 1], brush[:, 0])
        else:
            brush = brush[:, 1]
        return pen, brush

    def update_colors(self):
        """
        Trigger an update of point sizes

        The method calls `self.get_colors`, which in turn calls the widget's
        `get_color_data` to get the indices in the pallette. `get_colors`
        returns a list of pens and brushes to which this method uses to
        update the colors. Finally, the method triggers the update of the
        legend and the density plot.
        """
        if self.scatterplot_item is not None:
            pen_data, brush_data = self.get_colors()
            self.scatterplot_item.setPen(pen_data, update=False, mask=None)
            self.scatterplot_item.setBrush(brush_data, mask=None)
        self.update_legends()
        self.update_density()

    update_alpha_value = update_colors

    def update_density(self):
        """
        Remove the existing density plot (if there is one) and replace it
        with a new one (if enabled).

        The method gets the colors from the pens of the currently plotted
        points.
        """
        if self.density_img:
            self.plot_widget.removeItem(self.density_img)
            self.density_img = None
        if self.class_density and self.scatterplot_item is not None:
            rgb_data = [
                pen.color().getRgb()[:3] if pen is not None else (255, 255, 255)
                for pen in self.scatterplot_item.data['pen']]
            if len(set(rgb_data)) <= 1:
                return
            [min_x, max_x], [min_y, max_y] = self.view_box.viewRange()
            x_data, y_data = self.scatterplot_item.getData()
            self.density_img = classdensity.class_density_image(
                min_x, max_x, min_y, max_y, self.resolution,
                x_data, y_data, rgb_data)
            self.plot_widget.addItem(self.density_img)

    def update_selection_colors(self):
        """
        Trigger an update of selection markers

        This update method is usually not called by the widget but by the
        plot, since it is the plot that handles the selections.

        Like other update methods, it calls the corresponding get method
        (`get_colors_sel`) which returns a list of pens and brushes.
        """
        if self.scatterplot_item_sel is None:
            return
        pen, brush = self.get_colors_sel()
        self.scatterplot_item_sel.setPen(pen, update=False, mask=None)
        self.scatterplot_item_sel.setBrush(brush, mask=None)

    def get_colors_sel(self):
        """
        Return pens and brushes for selection markers.

        A pen can is set to `Qt.NoPen` if a point is not selected.

        All brushes are completely transparent whites.

        Returns:
            (tuple): a list of pens and a list of brushes
        """
        nopen = QPen(Qt.NoPen)
        if self.selection is None:
            pen = [nopen] * self.n_shown
        else:
            sels = np.max(self.selection)
            if sels == 1:
                pen = np.where(
                    self._filter_visible(self.selection),
                    _make_pen(QColor(255, 190, 0, 255), SELECTION_WIDTH + 1),
                    nopen)
            else:
                palette = ColorPaletteGenerator(number_of_colors=sels + 1)
                pen = np.choose(
                    self._filter_visible(self.selection),
                    [nopen] + [_make_pen(palette[i], SELECTION_WIDTH + 1)
                               for i in range(sels)])
        return pen, [QBrush(QColor(255, 255, 255, 0))] * self.n_shown

    # Labels
    def get_labels(self):
        """
        Prepare data for labels for points

        The method returns the results of the widget's `get_label_data`

        Returns:
            (labels): a sequence of labels
        """
        return self._filter_visible(self.master.get_label_data())

    def update_labels(self):
        """
        Trigger an update of labels

        The method calls `get_labels` which in turn calls the widget's
        `get_label_data`. The obtained labels are shown if the corresponding
        points are selected or if `label_only_selected` is `false`.
        """
        for label in self.labels:
            self.plot_widget.removeItem(label)
        self.labels = []

        mask = None
        if self.scatterplot_item is not None:
            x, y = self.scatterplot_item.getData()
            mask = self._label_mask(x, y)

        if mask is not None:
            labels = self.get_labels()
            if labels is None:
                mask = None

        self._signal_too_many_labels(
            mask is not None and mask.sum() > self.MAX_VISIBLE_LABELS)
        if self._too_many_labels or mask is None or not np.any(mask):
            return

        black = pg.mkColor(0, 0, 0)
        labels = labels[mask]
        x = x[mask]
        y = y[mask]
        for label, xp, yp in zip(labels, x, y):
            ti = TextItem(label, black)
            ti.setPos(xp, yp)
            self.plot_widget.addItem(ti)
            self.labels.append(ti)

    def _signal_too_many_labels(self, too_many):
        if self._too_many_labels != too_many:
            self._too_many_labels = too_many
            self.too_many_labels.emit(too_many)

    def _label_mask(self, x, y):
        (x0, x1), (y0, y1) = self.view_box.viewRange()
        mask = np.logical_and(
            np.logical_and(x >= x0, x <= x1),
            np.logical_and(y >= y0, y <= y1))
        if self.label_only_selected:
            sub_mask = self._filter_visible(self.master.get_subset_mask())
            if self.selection is None:
                if sub_mask is None:
                    return None
                else:
                    sel_mask = sub_mask
            else:
                sel_mask = self._filter_visible(self.selection) != 0
                if sub_mask is not None:
                    sel_mask = np.logical_or(sel_mask, sub_mask)
            mask = np.logical_and(mask, sel_mask)
        return mask

    # Shapes
    def get_shapes(self):
        """
        Prepare data for shapes of points in the plot

        The method is called by `update_shapes`. It gets the data from
        the widget's `get_shape_data`, and then calls its `impute_shapes`
        to impute the missing shape (usually with some default shape).

        Returns:
            (np.ndarray): an array of symbols (e.g. o, x, + ...)
        """
        shape_data = self.master.get_shape_data()
        shape_data = self._filter_visible(shape_data)
        # Data has to be copied so the imputation can change it in-place
        # TODO: Try avoiding this when we move imputation to the widget
        if shape_data is not None:
            shape_data = np.copy(shape_data)
        shape_imputer = getattr(
            self.master, "impute_shapes", self.default_impute_shapes)
        shape_imputer(shape_data, len(self.CurveSymbols) - 1)
        if isinstance(shape_data, np.ndarray):
            shape_data = shape_data.astype(int)
        else:
            shape_data = np.zeros(self.n_shown, dtype=int)
        return self.CurveSymbols[shape_data]

    @staticmethod
    def default_impute_shapes(shape_data, default_symbol):
        """
        Fallback imputation for shapes.

        Use the default symbol, usually the last symbol in the list.

        Returns:
            (bool): True if there was any missing data
        """
        if shape_data is None:
            return False
        nans = np.isnan(shape_data)
        if np.any(nans):
            shape_data[nans] = default_symbol
            return True
        else:
            return False

    def update_shapes(self):
        """
        Trigger an update of point symbols

        The method calls `get_shapes` to obtain an array with a symbol
        for each point and uses it to update the symbols.

        Finally, the method updates the legend.
        """
        if self.scatterplot_item:
            shape_data = self.get_shapes()
            self.scatterplot_item.setSymbol(shape_data)
        self.update_legends()

    def update_grid_visibility(self):
        """Show or hide the grid"""
        self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)

    def update_legend_visibility(self):
        """
        Show or hide legends based on whether they are enabled and non-empty
        """
        self.shape_legend.setVisible(
            self.show_legend and bool(self.shape_legend.items))
        self.color_legend.setVisible(
            self.show_legend and bool(self.color_legend.items))

    def update_legends(self):
        """Update content of legends and their visibility"""
        cont_color = self.master.is_continuous_color()
        shape_labels = self.master.get_shape_labels()
        color_labels = None if cont_color else self.master.get_color_labels()
        if shape_labels == color_labels and shape_labels is not None:
            self._update_combined_legend(shape_labels)
        else:
            self._update_shape_legend(shape_labels)
            if cont_color:
                self._update_continuous_color_legend()
            else:
                self._update_color_legend(color_labels)
        self.update_legend_visibility()

    def _update_shape_legend(self, labels):
        self.shape_legend.clear()
        if labels is None or self.scatterplot_item is None:
            return
        color = QColor(0, 0, 0)
        color.setAlpha(self.alpha_value)
        for label, symbol in zip(labels, self.CurveSymbols):
            self.shape_legend.addItem(
                ScatterPlotItem(pen=color, brush=color, size=10, symbol=symbol),
                escape(label))

    def _update_continuous_color_legend(self):
        self.color_legend.clear()
        if self.scale is None or self.scatterplot_item is None:
            return
        label = PaletteItemSample(self.palette, self.scale)
        self.color_legend.addItem(label, "")
        self.color_legend.setGeometry(label.boundingRect())

    def _update_color_legend(self, labels):
        self.color_legend.clear()
        if labels is None:
            return
        self._update_colored_legend(self.color_legend, labels, 'o')

    def _update_combined_legend(self, labels):
        # update_colored_legend will already clear the shape legend
        # so we remove colors here
        use_legend = \
            self.shape_legend if self.shape_legend.items else self.color_legend
        self.color_legend.clear()
        self.shape_legend.clear()
        self._update_colored_legend(use_legend, labels, self.CurveSymbols)

    def _update_colored_legend(self, legend, labels, symbols):
        if self.scatterplot_item is None or not self.palette:
            return
        if isinstance(symbols, str):
            symbols = itertools.repeat(symbols, times=len(labels))
        for i, (label, symbol) in enumerate(zip(labels, symbols)):
            color = QColor(*self.palette.getRGB(i))
            pen = _make_pen(color.darker(self.DarkerValue), 1.5)
            color.setAlpha(255 if self.subset_is_shown else self.alpha_value)
            brush = QBrush(color)
            legend.addItem(
                ScatterPlotItem(pen=pen, brush=brush, size=10, symbol=symbol),
                escape(label))

    def zoom_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def pan_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().PanMode)

    def select_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def reset_button_clicked(self):
        self.plot_widget.getViewBox().autoRange()
        self.update_labels()

    def select_by_click(self, _, points):
        if self.scatterplot_item is not None:
            self.select(points)

    def select_by_rectangle(self, rect):
        if self.scatterplot_item is not None:
            x0, x1 = sorted((rect.topLeft().x(), rect.bottomRight().x()))
            y0, y1 = sorted((rect.topLeft().y(), rect.bottomRight().y()))
            x, y = self.master.get_coordinates_data()
            indices = np.flatnonzero(
                (x0 <= x) & (x <= x1) & (y0 <= y) & (y <= y1))
            self.select_by_indices(indices.astype(int))

    def unselect_all(self):
        if self.selection is not None:
            self.selection = None
            self.update_selection_colors()
            if self.label_only_selected:
                self.update_labels()
            self.master.selection_changed()

    def select(self, points):
        # noinspection PyArgumentList
        if self.scatterplot_item is None:
            return
        indices = [p.data() for p in points]
        self.select_by_indices(indices)

    def select_by_indices(self, indices):
        if self.selection is None:
            self.selection = np.zeros(self.n_valid, dtype=np.uint8)
        keys = QApplication.keyboardModifiers()
        if keys & Qt.AltModifier:
            self.selection_remove(indices)
        elif keys & Qt.ShiftModifier and keys & Qt.ControlModifier:
            self.selection_append(indices)
        elif keys & Qt.ShiftModifier:
            self.selection_new_group(indices)
        else:
            self.selection_select(indices)

    def selection_select(self, indices):
        self.selection = np.zeros(self.n_valid, dtype=np.uint8)
        self.selection[indices] = 1
        self._update_after_selection()

    def selection_append(self, indices):
        self.selection[indices] = np.max(self.selection)
        self._update_after_selection()

    def selection_new_group(self, indices):
        self.selection[indices] = np.max(self.selection) + 1
        self._update_after_selection()

    def selection_remove(self, indices):
        self.selection[indices] = 0
        self._update_after_selection()

    def _update_after_selection(self):
        self._compress_indices()
        self.update_selection_colors()
        if self.label_only_selected:
            self.update_labels()
        self.master.selection_changed()

    def _compress_indices(self):
        indices = sorted(set(self.selection) | {0})
        if len(indices) == max(indices) + 1:
            return
        mapping = np.zeros((max(indices) + 1,), dtype=int)
        for i, ind in enumerate(indices):
            mapping[ind] = i
        self.selection = mapping[self.selection]

    def get_selection(self):
        if self.selection is None:
            return np.array([], dtype=np.uint8)
        else:
            return np.flatnonzero(self.selection)

    def help_event(self, event):
        """
        Create a `QToolTip` for the point hovered by the mouse
        """
        if self.scatterplot_item is None:
            return False
        act_pos = self.scatterplot_item.mapFromScene(event.scenePos())
        point_data = [p.data() for p in self.scatterplot_item.pointsAt(act_pos)]
        text = self.master.get_tooltip(point_data)
        if text:
            QToolTip.showText(event.screenPos(), text, widget=self.plot_widget)
            return True
        else:
            return False
Пример #8
0
class OWDistanceMatrix(widget.OWWidget):
    name = "Distance Matrix"
    description = "View distance matrix."
    icon = "icons/DistanceMatrix.svg"
    priority = 200
    keywords = []

    class Inputs:
        distances = Input("Distances", DistMatrix)

    class Outputs:
        distances = Output("Distances", DistMatrix, dynamic=False)
        table = Output("Selected Data", Table, replaces=["Table"])

    settingsHandler = DistanceMatrixContextHandler()
    auto_commit = Setting(True)
    annotation_idx = ContextSetting(1)
    selection = ContextSetting([])

    want_control_area = True
    want_main_area = False

    def __init__(self):
        super().__init__()
        self.distances = None
        self.items = None

        self.tablemodel = DistanceMatrixModel()
        view = self.tableview = TableView()
        view.setWordWrap(False)
        view.setTextElideMode(Qt.ElideNone)
        view.setEditTriggers(QTableView.NoEditTriggers)
        view.setItemDelegate(
            TableBorderItem(roles=(Qt.DisplayRole, Qt.BackgroundRole,
                                   Qt.ForegroundRole)))
        view.setModel(self.tablemodel)
        view.setShowGrid(False)
        for header in (view.horizontalHeader(), view.verticalHeader()):
            header.setResizeContentsPrecision(1)
            header.setSectionResizeMode(QHeaderView.ResizeToContents)
            header.setHighlightSections(True)
            header.setSectionsClickable(False)
        view.verticalHeader().setDefaultAlignment(Qt.AlignRight
                                                  | Qt.AlignVCenter)
        selmodel = SymmetricSelectionModel(view.model(), view)
        selmodel.selectionChanged.connect(self.commit.deferred)
        view.setSelectionModel(selmodel)
        view.setSelectionBehavior(QTableView.SelectItems)
        self.controlArea.layout().addWidget(view)

        self.annot_combo = gui.comboBox(self.buttonsArea,
                                        self,
                                        "annotation_idx",
                                        label="Labels: ",
                                        orientation=Qt.Horizontal,
                                        callback=self._invalidate_annotations,
                                        contentsLength=12)
        self.annot_combo.setModel(VariableListModel())
        self.annot_combo.model()[:] = ["None", "Enumeration"]
        gui.rubber(self.buttonsArea)
        acb = gui.auto_send(self.buttonsArea, self, "auto_commit", box=False)
        acb.setFixedWidth(200)

    def sizeHint(self):
        return QSize(800, 500)

    @Inputs.distances
    def set_distances(self, distances):
        self.closeContext()
        self.distances = distances
        self.tablemodel.set_data(self.distances)
        self.selection = []
        self.tableview.selectionModel().clear()

        self.items = items = distances is not None and distances.row_items
        annotations = ["None", "Enumerate"]
        pending_idx = 1
        if items and not distances.axis:
            annotations.append("Attribute names")
            pending_idx = 2
        elif isinstance(items, list) and \
                all(isinstance(item, Variable) for item in items):
            annotations.append("Name")
            pending_idx = 2
        elif isinstance(items, Table):
            annotations.extend(
                itertools.chain(items.domain.variables, items.domain.metas))
            if items.domain.class_var:
                pending_idx = 2 + len(items.domain.attributes)
        self.annot_combo.model()[:] = annotations
        self.annotation_idx = pending_idx

        if items:
            self.openContext(distances, annotations)
            self._update_labels()
            self.tableview.resizeColumnsToContents()
        self.commit.now()

    def _invalidate_annotations(self):
        if self.distances is not None:
            self._update_labels()

    def _update_labels(self):
        var = column = None
        if self.annotation_idx == 0:
            labels = None
        elif self.annotation_idx == 1:
            labels = [str(i + 1) for i in range(self.distances.shape[0])]
        elif self.annot_combo.model()[
                self.annotation_idx] == "Attribute names":
            attr = self.distances.row_items.domain.attributes
            labels = [str(attr[i]) for i in range(self.distances.shape[0])]
        elif self.annotation_idx == 2 and \
                isinstance(self.items, widget.AttributeList):
            labels = [v.name for v in self.items]
        elif isinstance(self.items, Table):
            var = self.annot_combo.model()[self.annotation_idx]
            column, _ = self.items.get_column_view(var)
            labels = [var.str_val(value) for value in column]
        if labels:
            self.tableview.horizontalHeader().show()
            self.tableview.verticalHeader().show()
        else:
            self.tableview.horizontalHeader().hide()
            self.tableview.verticalHeader().hide()
        self.tablemodel.set_labels(labels, var, column)
        self.tableview.resizeColumnsToContents()

    @gui.deferred
    def commit(self):
        sub_table = sub_distances = None
        if self.distances is not None:
            inds = self.tableview.selectionModel().selectedItems()
            if inds:
                sub_distances = self.distances.submatrix(inds)
                if self.distances.axis and isinstance(self.items, Table):
                    sub_table = self.items[inds]
        self.Outputs.distances.send(sub_distances)
        self.Outputs.table.send(sub_table)

    def send_report(self):
        if self.distances is None:
            return
        model = self.tablemodel
        dim = self.distances.shape[0]
        col_cell = model.color_for_cell

        def _rgb(brush):
            return "rgb({}, {}, {})".format(*brush.color().getRgb())

        if model.labels:
            col_label = model.color_for_label
            label_colors = [_rgb(col_label(i)) for i in range(dim)]
            self.report_raw('<table style="border-collapse:collapse">')
            self.report_raw("<tr><td></td>")
            self.report_raw("".join(
                '<td style="background-color: {}">{}</td>'.format(*cv)
                for cv in zip(label_colors, model.labels)))
            self.report_raw("</tr>")
            for i in range(dim):
                self.report_raw("<tr>")
                self.report_raw(
                    '<td style="background-color: {}">{}</td>'.format(
                        label_colors[i], model.labels[i]))
                self.report_raw("".join(
                    '<td style="background-color: {};'
                    'border-top:1px solid {}; border-left:1px solid {};">'
                    '{:.3f}</td>'.format(_rgb(col_cell(i, j)), label_colors[i],
                                         label_colors[j], self.distances[i, j])
                    for j in range(dim)))
                self.report_raw("</tr>")
            self.report_raw("</table>")
        else:
            self.report_raw('<table>')
            for i in range(dim):
                self.report_raw("<tr>" + "".join(
                    '<td style="background-color: {}">{:.3f}</td>'.format(
                        _rgb(col_cell(i, j)), self.distances[i, j])
                    for j in range(dim)) + "</tr>")
            self.report_raw("</table>")
Пример #9
0
class OWCorrelations(OWWidget):
    name = "Correlations"
    description = "Compute all pairwise attribute correlations."
    icon = "icons/Correlations.svg"
    priority = 1106
    category = "Unsupervised"

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table)
        features = Output("Features", AttributeList)
        correlations = Output("Correlations", Table)

    want_main_area = False
    want_control_area = True

    correlation_type: int

    settings_version = 3
    settingsHandler = DomainContextHandler()
    selection = ContextSetting([])
    feature = ContextSetting(None)
    correlation_type = Setting(0)

    class Information(OWWidget.Information):
        removed_cons_feat = Msg("Constant features have been removed.")

    class Warning(OWWidget.Warning):
        not_enough_vars = Msg("At least two numeric features are needed.")
        not_enough_inst = Msg("At least two instances are needed.")

    def __init__(self):
        super().__init__()
        self.data = None  # type: Table
        self.cont_data = None  # type: Table

        # GUI
        box = gui.vBox(self.controlArea)
        self.correlation_combo = gui.comboBox(
            box, self, "correlation_type", items=CorrelationType.items(),
            orientation=Qt.Horizontal, callback=self._correlation_combo_changed
        )

        self.feature_model = DomainModel(
            order=DomainModel.ATTRIBUTES, separators=False,
            placeholder="(All combinations)", valid_types=ContinuousVariable)
        gui.comboBox(
            box, self, "feature", callback=self._feature_combo_changed,
            model=self.feature_model
        )

        self.vizrank, _ = CorrelationRank.add_vizrank(
            None, self, None, self._vizrank_selection_changed)
        self.vizrank.button.setEnabled(False)
        self.vizrank.threadStopped.connect(self._vizrank_stopped)

        gui.separator(box)
        box.layout().addWidget(self.vizrank.filter)
        box.layout().addWidget(self.vizrank.rank_table)

        button_box = gui.hBox(self.buttonsArea)
        button_box.layout().addWidget(self.vizrank.button)

    @staticmethod
    def sizeHint():
        return QSize(350, 400)

    def _correlation_combo_changed(self):
        self.apply()

    def _feature_combo_changed(self):
        self.apply()

    def _vizrank_selection_changed(self, *args):
        self.selection = list(args)
        self.commit()

    def _vizrank_stopped(self):
        self._vizrank_select()

    def _vizrank_select(self):
        model = self.vizrank.rank_table.model()
        if not model.rowCount():
            return
        selection = QItemSelection()

        # This flag is needed because data in the model could be
        # filtered by a feature and therefore selection could not be found
        selection_in_model = False
        if self.selection:
            sel_names = sorted(var.name for var in self.selection)
            for i in range(model.rowCount()):
                # pylint: disable=protected-access
                names = sorted(x.name for x in model.data(
                    model.index(i, 0), CorrelationRank._AttrRole))
                if names == sel_names:
                    selection.select(model.index(i, 0),
                                     model.index(i, model.columnCount() - 1))
                    selection_in_model = True
                    break
        if not selection_in_model:
            selection.select(model.index(0, 0),
                             model.index(0, model.columnCount() - 1))
        self.vizrank.rank_table.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear_messages()
        self.data = data
        self.cont_data = None
        self.selection = []
        if data is not None:
            if len(data) < 2:
                self.Warning.not_enough_inst()
            else:
                domain = data.domain
                cont_vars = [a for a in domain.class_vars + domain.metas +
                             domain.attributes if a.is_continuous]
                cont_data = Table.from_table(Domain(cont_vars), data)
                remover = Remove(Remove.RemoveConstant)
                cont_data = remover(cont_data)
                if remover.attr_results["removed"]:
                    self.Information.removed_cons_feat()
                if len(cont_data.domain.attributes) < 2:
                    self.Warning.not_enough_vars()
                else:
                    self.cont_data = SklImpute()(cont_data)
        self.set_feature_model()
        self.openContext(self.cont_data)
        self.apply()
        self.vizrank.button.setEnabled(self.cont_data is not None)

    def set_feature_model(self):
        self.feature_model.set_domain(
            self.cont_data.domain if self.cont_data else None)
        data = self.data
        if self.cont_data and data.domain.has_continuous_class:
            self.feature = self.cont_data.domain[data.domain.class_var.name]
        else:
            self.feature = None

    def apply(self):
        self.vizrank.initialize()
        if self.cont_data is not None:
            # this triggers self.commit() by changing vizrank selection
            self.vizrank.toggle()
        else:
            self.commit()

    def commit(self):
        self.Outputs.data.send(self.data)

        if self.data is None or self.cont_data is None:
            self.Outputs.features.send(None)
            self.Outputs.correlations.send(None)
            return

        attrs = [ContinuousVariable("Correlation"), ContinuousVariable("FDR")]
        metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
        domain = Domain(attrs, metas=metas)
        model = self.vizrank.rank_model
        x = np.array([[float(model.data(model.index(row, 0), role))
                       for role in (Qt.DisplayRole, CorrelationRank.PValRole)]
                      for row in range(model.rowCount())])
        x[:, 1] = FDR(list(x[:, 1]))
        # pylint: disable=protected-access
        m = np.array([[a.name for a in model.data(model.index(row, 0),
                                                  CorrelationRank._AttrRole)]
                      for row in range(model.rowCount())], dtype=object)
        corr_table = Table(domain, x, metas=m)
        corr_table.name = "Correlations"

        # data has been imputed; send original attributes
        self.Outputs.features.send(AttributeList(
            [self.data.domain[var.name] for var in self.selection]))
        self.Outputs.correlations.send(corr_table)

    def send_report(self):
        self.report_table(CorrelationType.items()[self.correlation_type],
                          self.vizrank.rank_table)

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            sel = context.values["selection"]
            context.values["selection"] = [(var.name, vartype(var))
                                           for var in sel[0]]
        if version < 3:
            sel = context.values["selection"]
            context.values["selection"] = ([(name, vtype + 100)
                                            for name, vtype in sel], -3)
Пример #10
0
class OWSpectralResidualTransform(SingleInputWidget):
    name = "SpectralResidualTransform"
    description = ("Spectral Residual For feature Analysis.")
    icon = "icons/SpectralResidualTransform.svg"
    category = "Feature Analysis"
    keywords = []

    want_main_area = False
    buttons_area_orientatio = Qt.Vertical
    resizing_enabled = False

    # set default hyperparameters here
    autosend = Setting(True)

    avg_filter_dimension = Setting(1)

    return_subseq_inds = Setting(False)

    use_columns_buf = Setting(())
    use_columns = ()
    exclude_columns_buf = Setting(())
    exclude_columns = ()
    return_result = Setting('new')
    use_semantic_types = Setting(False)
    add_index_columns = Setting(False)
    error_on_no_input = Setting(True)
    return_semantic_type = Setting(
        'https://metadata.datadrivendiscovery.org/types/Attribute')

    primitive = SpectralResidualTransformPrimitive

    def _use_columns_callback(self):
        self.use_columns = eval(''.join(self.use_columns_buf))
        # print(self.use_columns)
        self.settings_changed()

    def _exclude_columns_callback(self):
        self.exclude_columns = eval(''.join(self.exclude_columns_buf))
        # print(self.exclude_columns)
        self.settings_changed()

    def _init_ui(self):
        # implement your user interface here (for setting hyperparameters)
        gui.separator(self.controlArea)
        box = gui.widgetBox(self.controlArea, "Hyperparameter")
        gui.separator(self.controlArea)

        gui.lineEdit(box,
                     self,
                     'avg_filter_dimension',
                     label='The square filter dimension. (IntVariable)',
                     callback=None)

        # return_subseq_inds = Setting(False)
        gui.checkBox(box,
                     self,
                     "return_subseq_inds",
                     label='If return subsequence index.',
                     callback=None)

        # use_semantic_types = Setting(False)
        gui.checkBox(box,
                     self,
                     "use_semantic_types",
                     label='Mannally select columns if active.',
                     callback=None)

        # use_columns = Setting(())
        gui.lineEdit(
            box,
            self,
            "use_columns_buf",
            label=
            'Column index to use when use_semantic_types is activated. Tuple, e.g. (0,1,2)',
            validator=None,
            callback=self._use_columns_callback)

        # exclude_columns = Setting(())
        gui.lineEdit(
            box,
            self,
            "exclude_columns_buf",
            label=
            'Column index to exclude when use_semantic_types is activated. Tuple, e.g. (0,1,2)',
            validator=None,
            callback=self._exclude_columns_callback)

        # return_result = Setting(['append', 'replace', 'new'])
        gui.comboBox(
            box,
            self,
            "return_result",
            sendSelectedValue=True,
            label='Output results.',
            items=['new', 'append', 'replace'],
        )

        # add_index_columns = Setting(False)
        gui.checkBox(box,
                     self,
                     "add_index_columns",
                     label='Keep index in the outputs.',
                     callback=None)

        # error_on_no_input = Setting(True)
        gui.checkBox(box,
                     self,
                     "error_on_no_input",
                     label='Error on no input.',
                     callback=None)

        # return_semantic_type = Setting(['https://metadata.datadrivendiscovery.org/types/Attribute',
        #                                 'https://metadata.datadrivendiscovery.org/types/ConstructedAttribute'])
        gui.comboBox(
            box,
            self,
            "return_semantic_type",
            sendSelectedValue=True,
            label='Semantic type attach with results.',
            items=[
                'https://metadata.datadrivendiscovery.org/types/Attribute',
                'https://metadata.datadrivendiscovery.org/types/ConstructedAttribute'
            ],
        )
        # Only for test
        gui.button(box,
                   self,
                   "Print Hyperparameters",
                   callback=self._print_hyperparameter)

        gui.auto_apply(box, self, "autosend", box=False)

        self.data = None
        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

    def _print_hyperparameter(self):
        print(self.avg_filter_dimension, type(self.avg_filter_dimension))
        #print(self.IntVariable, type(self.IntVariable))
        print(self.return_subseq_inds, type(self.return_subseq_inds))
        print(self.use_columns, type(self.use_columns))
        print(self.exclude_columns, type(self.exclude_columns))
        print(self.return_result, type(self.return_result))
        print(self.use_semantic_types, type(self.use_semantic_types))
        print(self.add_index_columns, type(self.add_index_columns))
        print(self.error_on_no_input, type(self.error_on_no_input))
        print(self.return_semantic_type, type(self.return_semantic_type))

        self.commit()
Пример #11
0
class OWNomogram(OWWidget):
    name = "Nomogram"
    description = " Nomograms for Visualization of Naive Bayesian" \
                  " and Logistic Regression Classifiers."
    icon = "icons/Nomogram.svg"
    priority = 2000

    inputs = [("Classifier", Model, "set_classifier"),
              ("Data", Table, "set_instance")]

    MAX_N_ATTRS = 1000
    POINT_SCALE = 0
    ALIGN_LEFT = 0
    ALIGN_ZERO = 1
    ACCEPTABLE = (NaiveBayesModel, LogisticRegressionClassifier)
    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    normalize_probabilities = Setting(False)
    scale = Setting(1)
    display_index = Setting(1)
    n_attributes = Setting(10)
    sort_index = Setting(SortBy.ABSOLUTE)
    cont_feature_dim_index = Setting(0)

    graph_name = "scene"

    class Error(OWWidget.Error):
        invalid_classifier = Msg("Nomogram accepts only Naive Bayes and "
                                 "Logistic Regression classifiers.")

    def __init__(self):
        super().__init__()
        self.instances = None
        self.domain = None
        self.data = None
        self.classifier = None
        self.align = OWNomogram.ALIGN_ZERO
        self.log_odds_ratios = []
        self.log_reg_coeffs = []
        self.log_reg_coeffs_orig = []
        self.log_reg_cont_data_extremes = []
        self.p = None
        self.b0 = None
        self.points = []
        self.feature_items = []
        self.feature_marker_values = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.old_target_class_index = self.target_class_index
        self.markers_set = False
        self.repaint = False

        # GUI
        box = gui.vBox(self.controlArea, "Target class")
        self.class_combo = gui.comboBox(box,
                                        self,
                                        "target_class_index",
                                        callback=self._class_combo_changed,
                                        contentsLength=12)
        self.norm_check = gui.checkBox(
            box,
            self,
            "normalize_probabilities",
            "Normalize probabilities",
            hidden=True,
            callback=self._norm_check_changed,
            tooltip="For multiclass data 1 vs. all probabilities do not"
            " sum to 1 and therefore could be normalized.")

        self.scale_radio = gui.radioButtons(
            self.controlArea,
            self,
            "scale", ["Point scale", "Log odds ratios"],
            box="Scale",
            callback=self._radio_button_changed)

        box = gui.vBox(self.controlArea, "Display features")
        grid = QGridLayout()
        self.display_radio = gui.radioButtonsInBox(
            box,
            self,
            "display_index", [],
            orientation=grid,
            callback=self._display_radio_button_changed)
        radio_all = gui.appendRadioButton(self.display_radio,
                                          "All:",
                                          addToLayout=False)
        radio_best = gui.appendRadioButton(self.display_radio,
                                           "Best ranked:",
                                           addToLayout=False)
        spin_box = gui.hBox(None, margin=0)
        self.n_spin = gui.spin(spin_box,
                               self,
                               "n_attributes",
                               1,
                               self.MAX_N_ATTRS,
                               label=" ",
                               controlWidth=60,
                               callback=self._n_spin_changed)
        grid.addWidget(radio_all, 1, 1)
        grid.addWidget(radio_best, 2, 1)
        grid.addWidget(spin_box, 2, 2)

        self.sort_combo = gui.comboBox(box,
                                       self,
                                       "sort_index",
                                       label="Sort by: ",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._sort_combo_changed)

        self.cont_feature_dim_combo = gui.comboBox(
            box,
            self,
            "cont_feature_dim_index",
            label="Continuous features: ",
            items=["1D projection", "2D curve"],
            orientation=Qt.Horizontal,
            callback=self._cont_feature_dim_combo_changed)

        gui.rubber(self.controlArea)

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(
            self.scene,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            renderHints=QPainter.Antialiasing | QPainter.TextAntialiasing
            | QPainter.SmoothPixmapTransform,
            alignment=Qt.AlignLeft)
        self.view.viewport().installEventFilter(self)
        self.view.viewport().setMinimumWidth(300)
        self.view.sizeHint = lambda: QSize(600, 500)
        self.mainArea.layout().addWidget(self.view)

    def _class_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        coeffs = [
            np.nan_to_num(p[self.target_class_index] /
                          p[self.old_target_class_index]) for p in self.points
        ]
        points = [p[self.old_target_class_index] for p in self.points]
        self.feature_marker_values = [
            self.get_points_from_coeffs(v, c, p)
            for (v, c, p) in zip(self.feature_marker_values, coeffs, points)
        ]
        self.update_scene()
        self.old_target_class_index = self.target_class_index

    def _norm_check_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _radio_button_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def _display_radio_button_changed(self):
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _n_spin_changed(self):
        self.display_index = 1
        self.__hide_attrs(self.n_attributes)

    def __hide_attrs(self, n_show):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(n_show)
        if self.vertical_line:
            x = self.vertical_line.line().x1()
            y = self.nomogram_main.layout.preferredHeight() + 30
            self.vertical_line.setLine(x, -6, x, y)
            self.hidden_vertical_line.setLine(x, -6, x, y)
        rect = QRectF(self.scene.sceneRect().x(),
                      self.scene.sceneRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def _sort_combo_changed(self):
        if self.nomogram_main is None:
            return
        self.nomogram_main.hide(None)
        self.nomogram_main.sort(self.sort_index)
        self.__hide_attrs(self.n_attributes if self.display_index else None)

    def _cont_feature_dim_combo_changed(self):
        values = [item.dot.value for item in self.feature_items]
        self.feature_marker_values = self.scale_back(values)
        self.update_scene()

    def eventFilter(self, obj, event):
        if obj is self.view.viewport() and event.type() == QEvent.Resize:
            self.repaint = True
            values = [item.dot.value for item in self.feature_items]
            self.feature_marker_values = self.scale_back(values)
            self.update_scene()
        return super().eventFilter(obj, event)

    def update_controls(self):
        self.class_combo.clear()
        self.norm_check.setHidden(True)
        self.cont_feature_dim_combo.setEnabled(True)
        if self.domain:
            self.class_combo.addItems(self.domain.class_vars[0].values)
            if len(self.domain.attributes) > self.MAX_N_ATTRS:
                self.display_index = 1
            if len(self.domain.class_vars[0].values) > 2:
                self.norm_check.setHidden(False)
            if not self.domain.has_continuous_attributes():
                self.cont_feature_dim_combo.setEnabled(False)
                self.cont_feature_dim_index = 0
        model = self.sort_combo.model()
        item = model.item(SortBy.POSITIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        item = model.item(SortBy.NEGATIVE)
        item.setFlags(item.flags() | Qt.ItemIsEnabled)
        self.align = OWNomogram.ALIGN_ZERO
        if self.classifier and isinstance(self.classifier,
                                          LogisticRegressionClassifier):
            self.align = OWNomogram.ALIGN_LEFT
            item = model.item(SortBy.POSITIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            item = model.item(SortBy.NEGATIVE)
            item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
            if self.sort_index in (SortBy.POSITIVE, SortBy.POSITIVE):
                self.sort_index = SortBy.NO_SORTING

    def set_instance(self, data):
        self.instances = data
        self.feature_marker_values = []
        self.set_feature_marker_values()

    def set_classifier(self, classifier):
        self.closeContext()
        self.classifier = classifier
        self.Error.clear()
        if self.classifier and not isinstance(self.classifier,
                                              self.ACCEPTABLE):
            self.Error.invalid_classifier()
            self.classifier = None
        self.domain = self.classifier.domain if self.classifier else None
        self.data = None
        self.calculate_log_odds_ratios()
        self.calculate_log_reg_coefficients()
        self.update_controls()
        self.target_class_index = 0
        self.openContext(self.domain and self.domain.class_var)
        self.points = self.log_odds_ratios or self.log_reg_coeffs
        self.feature_marker_values = []
        self.old_target_class_index = self.target_class_index
        self.update_scene()

    def calculate_log_odds_ratios(self):
        self.log_odds_ratios = []
        self.p = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, NaiveBayesModel):
            return

        log_cont_prob = self.classifier.log_cont_prob
        class_prob = self.classifier.class_prob
        for i in range(len(self.domain.attributes)):
            ca = np.exp(log_cont_prob[i]) * class_prob[:, None]
            _or = (ca / (1 - ca)) / (class_prob / (1 - class_prob))[:, None]
            self.log_odds_ratios.append(np.log(_or))
        self.p = class_prob

    def calculate_log_reg_coefficients(self):
        self.log_reg_coeffs = []
        self.log_reg_cont_data_extremes = []
        self.b0 = None
        if self.classifier is None or self.domain is None:
            return
        if not isinstance(self.classifier, LogisticRegressionClassifier):
            return

        self.domain = self.reconstruct_domain(self.classifier.original_domain,
                                              self.domain)
        self.data = Table.from_table(self.domain,
                                     self.classifier.original_data)
        attrs, ranges, start = self.domain.attributes, [], 0
        for attr in attrs:
            stop = start + len(attr.values) if attr.is_discrete else start + 1
            ranges.append(slice(start, stop))
            start = stop

        self.b0 = self.classifier.intercept
        coeffs = self.classifier.coefficients
        if len(self.domain.class_var.values) == 2:
            self.b0 = np.hstack((self.b0 * (-1), self.b0))
            coeffs = np.vstack((coeffs * (-1), coeffs))
        self.log_reg_coeffs = [coeffs[:, ranges[i]] for i in range(len(attrs))]
        self.log_reg_coeffs_orig = self.log_reg_coeffs.copy()

        min_values = nanmin(self.data.X, axis=0)
        max_values = nanmax(self.data.X, axis=0)

        for i, min_t, max_t in zip(range(len(self.log_reg_coeffs)), min_values,
                                   max_values):
            if self.log_reg_coeffs[i].shape[1] == 1:
                coef = self.log_reg_coeffs[i]
                self.log_reg_coeffs[i] = np.hstack(
                    (coef * min_t, coef * max_t))
                self.log_reg_cont_data_extremes.append(
                    [sorted([min_t, max_t], reverse=(c < 0)) for c in coef])
            else:
                self.log_reg_cont_data_extremes.append([None])

    def update_scene(self):
        if not self.repaint:
            return
        self.clear_scene()
        if self.domain is None or not len(self.points[0]):
            return

        name_items = [
            QGraphicsTextItem(a.name) for a in self.domain.attributes
        ]
        point_text = QGraphicsTextItem("Points")
        probs_text = QGraphicsTextItem("Probabilities (%)")
        all_items = name_items + [point_text, probs_text]
        name_offset = -max(t.boundingRect().width() for t in all_items) - 50
        w = self.view.viewport().rect().width()
        max_width = w + name_offset - 100

        points = [pts[self.target_class_index] for pts in self.points]
        minimums = [min(p) for p in points]
        if self.align == OWNomogram.ALIGN_LEFT:
            points = [p - m for m, p in zip(minimums, points)]
        max_ = np.nan_to_num(max(max(abs(p)) for p in points))
        d = 100 / max_ if max_ else 1
        if self.scale == OWNomogram.POINT_SCALE:
            points = [p * d for p in points]

        if self.scale == OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [
                p / d + m for m, p in zip(minimums, x)
            ]
            self.scale_forth = lambda x: [(p - m) * d
                                          for m, p in zip(minimums, x)]
        if self.scale == OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p / d for p in x]
            self.scale_forth = lambda x: [p * d for p in x]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align == OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: [p + m for m, p in zip(minimums, x)]
            self.scale_forth = lambda x: [p - m for m, p in zip(minimums, x)]
        if self.scale != OWNomogram.POINT_SCALE and \
                self.align != OWNomogram.ALIGN_LEFT:
            self.scale_back = lambda x: x
            self.scale_forth = lambda x: x

        point_item, nomogram_head = self.create_main_nomogram(
            name_items, points, max_width, point_text, name_offset)
        probs_item, nomogram_foot = self.create_footer_nomogram(
            probs_text, d, minimums, max_width, name_offset)
        for item in self.feature_items:
            item.dot.point_dot = point_item.dot
            item.dot.probs_dot = probs_item.dot
            item.dot.vertical_line = self.hidden_vertical_line

        self.nomogram = nomogram = NomogramItem()
        nomogram.add_items([nomogram_head, self.nomogram_main, nomogram_foot])
        self.scene.addItem(nomogram)
        self.set_feature_marker_values()
        rect = QRectF(self.scene.itemsBoundingRect().x(),
                      self.scene.itemsBoundingRect().y(),
                      self.scene.itemsBoundingRect().width(),
                      self.nomogram.preferredSize().height())
        self.scene.setSceneRect(rect.adjusted(0, 0, 70, 70))

    def create_main_nomogram(self, name_items, points, max_width, point_text,
                             name_offset):
        cls_index = self.target_class_index
        min_p = min(min(p) for p in points)
        max_p = max(max(p) for p in points)
        values = self.get_ruler_values(min_p, max_p, max_width)
        min_p, max_p = min(values), max(values)
        diff_ = np.nan_to_num(max_p - min_p)
        scale_x = max_width / diff_ if diff_ else max_width

        nomogram_header = NomogramItem()
        point_item = RulerItem(point_text, values, scale_x, name_offset,
                               -scale_x * min_p)
        point_item.setPreferredSize(point_item.preferredWidth(), 35)
        nomogram_header.add_items([point_item])

        self.nomogram_main = SortableNomogramItem()
        cont_feature_item_class = ContinuousFeature2DItem if \
            self.cont_feature_dim_index else ContinuousFeatureItem
        self.feature_items = [
            DiscreteFeatureItem(name_items[i], [val for val in att.values],
                                points[i], scale_x, name_offset, -scale_x *
                                min_p, self.points[i][cls_index])
            if att.is_discrete else cont_feature_item_class(
                name_items[i], self.log_reg_cont_data_extremes[i][cls_index],
                self.get_ruler_values(
                    np.min(points[i]), np.max(points[i]),
                    scale_x * (np.max(points[i]) - np.min(points[i])),
                    False), scale_x, name_offset, -scale_x *
                min_p, self.log_reg_coeffs_orig[i][cls_index][0])
            for i, att in enumerate(self.domain.attributes)
        ]
        self.nomogram_main.add_items(
            self.feature_items, self.sort_index,
            self.n_attributes if self.display_index else None)

        x = -scale_x * min_p
        y = self.nomogram_main.layout.preferredHeight() + 30
        self.vertical_line = QGraphicsLineItem(x, -6, x, y)
        self.vertical_line.setPen(QPen(Qt.DotLine))
        self.vertical_line.setParentItem(point_item)
        self.hidden_vertical_line = QGraphicsLineItem(x, -6, x, y)
        pen = QPen(Qt.DashLine)
        pen.setBrush(QColor(Qt.red))
        self.hidden_vertical_line.setPen(pen)
        self.hidden_vertical_line.setParentItem(point_item)

        return point_item, nomogram_header

    def create_footer_nomogram(self, probs_text, d, minimums, max_width,
                               name_offset):
        eps, d_ = 0.05, 1
        k = -np.log(self.p / (1 - self.p)) if self.p is not None else -self.b0
        min_sum = k[self.target_class_index] - np.log((1 - eps) / eps)
        max_sum = k[self.target_class_index] - np.log(eps / (1 - eps))
        if self.align == OWNomogram.ALIGN_LEFT:
            max_sum = max_sum - sum(minimums)
            min_sum = min_sum - sum(minimums)
            for i in range(len(k)):
                k[i] = k[i] - sum(
                    [min(q) for q in [p[i] for p in self.points]])
        if self.scale == OWNomogram.POINT_SCALE:
            min_sum *= d
            max_sum *= d
            d_ = d

        values = self.get_ruler_values(min_sum, max_sum, max_width)
        min_sum, max_sum = min(values), max(values)
        diff_ = np.nan_to_num(max_sum - min_sum)
        scale_x = max_width / diff_ if diff_ else max_width
        cls_var, cls_index = self.domain.class_var, self.target_class_index
        nomogram_footer = NomogramItem()

        def get_normalized_probabilities(val):
            if not self.normalize_probabilities:
                return 1 / (1 + np.exp(k[cls_index] - val / d_))
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return 1 / (1 + np.exp(k[cls_index] - val / d_)) / p_sum

        def get_points(prob):
            if not self.normalize_probabilities:
                return (k[cls_index] - np.log(1 / prob - 1)) * d_
            totals = self.__get_totals_for_class_values(minimums)
            p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
            return (k[cls_index] - np.log(1 / (prob * p_sum) - 1)) * d_

        self.markers_set = False
        probs_item = ProbabilitiesRulerItem(
            probs_text,
            values,
            scale_x,
            name_offset,
            -scale_x * min_sum,
            get_points=get_points,
            title="{}='{}'".format(cls_var.name, cls_var.values[cls_index]),
            get_probabilities=get_normalized_probabilities)
        self.markers_set = True
        nomogram_footer.add_items([probs_item])
        return probs_item, nomogram_footer

    def __get_totals_for_class_values(self, minimums):
        cls_index = self.target_class_index
        marker_values = [item.dot.value for item in self.feature_items]
        if not self.markers_set:
            marker_values = self.scale_forth(marker_values)
        totals = np.empty(len(self.domain.class_var.values))
        totals[cls_index] = sum(marker_values)
        marker_values = self.scale_back(marker_values)
        for i in range(len(self.domain.class_var.values)):
            if i == cls_index:
                continue
            coeffs = [np.nan_to_num(p[i] / p[cls_index]) for p in self.points]
            points = [p[cls_index] for p in self.points]
            total = sum([
                self.get_points_from_coeffs(v, c, p)
                for (v, c, p) in zip(marker_values, coeffs, points)
            ])
            if self.align == OWNomogram.ALIGN_LEFT:
                points = [p - m for m, p in zip(minimums, points)]
                total -= sum([min(p) for p in [p[i] for p in self.points]])
            d = 100 / max(max(abs(p)) for p in points)
            if self.scale == OWNomogram.POINT_SCALE:
                total *= d
            totals[i] = total
        return totals

    def set_feature_marker_values(self):
        if not (len(self.points) and len(self.feature_items)):
            return
        if not len(self.feature_marker_values):
            self._init_feature_marker_values()
        self.feature_marker_values = self.scale_forth(
            self.feature_marker_values)
        item = self.feature_items[0]
        for i, item in enumerate(self.feature_items):
            item.dot.move_to_val(self.feature_marker_values[i])
        item.dot.probs_dot.move_to_sum()

    def _init_feature_marker_values(self):
        self.feature_marker_values = []
        cls_index = self.target_class_index
        instances = Table(self.domain, self.instances) \
            if self.instances else None
        for i, attr in enumerate(self.domain.attributes):
            value, feature_val = 0, None
            if len(self.log_reg_coeffs):
                if attr.is_discrete:
                    ind, n = unique(self.data.X[:, i], return_counts=True)
                    feature_val = np.nan_to_num(ind[np.argmax(n)])
                else:
                    feature_val = mean(self.data.X[:, i])
            inst_in_dom = instances and attr in instances.domain
            if inst_in_dom and not np.isnan(instances[0][attr]):
                feature_val = instances[0][attr]
            if feature_val is not None:
                value = self.points[i][cls_index][int(feature_val)] \
                    if attr.is_discrete else \
                    self.log_reg_coeffs_orig[i][cls_index][0] * feature_val
            self.feature_marker_values.append(value)

    def clear_scene(self):
        self.feature_items = []
        self.scale_back = lambda x: x
        self.scale_forth = lambda x: x
        self.nomogram = None
        self.nomogram_main = None
        self.vertical_line = None
        self.hidden_vertical_line = None
        self.scene.clear()

    def send_report(self):
        self.report_plot()

    @staticmethod
    def reconstruct_domain(original, preprocessed):
        # abuse dict to make "in" comparisons faster
        attrs = OrderedDict()
        for attr in preprocessed.attributes:
            cv = attr._compute_value.variable._compute_value
            var = cv.variable if cv else original[attr.name]
            if var in attrs:  # the reason for OrderedDict
                continue
            attrs[var] = None  # we only need keys
        attrs = list(attrs.keys())
        return Domain(attrs, original.class_var, original.metas)

    @staticmethod
    def get_ruler_values(start, stop, max_width, round_to_nearest=True):
        if max_width == 0:
            return [0]
        diff = np.nan_to_num((stop - start) / max_width)
        if diff <= 0:
            return [0]
        decimals = int(np.floor(np.log10(diff)))
        if diff > 4 * pow(10, decimals):
            step = 5 * pow(10, decimals + 2)
        elif diff > 2 * pow(10, decimals):
            step = 2 * pow(10, decimals + 2)
        elif diff > 1 * pow(10, decimals):
            step = 1 * pow(10, decimals + 2)
        else:
            step = 5 * pow(10, decimals + 1)
        round_by = int(-np.floor(np.log10(step)))
        r = start % step
        if not round_to_nearest:
            _range = np.arange(start + step, stop + r, step) - r
            start, stop = np.floor(start * 100) / 100, np.ceil(
                stop * 100) / 100
            return np.round(np.hstack((start, _range, stop)), 2)
        return np.round(np.arange(start, stop + r + step, step) - r, round_by)

    @staticmethod
    def get_points_from_coeffs(current_value, coefficients, possible_values):
        if any(np.isnan(possible_values)):
            return 0
        indices = np.argsort(possible_values)
        sorted_values = possible_values[indices]
        sorted_coefficients = coefficients[indices]
        for i, val in enumerate(sorted_values):
            if current_value < val:
                break
        diff = sorted_values[i] - sorted_values[i - 1]
        k = 0 if diff < 1e-6 else (sorted_values[i] - current_value) / \
                                  (sorted_values[i] - sorted_values[i - 1])
        return sorted_coefficients[i - 1] * sorted_values[i - 1] * k + \
               sorted_coefficients[i] * sorted_values[i] * (1 - k)
Пример #12
0
class ScoreTable(OWComponent, QObject):
    shown_scores = \
        Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))

    shownScoresChanged = Signal()

    class ItemDelegate(QStyledItemDelegate):
        def sizeHint(self, *args):
            size = super().sizeHint(*args)
            return QSize(size.width(), size.height() + 6)

        def displayText(self, value, locale):
            if isinstance(value, float):
                return f"{value:.3f}"
            else:
                return super().displayText(value, locale)

    def __init__(self, master):
        QObject.__init__(self)
        OWComponent.__init__(self, master)

        self.view = gui.TableView(wordWrap=True,
                                  editTriggers=gui.TableView.NoEditTriggers)
        header = self.view.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setStretchLastSection(False)
        header.setContextMenuPolicy(Qt.CustomContextMenu)
        header.customContextMenuRequested.connect(self.show_column_chooser)

        self.model = QStandardItemModel(master)
        self.model.setHorizontalHeaderLabels(["Method"])
        self.sorted_model = ScoreModel()
        self.sorted_model.setSourceModel(self.model)
        self.view.setModel(self.sorted_model)
        self.view.setItemDelegate(self.ItemDelegate())

    def _column_names(self):
        return (self.model.horizontalHeaderItem(section).data(Qt.DisplayRole)
                for section in range(1, self.model.columnCount()))

    def show_column_chooser(self, pos):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        def update(col_name, checked):
            if checked:
                self.shown_scores.add(col_name)
            else:
                self.shown_scores.remove(col_name)
            self._update_shown_columns()

        menu = QMenu()
        header = self.view.horizontalHeader()
        for col_name in self._column_names():
            action = menu.addAction(col_name)
            action.setCheckable(True)
            action.setChecked(col_name in self.shown_scores)
            action.triggered.connect(partial(update, col_name))
        menu.exec(header.mapToGlobal(pos))

    def _update_shown_columns(self):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        header = self.view.horizontalHeader()
        for section, col_name in enumerate(self._column_names(), start=1):
            header.setSectionHidden(section, col_name not in self.shown_scores)
        self.view.resizeColumnsToContents()
        self.shownScoresChanged.emit()

    def update_header(self, scorers):
        # Set the correct horizontal header labels on the results_model.
        self.model.setColumnCount(3 + len(scorers))
        self.model.setHorizontalHeaderItem(0, QStandardItem("Model"))
        self.model.setHorizontalHeaderItem(1, QStandardItem("Train time [s]"))
        self.model.setHorizontalHeaderItem(2, QStandardItem("Test time [s]"))
        for col, score in enumerate(scorers, start=3):
            item = QStandardItem(score.name)
            item.setToolTip(score.long_name)
            self.model.setHorizontalHeaderItem(col, item)
        self._update_shown_columns()

    def copy_selection_to_clipboard(self):
        mime = table_selection_to_mime_data(self.view)
        QApplication.clipboard().setMimeData(mime, QClipboard.Clipboard)
Пример #13
0
class OWRank(OWWidget):
    name = "排名(Rank)"
    description = "根据数据特征的相关性对其进行排名和筛选。"
    icon = "icons/Rank.svg"
    priority = 1102
    keywords = []

    buttons_area_orientation = Qt.Vertical

    class Inputs:
        data = Input("数据(Data)", Table, replaces=['Data'])
        scorer = Input("评分器(Scorer)", score.Scorer, multiple=True, replaces=['Scorer'])

    class Outputs:
        reduced_data = Output("选中的数据(Reduced Data)", Table, default=True, replaces=['Reduced Data'])
        scores = Output("分数(Scores)", Table, replaces=['Scores'])
        features = Output("特征(Features)", AttributeList, dynamic=False, replaces=['Features'])

    SelectNone, SelectAll, SelectManual, SelectNBest = range(4)

    nSelected = ContextSetting(5)
    auto_apply = Setting(True)

    sorting = Setting((0, Qt.DescendingOrder))
    selected_methods = Setting(set())

    settings_version = 2
    settingsHandler = DomainContextHandler()
    selected_rows = ContextSetting([])
    selectionMethod = ContextSetting(SelectNBest)

    class Information(OWWidget.Information):
        no_target_var = Msg("Data does not have a single target variable. "
                            "You can still connect in unsupervised scorers "
                            "such as PCA.")
        missings_imputed = Msg('Missing values will be imputed as needed.')

    class Error(OWWidget.Error):
        invalid_type = Msg("Cannot handle target variable type {}")
        inadequate_learner = Msg("Scorer {} inadequate: {}")
        no_attributes = Msg("Data does not have a single attribute.")

    def __init__(self):
        super().__init__()
        self.scorers = OrderedDict()
        self.out_domain_desc = None
        self.data = None
        self.problem_type_mode = ProblemType.CLASSIFICATION

        if not self.selected_methods:
            self.selected_methods = {method.name for method in SCORES
                                     if method.is_default}

        # GUI

        self.ranksModel = model = TableModel(parent=self)  # type: TableModel
        self.ranksView = view = TableView(self)            # type: TableView
        self.mainArea.layout().addWidget(view)
        view.setModel(model)
        view.setColumnWidth(0, 30)
        view.selectionModel().selectionChanged.connect(self.on_select)

        def _set_select_manual():
            self.setSelectionMethod(OWRank.SelectManual)

        view.pressed.connect(_set_select_manual)
        view.verticalHeader().sectionClicked.connect(_set_select_manual)
        view.horizontalHeader().sectionClicked.connect(self.headerClick)

        self.measuresStack = stacked = QStackedWidget(self)
        self.controlArea.layout().addWidget(stacked)

        for scoring_methods in (CLS_SCORES,
                                REG_SCORES,
                                []):
            box = gui.vBox(None, "评分方法(Scoring Methods)" if scoring_methods else None)
            stacked.addWidget(box)
            for method in scoring_methods:
                box.layout().addWidget(QCheckBox(
                    method.name, self,
                    objectName=method.shortname,  # To be easily found in tests
                    checked=method.name in self.selected_methods,
                    stateChanged=partial(self.methodSelectionChanged, method_name=method.name)))
            gui.rubber(box)

        gui.rubber(self.controlArea)
        self.switchProblemType(ProblemType.CLASSIFICATION)

        selMethBox = gui.vBox(self.controlArea, "选择属性", addSpace=True)

        grid = QGridLayout()
        grid.setContentsMargins(6, 0, 6, 0)
        self.selectButtons = QButtonGroup()
        self.selectButtons.buttonClicked[int].connect(self.setSelectionMethod)

        def button(text, buttonid, toolTip=None):
            b = QRadioButton(text)
            self.selectButtons.addButton(b, buttonid)
            if toolTip is not None:
                b.setToolTip(toolTip)
            return b

        b1 = button(self.tr("无"), OWRank.SelectNone)
        b2 = button(self.tr("所有"), OWRank.SelectAll)
        b3 = button(self.tr("手动"), OWRank.SelectManual)
        b4 = button(self.tr("最佳排名:"), OWRank.SelectNBest)

        s = gui.spin(selMethBox, self, "nSelected", 1, 999,
                     callback=lambda: self.setSelectionMethod(OWRank.SelectNBest))

        grid.addWidget(b1, 0, 0)
        grid.addWidget(b2, 1, 0)
        grid.addWidget(b3, 2, 0)
        grid.addWidget(b4, 3, 0)
        grid.addWidget(s, 3, 1)

        self.selectButtons.button(self.selectionMethod).setChecked(True)

        selMethBox.layout().addLayout(grid)

        gui.auto_send(selMethBox, self, "auto_apply", box=False)

        self.resize(690, 500)

    def switchProblemType(self, index):
        """
        Switch between discrete/continuous/no_class mode
        """
        self.measuresStack.setCurrentIndex(index)
        self.problem_type_mode = index

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.selected_rows = []
        self.ranksModel.clear()
        self.ranksModel.resetSorting(True)

        self.get_method_scores.cache_clear()  # pylint: disable=no-member
        self.get_scorer_scores.cache_clear()  # pylint: disable=no-member

        self.Error.clear()
        self.Information.clear()
        self.Information.missings_imputed(
            shown=data is not None and data.has_missing())

        if data is not None and not data.domain.attributes:
            data = None
            self.Error.no_attributes()
        self.data = data
        self.switchProblemType(ProblemType.CLASSIFICATION)
        if self.data is not None:
            domain = self.data.domain

            if domain.has_discrete_class:
                problem_type = ProblemType.CLASSIFICATION
            elif domain.has_continuous_class:
                problem_type = ProblemType.REGRESSION
            elif not domain.class_var:
                self.Information.no_target_var()
                problem_type = ProblemType.UNSUPERVISED
            else:
                # This can happen?
                self.Error.invalid_type(type(domain.class_var).__name__)
                problem_type = None

            if problem_type is not None:
                self.switchProblemType(problem_type)

            self.ranksModel.setVerticalHeaderLabels(domain.attributes)
            self.ranksView.setVHeaderFixedWidthFromLabel(
                max((a.name for a in domain.attributes), key=len))

            self.selectionMethod = OWRank.SelectNBest

        self.openContext(data)
        self.selectButtons.button(self.selectionMethod).setChecked(True)

    def handleNewSignals(self):
        self.setStatusMessage('Running')
        self.updateScores()
        self.setStatusMessage('')
        self.on_select()

    @Inputs.scorer
    def set_learner(self, scorer, id):  # pylint: disable=redefined-builtin
        if scorer is None:
            self.scorers.pop(id, None)
        else:
            # Avoid caching a (possibly stale) previous instance of the same
            # Scorer passed via the same signal
            if id in self.scorers:
                # pylint: disable=no-member
                self.get_scorer_scores.cache_clear()

            self.scorers[id] = ScoreMeta(scorer.name, scorer.name, scorer,
                                         ProblemType.from_variable(scorer.class_type),
                                         False)

    @memoize_method()
    def get_method_scores(self, method):
        # These errors often happen, but they result in nans, which
        # are handled correctly by the widget
        estimator = method.scorer()
        data = self.data
        try:
            scores = np.asarray(estimator(data))
        except ValueError:
            try:
                scores = np.array([estimator(data, attr)
                                   for attr in data.domain.attributes])
            except ValueError:
                log.error("%s doesn't work on this data", method.name)
                scores = np.full(len(data.domain.attributes), np.nan)
            else:
                log.warning("%s had to be computed separately for each "
                            "variable", method.name)
        return scores

    @memoize_method()
    def get_scorer_scores(self, scorer):
        try:
            scores = scorer.scorer.score_data(self.data).T
        except ValueError:
            log.error("%s doesn't work on this data", scorer.name)
            scores = np.full((len(self.data.domain.attributes), 1), np.nan)

        labels = ((scorer.shortname,)
                  if scores.shape[1] == 1 else
                  tuple(scorer.shortname + '_' + str(i)
                        for i in range(1, 1 + scores.shape[1])))
        return scores, labels

    def updateScores(self):
        if self.data is None:
            self.ranksModel.clear()
            self.Outputs.scores.send(None)
            return

        methods = [method
                   for method in SCORES
                   if (method.name in self.selected_methods and
                       method.problem_type == self.problem_type_mode and
                       (not issparse(self.data.X) or
                        method.scorer.supports_sparse_data))]

        scorers = []
        self.Error.inadequate_learner.clear()
        for scorer in self.scorers.values():
            if scorer.problem_type in (self.problem_type_mode, ProblemType.UNSUPERVISED):
                scorers.append(scorer)
            else:
                self.Error.inadequate_learner(scorer.name, scorer.learner_adequacy_err_msg)

        method_scores = tuple(self.get_method_scores(method)
                              for method in methods)

        scorer_scores, scorer_labels = (), ()
        if scorers:
            scorer_scores, scorer_labels = zip(*(self.get_scorer_scores(scorer)
                                                 for scorer in scorers))
            scorer_labels = tuple(chain.from_iterable(scorer_labels))

        labels = tuple(method.shortname for method in methods) + scorer_labels
        model_array = np.column_stack(
            ([len(a.values) if a.is_discrete else np.nan
              for a in self.data.domain.attributes],) +
            (method_scores if method_scores else ()) +
            (scorer_scores if scorer_scores else ())
        )
        for column, values in enumerate(model_array.T):
            self.ranksModel.setExtremesFrom(column, values)

        self.ranksModel.wrap(model_array.tolist())
        self.ranksModel.setHorizontalHeaderLabels(('#',) + labels)
        self.ranksView.setColumnWidth(0, 40)

        # Re-apply sort
        try:
            sort_column, sort_order = self.sorting
            if sort_column < len(labels):
                # adds 1 for '#' (discrete count) column
                self.ranksModel.sort(sort_column + 1, sort_order)
                self.ranksView.horizontalHeader().setSortIndicator(sort_column + 1, sort_order)
        except ValueError:
            pass

        self.autoSelection()
        self.Outputs.scores.send(self.create_scores_table(labels))

    def on_select(self):
        # Save indices of attributes in the original, unsorted domain
        self.selected_rows = list(self.ranksModel.mapToSourceRows([
            i.row() for i in self.ranksView.selectionModel().selectedRows(0)]))
        self.commit()

    def setSelectionMethod(self, method):
        self.selectionMethod = method
        self.selectButtons.button(method).setChecked(True)
        self.autoSelection()

    def autoSelection(self):
        selModel = self.ranksView.selectionModel()
        model = self.ranksModel
        rowCount = model.rowCount()
        columnCount = model.columnCount()

        if self.selectionMethod == OWRank.SelectNone:
            selection = QItemSelection()
        elif self.selectionMethod == OWRank.SelectAll:
            selection = QItemSelection(
                model.index(0, 0),
                model.index(rowCount - 1, columnCount - 1)
            )
        elif self.selectionMethod == OWRank.SelectNBest:
            nSelected = min(self.nSelected, rowCount)
            selection = QItemSelection(
                model.index(0, 0),
                model.index(nSelected - 1, columnCount - 1)
            )
        else:
            selection = QItemSelection()
            if self.selected_rows is not None:
                for row in model.mapFromSourceRows(self.selected_rows):
                    selection.append(QItemSelectionRange(
                        model.index(row, 0), model.index(row, columnCount - 1)))

        selModel.select(selection, QItemSelectionModel.ClearAndSelect)

    def headerClick(self, index):
        if index >= 1 and self.selectionMethod == OWRank.SelectNBest:
            # Reselect the top ranked attributes
            self.autoSelection()

        # Store the header states
        sort_order = self.ranksModel.sortOrder()
        sort_column = self.ranksModel.sortColumn() - 1  # -1 for '#' (discrete count) column
        self.sorting = (sort_column, sort_order)

    def methodSelectionChanged(self, state, method_name):
        if state == Qt.Checked:
            self.selected_methods.add(method_name)
        elif method_name in self.selected_methods:
            self.selected_methods.remove(method_name)

        self.updateScores()

    def send_report(self):
        if not self.data:
            return
        self.report_domain("Input", self.data.domain)
        self.report_table("Ranks", self.ranksView, num_format="{:.3f}")
        if self.out_domain_desc is not None:
            self.report_items("Output", self.out_domain_desc)

    def commit(self):
        selected_attrs = []
        if self.data is not None:
            selected_attrs = [self.data.domain.attributes[i]
                              for i in self.selected_rows]
        if not selected_attrs:
            self.Outputs.reduced_data.send(None)
            self.Outputs.features.send(None)
            self.out_domain_desc = None
        else:
            reduced_domain = Domain(
                selected_attrs, self.data.domain.class_var, self.data.domain.metas)
            data = self.data.transform(reduced_domain)
            self.Outputs.reduced_data.send(data)
            self.Outputs.features.send(AttributeList(selected_attrs))
            self.out_domain_desc = report.describe_domain(data.domain)

    def create_scores_table(self, labels):
        model_list = self.ranksModel.tolist()
        if not model_list or len(model_list[0]) == 1:  # Empty or just n_values column
            return None

        domain = Domain([ContinuousVariable(label) for label in labels],
                        metas=[StringVariable("Feature")])

        # Prevent np.inf scores
        finfo = np.finfo(np.float64)
        scores = np.clip(np.array(model_list)[:, 1:], finfo.min, finfo.max)

        feature_names = np.array([a.name for a in self.data.domain.attributes])
        # Reshape to 2d array as Table does not like 1d arrays
        feature_names = feature_names[:, None]

        new_table = Table(domain, scores, metas=feature_names)
        new_table.name = "Feature Scores"
        return new_table

    @classmethod
    def migrate_settings(cls, settings, version):
        # If older settings, restore sort header to default
        # Saved selected_rows will likely be incorrect
        if version is None or version < 2:
            column, order = 0, Qt.DescendingOrder
            headerState = settings.pop("headerState", None)

            # Lacking knowledge of last problemType, use discrete ranks view's ordering
            if isinstance(headerState, (tuple, list)):
                headerState = headerState[0]

            if isinstance(headerState, bytes):
                hview = QHeaderView(Qt.Horizontal)
                hview.restoreState(headerState)
                column, order = hview.sortIndicatorSection() - 1, hview.sortIndicatorOrder()
            settings["sorting"] = (column, order)

    @classmethod
    def migrate_context(cls, context, version):
        if version is None or version < 2:
            # Old selection was saved as sorted indices. New selection is original indices.
            # Since we can't devise the latter without first computing the ranks,
            # just reset the selection to avoid confusion.
            context.values['selected_rows'] = []
Пример #14
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read data from an input file or network " \
                  "and send a data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["file", "load", "read", "open"]

    class Outputs:
        data = Output("Data",
                      Table,
                      doc="Attribute-valued dataset read from the input file.")

    want_main_area = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]
    SIZE_LIMIT = 1e7
    LOCAL_FILE, URL = range(2)

    settingsHandler = PerfectDomainContextHandler(
        match_values=PerfectDomainContextHandler.MATCH_VALUES_ALL)

    # pylint seems to want declarations separated from definitions
    recent_paths: List[RecentPath]
    recent_urls: List[str]
    variables: list

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
        RecentPath("", "sample-datasets", "heart_disease.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    variables = ContextSetting([])

    domain_editor = SettingProvider(DomainEditor)

    class Warning(widget.OWWidget.Warning):
        file_too_big = widget.Msg(
            "The file is too large to load automatically."
            " Press Reload to load.")
        load_warning = widget.Msg("Read warning:\n{}")

    class Error(widget.OWWidget.Error):
        file_not_found = widget.Msg("File not found.")
        missing_reader = widget.Msg("Missing reader.")
        sheet_error = widget.Msg("Error listing available sheets.")
        unknown = widget.Msg("Read error:\n{}")

    class NoFileSelected:
        pass

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self.reader = None

        layout = QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data,
                                addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File:", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(None,
                                 self,
                                 '...',
                                 callback=self.browse_file,
                                 autoDefault=False)
        file_button.setIcon(self.style().standardIcon(QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(None,
                                   self,
                                   "Reload",
                                   callback=self.load_data,
                                   autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(
            None,
            self,
            "xls_sheet",
            callback=self.select_sheet,
            sendSelectedValue=True,
        )
        self.sheet_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(self.sheet_label, Qt.AlignLeft)
        self.sheet_box.layout().addWidget(self.sheet_combo, Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL:", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, Qt.AlignVCenter)

        self.url_combo = url_combo = QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setLineEdit(LineEditSelectOnFocus())
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, "Columns (Double click to edit)")
        self.domain_editor = DomainEditor(self)
        self.editor_model = self.domain_editor.model()
        box.layout().addWidget(self.domain_editor)

        box = gui.hBox(self.controlArea)
        gui.button(box,
                   self,
                   "Browse documentation datasets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)

        self.apply_button = gui.button(box,
                                       self,
                                       "Apply",
                                       callback=self.apply_domain_edit)
        self.apply_button.setEnabled(False)
        self.apply_button.setFixedWidth(170)
        self.editor_model.dataChanged.connect(
            lambda: self.apply_button.setEnabled(True))

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)

        self.setAcceptDrops(True)

        if self.source == self.LOCAL_FILE:
            last_path = self.last_path()
            if last_path and os.path.exists(last_path) and \
                    os.path.getsize(last_path) > self.SIZE_LIMIT:
                self.Warning.file_too_big()
                return

        QTimer.singleShot(0, self.load_data)

    def sizeHint(self):
        return QSize(600, 550)

    def select_file(self, n):
        assert n < len(self.recent_paths)
        super().select_file(n)
        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def select_sheet(self):
        self.recent_paths[0].sheet = self.sheet_combo.currentText()
        self.load_data()

    def _url_set(self):
        url = self.url_combo.currentText()
        pos = self.recent_urls.index(url)
        url = url.strip()

        if not urlparse(url).scheme:
            url = 'http://' + url
            self.url_combo.setItemText(pos, url)
            self.recent_urls[pos] = url

        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation datasets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        readers = [
            f for f in FileFormat.formats
            if getattr(f, 'read', None) and getattr(f, "EXTENSIONS", None)
        ]
        filename, reader, _ = open_filename_dialog(start_file, None, readers)
        if not filename:
            return
        self.add_path(filename)
        if reader is not None:
            self.recent_paths[0].file_format = reader.qualified_name()

        self.source = self.LOCAL_FILE
        self.load_data()

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        # We need to catch any exception type since anything can happen in
        # file readers
        self.closeContext()
        self.domain_editor.set_domain(None)
        self.apply_button.setEnabled(False)
        self.clear_messages()
        self.set_file_list()

        error = self._try_load()
        if error:
            error()
            self.data = None
            self.sheet_box.hide()
            self.Outputs.data.send(None)
            self.infolabel.setText("No data.")

    def _try_load(self):
        # pylint: disable=broad-except
        if self.last_path() and not os.path.exists(self.last_path()):
            return self.Error.file_not_found

        try:
            self.reader = self._get_reader()
            assert self.reader is not None
        except Exception:
            return self.Error.missing_reader

        if self.reader is self.NoFileSelected:
            self.Outputs.data.send(None)
            return None

        try:
            self._update_sheet_combo()
        except Exception:
            return self.Error.sheet_error

        with catch_warnings(record=True) as warnings:
            try:
                data = self.reader.read()
            except Exception as ex:
                log.exception(ex)
                return lambda x=ex: self.Error.unknown(str(x))
            if warnings:
                self.Warning.load_warning(warnings[-1].message.args[0])

        self.infolabel.setText(self._describe(data))

        self.loaded_file = self.last_path()
        add_origin(data, self.loaded_file)
        self.data = data
        self.openContext(data.domain)
        self.apply_domain_edit()  # sends data
        return None

    def _get_reader(self):
        """

        Returns
        -------
        FileFormat
        """
        if self.source == self.LOCAL_FILE:
            path = self.last_path()
            if path is None:
                return self.NoFileSelected
            if self.recent_paths and self.recent_paths[0].file_format:
                qname = self.recent_paths[0].file_format
                reader_class = class_from_qualified_name(qname)
                reader = reader_class(path)
            else:
                reader = FileFormat.get_reader(path)
            if self.recent_paths and self.recent_paths[0].sheet:
                reader.select_sheet(self.recent_paths[0].sheet)
            return reader
        else:
            url = self.url_combo.currentText().strip()
            if url:
                return UrlReader(url)
            else:
                return self.NoFileSelected

    def _update_sheet_combo(self):
        if len(self.reader.sheets) < 2:
            self.sheet_box.hide()
            self.reader.select_sheet(None)
            return

        self.sheet_combo.clear()
        self.sheet_combo.addItems(self.reader.sheets)
        self._select_active_sheet()
        self.sheet_box.show()

    def _select_active_sheet(self):
        if self.reader.sheet:
            try:
                idx = self.reader.sheets.index(self.reader.sheet)
                self.sheet_combo.setCurrentIndex(idx)
            except ValueError:
                # Requested sheet does not exist in this file
                self.reader.select_sheet(None)
        else:
            self.sheet_combo.setCurrentIndex(0)

    def _describe(self, table):
        domain = table.domain
        text = ""

        attrs = getattr(table, "attributes", {})
        descs = [
            attrs[desc] for desc in ("Name", "Description") if desc in attrs
        ]
        if len(descs) == 2:
            descs[0] = "<b>{}</b>".format(descs[0])
        if descs:
            text += "<p>{}</p>".format("<br/>".join(descs))
        # Instances
        text += "<p>{} instance(s)".format(len(table))
        # Attributes
        missing_attr = "({:.1f}% missing values)".format(table.get_nan_frequency_attribute() * 100) \
            if table.has_missing_attribute() else "(no missing values)"
        text += "<br/>{} feature(s) {}".format(len(domain.attributes),
                                               missing_attr)
        # Classes
        missing_class = "({:.1f}% missing values)".format(table.get_nan_frequency_class() * 100) \
            if table.has_missing_class() else "(no missing values)"
        if domain.has_continuous_class:
            text += "<br/>Regression; numerical class {}".format(missing_class)
        elif domain.has_discrete_class:
            text += "<br/>Classification; categorical class with {} values {}".format(
                len(domain.class_var.values), missing_class)
        elif table.domain.class_vars:
            text += "<br/>Multi-target; {} target variables {}".format(
                len(table.domain.class_vars), missing_class)
        else:
            text += "<br/>Data has no target variable."
        # Metas
        text += "<br/>{} meta attribute(s)".format(len(domain.metas))
        text += "</p>"

        if 'Timestamp' in table.domain:
            # Google Forms uses this header to timestamp responses
            text += '<p>First entry: {}<br/>Last entry: {}</p>'.format(
                table[0, 'Timestamp'], table[-1, 'Timestamp'])
        return text

    def storeSpecificSettings(self):
        self.current_context.modified_variables = self.variables[:]

    def retrieveSpecificSettings(self):
        if hasattr(self.current_context, "modified_variables"):
            self.variables[:] = self.current_context.modified_variables

    def apply_domain_edit(self):
        if self.data is None:
            table = None
        else:
            domain, cols = self.domain_editor.get_domain(
                self.data.domain, self.data)
            if not (domain.variables or domain.metas):
                table = None
            else:
                X, y, m = cols
                table = Table.from_numpy(domain, X, y, m, self.data.W)
                table.name = self.data.name
                table.ids = np.array(self.data.ids)
                table.attributes = getattr(self.data, 'attributes', {})

        self.Outputs.data.send(table)
        self.apply_button.setEnabled(False)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~" + os.path.sep + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += " ({})".format(self.sheet_combo.currentText())
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    def dragEnterEvent(self, event):
        """Accept drops of valid file urls"""
        urls = event.mimeData().urls()
        if urls:
            try:
                FileFormat.get_reader(
                    OSX_NSURL_toLocalFile(urls[0]) or urls[0].toLocalFile())
                event.acceptProposedAction()
            except IOError:
                pass

    def dropEvent(self, event):
        """Handle file drops"""
        urls = event.mimeData().urls()
        if urls:
            self.add_path(
                OSX_NSURL_toLocalFile(urls[0])
                or urls[0].toLocalFile())  # add first file
            self.source = self.LOCAL_FILE
            self.load_data()

    def workflowEnvChanged(self, key, value, oldvalue):
        """
        Function called when environment changes (e.g. while saving the scheme)
        It make sure that all environment connected values are modified
        (e.g. relative file paths are changed)
        """
        self.update_file_list(key, value, oldvalue)
Пример #15
0
class OWFace(widget.OWWidget):
    name = "Face Detector"
    description = "Detect and extract a face from an image."
    icon = "icons/Face.svg"
    priority = 123

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table)

    auto_run = Setting(True)

    want_main_area = False

    def __init__(self):
        super().__init__()
        self.data = None
        self.img_attr = None
        self.faces = None

        haarcascade = os.path.join(os.path.dirname(__file__),
                                   'data/haarcascade_frontalface_default.xml')
        self.face_cascade = cv2.CascadeClassifier(haarcascade)

        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, "No data.")

        gui.auto_commit(self.controlArea,
                        self,
                        "auto_run",
                        "Run",
                        checkbox_label="Run after any change",
                        orientation="horizontal")

    def get_ext(self, file_path):
        """Find the extension of a file or url."""
        if not os.path.isfile(file_path):
            file_path = urllib.parse.urlparse(file_path).path
        return os.path.splitext(file_path)[1].strip().lower()

    def read_img(self, file_path):
        """Read an image from file or url and convert it to grayscale."""
        try:
            if os.path.isfile(file_path):
                img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
            else:
                res = urllib.request.urlopen(file_path)
                arr = np.asarray(bytearray(res.read()), dtype=np.uint8)
                img = cv2.imdecode(arr, cv2.IMREAD_GRAYSCALE)
            return img
        except:
            return None

    def find_face(self, file_path, face_path):
        """Find the face in image file_path and store it in face_path."""
        img = self.read_img(file_path)
        if img is None:
            return False
        # downscale to a reasonable size (long edge <= 1024)
        f = min(1024 / img.shape[0], 1024 / img.shape[1], 1)
        img = cv2.resize(img, None, fx=f, fy=f)
        faces = self.face_cascade.detectMultiScale(img)
        if len(faces) == 0:
            return False
        x, y, w, h = max(faces, key=lambda xywh: xywh[2] * xywh[3])
        face = img[y:y + h, x:x + w]
        cv2.imwrite(face_path, face)
        return True

    @staticmethod
    def cleanup(filenames):
        for fname in filenames:
            os.unlink(fname)

    def commit(self):
        if self.img_attr is None:
            self.send("Data", self.data)
            return
        face_var = StringVariable("face")
        face_var.attributes["type"] = "image"
        domain = Domain([], metas=[face_var])
        faces_list = []
        tmp_files = []
        n_faces = 0
        for row in self.data:
            file_abs = str(row[self.img_attr])
            file_ext = self.get_ext(file_abs)
            with tempfile.NamedTemporaryFile(suffix=file_ext,
                                             delete=False) as f:
                face_abs = f.name
                tmp_files.append(face_abs)
            if self.find_face(file_abs, face_abs):
                faces_list.append([face_abs])
                n_faces += 1
            else:
                faces_list.append([""])
        atexit.register(self.cleanup, tmp_files)
        self.info.setText("Detected %d faces." % n_faces)

        self.faces = Table.from_list(domain, faces_list)
        new_domain = Domain(self.data.domain.attributes,
                            metas=self.data.domain.metas +
                            self.faces.domain.metas)
        comb = self.data.transform(new_domain)
        comb[:, face_var] = faces_list
        self.Outputs.data.send(comb)

    @Inputs.data
    def set_data(self, data):
        self.data = data
        self.faces = None
        if not self.data:
            self.info.setText("No data.")
            self.send("Data", None)
            return
        atts = [
            a for a in data.domain.metas if a.attributes.get("type") == "image"
        ]
        self.img_attr = atts[0] if atts else None
        if not self.img_attr:
            self.info.setText("No image attribute.")
        else:
            self.info.setText("Image attribute: %s" % str(self.img_attr))
        if self.auto_run:
            self.commit()
Пример #16
0
class OWParallelGraph(OWPlot, ScaleData):
    show_distributions = Setting(False)
    show_attr_values = Setting(True)
    show_statistics = Setting(default=False)

    group_lines = Setting(default=False)
    number_of_groups = Setting(default=5)
    number_of_steps = Setting(default=30)

    use_splines = Setting(False)
    alpha_value = Setting(150)
    alpha_value_2 = Setting(150)

    def __init__(self, widget, parent=None, name=None):
        OWPlot.__init__(self, parent, name, axes=[], widget=widget)
        ScaleData.__init__(self)

        self.update_antialiasing(False)

        self.widget = widget
        self.last_selected_curve = None
        self.enableGridXB(0)
        self.enableGridYL(0)
        self.domain_contingencies = None
        self.auto_update_axes = 1
        self.old_legend_keys = []
        self.selection_conditions = {}
        self.attributes = []
        self.visualized_mid_labels = []
        self.attribute_indices = []
        self.valid_data = []
        self.groups = {}

        self.selected_examples = []
        self.unselected_examples = []
        self.bottom_pixmap = QPixmap(
            gui.resource_filename("icons/upgreenarrow.png"))
        self.top_pixmap = QPixmap(
            gui.resource_filename("icons/downgreenarrow.png"))

    def set_data(self, data, subset_data=None, **args):
        self.start_progress()
        self.set_progress(1, 100)
        self.data = data
        self.have_data = True
        self.domain_contingencies = None
        self.groups = {}
        OWPlot.setData(self, data)
        ScaleData.set_data(self, data, no_data=True, **args)
        self.end_progress()

    def update_data(self, attributes, mid_labels=None):
        old_selection_conditions = self.selection_conditions

        self.clear()

        if not (self.have_data):
            return
        if len(attributes) < 2:
            return

        if self.show_statistics:
            self.alpha_value = TRANSPARENT
            self.alpha_value_2 = VISIBLE
        else:
            self.alpha_value = VISIBLE
            self.alpha_value_2 = TRANSPARENT

        self.attributes = attributes
        self.attribute_indices = [
            self.attribute_name_index[name] for name in self.attributes
        ]
        self.valid_data = self.get_valid_list(self.attribute_indices)

        self.visualized_mid_labels = mid_labels
        self.add_relevant_selections(old_selection_conditions)

        if self.data_has_discrete_class:
            self.discrete_palette.set_number_of_colors(
                len(self.data_domain.class_var.values))

        if self.group_lines:
            self.show_statistics = False
            self.draw_groups()
        else:
            self.show_statistics = False
            self.draw_curves()
        self.draw_distributions()
        self.draw_axes()
        self.draw_statistics()
        self.draw_mid_labels(mid_labels)
        self.draw_legend()

        self.replot()

    def add_relevant_selections(self, old_selection_conditions):
        """Keep only conditions related to the currently visualized attributes"""
        for name, value in old_selection_conditions.items():
            if name in self.attributes:
                self.selection_conditions[name] = value

    def draw_axes(self):
        self.remove_all_axes()
        for i in range(len(self.attributes)):
            axis_id = UserAxis + i
            a = self.add_axis(axis_id,
                              line=QLineF(i, 0, i, 1),
                              arrows=AxisStart | AxisEnd,
                              zoomable=True)
            a.always_horizontal_text = True
            a.max_text_width = 100
            a.title_margin = -10
            a.text_margin = 0
            a.setZValue(5)
            self.set_axis_title(axis_id,
                                self.data_domain[self.attributes[i]].name)
            self.set_show_axis_title(axis_id, self.show_attr_values)
            if self.show_attr_values:
                attr = self.data_domain[self.attributes[i]]
                if attr.is_continuous:
                    self.set_axis_scale(axis_id,
                                        self.attr_values[attr.name][0],
                                        self.attr_values[attr.name][1])
                elif attr.is_discrete:
                    attribute_values = get_variable_values_sorted(
                        self.data_domain[self.attributes[i]])
                    attr_len = len(attribute_values)
                    values = [
                        float(1.0 + 2.0 * j) / float(2 * attr_len)
                        for j in range(len(attribute_values))
                    ]
                    a.set_bounds((0, 1))
                    self.set_axis_labels(axis_id,
                                         labels=attribute_values,
                                         values=values)

    def draw_curves(self):
        conditions = {
            name: self.attributes.index(name)
            for name in self.selection_conditions.keys()
        }

        def is_selected(example):
            return all(self.selection_conditions[name][0] <= example[index] <=
                       self.selection_conditions[name][1]
                       for (name, index) in list(conditions.items()))

        selected_curves = defaultdict(list)
        background_curves = defaultdict(list)

        diff, mins = [], []
        for i in self.attribute_indices:
            var = self.data_domain[i]
            if var.is_discrete:
                diff.append(len(var.values))
                mins.append(-0.5)
            else:
                diff.append(
                    self.domain_data_stat[i].max - self.domain_data_stat[i].min
                    or 1)
                mins.append(self.domain_data_stat[i].min)

        def scale_row(row):
            return [(x - m) / d for x, m, d in zip(row, mins, diff)]

        for row_idx, row in enumerate(self.data[:, self.attribute_indices]):
            if any(np.isnan(v) for v in row.x):
                continue

            color = tuple(self.select_color(row_idx))

            if is_selected(row):
                color += (self.alpha_value, )
                selected_curves[color].extend(scale_row(row))
                self.selected_examples.append(row_idx)
            else:
                color += (self.alpha_value_2, )
                background_curves[color].extend(row)
                self.unselected_examples.append(row_idx)

        self._draw_curves(selected_curves)
        self._draw_curves(background_curves)

    def select_color(self, row_index):
        if self.data_has_class:
            if self.data_has_continuous_class:
                return self.continuous_palette.getRGB(
                    self.data[row_index, self.data_class_index])
            else:
                return self.discrete_palette.getRGB(
                    self.data[row_index, self.data_class_index])
        else:
            return 0, 0, 0

    def _draw_curves(self, selected_curves):
        n_attr = len(self.attributes)
        for color, y_values in sorted(selected_curves.items()):
            n_rows = int(len(y_values) / n_attr)
            x_values = list(range(n_attr)) * n_rows
            curve = OWCurve()
            curve.set_style(OWCurve.Lines)
            curve.set_color(QColor(*color))
            curve.set_segment_length(n_attr)
            curve.set_data(x_values, y_values)
            curve.attach(self)

    def draw_groups(self):
        phis, mus, sigmas = self.compute_groups()

        diff, mins = [], []
        for i in self.attribute_indices:
            var = self.data_domain[i]
            if var.is_discrete:
                diff.append(len(var.values))
                mins.append(-0.5)
            else:
                diff.append(
                    self.domain_data_stat[i].max - self.domain_data_stat[i].min
                    or 1)
                mins.append(self.domain_data_stat[i].min)

        for j, (phi, cluster_mus,
                cluster_sigma) in enumerate(zip(phis, mus, sigmas)):
            for i, (mu1, sigma1, mu2, sigma2), in enumerate(
                    zip(cluster_mus, cluster_sigma, cluster_mus[1:],
                        cluster_sigma[1:])):
                nmu1 = (mu1 - mins[i]) / diff[i]
                nmu2 = (mu2 - mins[i + 1]) / diff[i + 1]
                nsigma1 = math.sqrt(sigma1) / diff[i]
                nsigma2 = math.sqrt(sigma2) / diff[i + 1]

                polygon = ParallelCoordinatePolygon(
                    i, nmu1, nmu2, nsigma1, nsigma2, phi,
                    tuple(self.discrete_palette.getRGB(j)))
                polygon.attach(self)

        self.replot()

    def compute_groups(self):
        key = (tuple(self.attributes), self.number_of_groups,
               self.number_of_steps)
        if key not in self.groups:

            def callback(i, n):
                self.set_progress(i, 2 * n)

            conts = create_contingencies(self.data[:, self.attribute_indices],
                                         callback=callback)
            self.set_progress(50, 100)
            w, mu, sigma, phi = lac(conts, self.number_of_groups,
                                    self.number_of_steps)
            self.set_progress(100, 100)
            self.groups[key] = list(map(np.nan_to_num, (phi, mu, sigma)))
        return self.groups[key]

    def draw_legend(self):
        if self.data_has_class:
            if self.data_domain.has_discrete_class:
                self.legend().clear()
                values = get_variable_values_sorted(self.data_domain.class_var)
                for i, value in enumerate(values):
                    self.legend().add_item(
                        self.data_domain.class_var.name, value,
                        OWPoint(OWPoint.Rect, self.discrete_palette[i],
                                self.point_width))
            else:
                values = self.attr_values[self.data_domain.class_var.name]
                decimals = self.data_domain.class_var.number_of_decimals
                self.legend().add_color_gradient(
                    self.data_domain.class_var.name,
                    ["%%.%df" % decimals % v for v in values])
        else:
            self.legend().clear()
            self.old_legend_keys = []

    def draw_mid_labels(self, mid_labels):
        if mid_labels:
            for j in range(len(mid_labels)):
                self.addMarker(mid_labels[j],
                               j + 0.5,
                               1.0,
                               alignment=Qt.AlignCenter | Qt.AlignTop)

    def draw_statistics(self):
        """Draw lines that represent standard deviation or quartiles"""
        return  # TODO: Implement using BasicStats
        if self.show_statistics and self.have_data:
            data = []
            for attr_idx in self.attribute_indices:
                if not self.data_domain[attr_idx].is_continuous:
                    data.append([()])
                    continue  # only for continuous attributes

                if not self.data_has_class or self.data_has_continuous_class:  # no class
                    if self.show_statistics == MEANS:
                        m = self.domain_data_stat[attr_idx].mean
                        dev = self.domain_data_stat[attr_idx].var
                        data.append([(m - dev, m, m + dev)])
                    elif self.show_statistics == MEDIAN:
                        data.append([(0, 0, 0)])
                        continue

                        sorted_array = np.sort(attr_values)
                        if len(sorted_array) > 0:
                            data.append([
                                (sorted_array[int(len(sorted_array) / 4.0)],
                                 sorted_array[int(len(sorted_array) / 2.0)],
                                 sorted_array[int(len(sorted_array) * 0.75)])
                            ])
                        else:
                            data.append([(0, 0, 0)])
                else:
                    curr = []
                    class_values = get_variable_values_sorted(
                        self.data_domain.class_var)

                    for c in range(len(class_values)):
                        attr_values = self.data[
                            attr_idx, self.data[self.data_class_index] == c]
                        attr_values = attr_values[~np.isnan(attr_values)]

                        if len(attr_values) == 0:
                            curr.append((0, 0, 0))
                            continue
                        if self.show_statistics == MEANS:
                            m = attr_values.mean()
                            dev = attr_values.std()
                            curr.append((m - dev, m, m + dev))
                        elif self.show_statistics == MEDIAN:
                            sorted_array = np.sort(attr_values)
                            curr.append(
                                (sorted_array[int(len(attr_values) / 4.0)],
                                 sorted_array[int(len(attr_values) / 2.0)],
                                 sorted_array[int(len(attr_values) * 0.75)]))
                    data.append(curr)

            # draw vertical lines
            for i in range(len(data)):
                for c in range(len(data[i])):
                    if data[i][c] == ():
                        continue
                    x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03
                    col = QColor(self.discrete_palette[c])
                    col.setAlpha(self.alpha_value_2)
                    self.add_curve(
                        "",
                        col,
                        col,
                        3,
                        OWCurve.Lines,
                        OWPoint.NoSymbol,
                        xData=[x, x, x],
                        yData=[data[i][c][0], data[i][c][1], data[i][c][2]],
                        lineWidth=4)
                    self.add_curve("",
                                   col,
                                   col,
                                   1,
                                   OWCurve.Lines,
                                   OWPoint.NoSymbol,
                                   xData=[x - 0.03, x + 0.03],
                                   yData=[data[i][c][0], data[i][c][0]],
                                   lineWidth=4)
                    self.add_curve("",
                                   col,
                                   col,
                                   1,
                                   OWCurve.Lines,
                                   OWPoint.NoSymbol,
                                   xData=[x - 0.03, x + 0.03],
                                   yData=[data[i][c][1], data[i][c][1]],
                                   lineWidth=4)
                    self.add_curve("",
                                   col,
                                   col,
                                   1,
                                   OWCurve.Lines,
                                   OWPoint.NoSymbol,
                                   xData=[x - 0.03, x + 0.03],
                                   yData=[data[i][c][2], data[i][c][2]],
                                   lineWidth=4)

            # draw lines with mean/median values
            if not self.data_has_class or self.data_has_continuous_class:
                class_count = 1
            else:
                class_count = len(self.data_domain.class_var.values)
            for c in range(class_count):
                diff = -0.03 * (class_count - 1) / 2.0 + c * 0.03
                ys = []
                xs = []
                for i in range(len(data)):
                    if data[i] != [()]:
                        ys.append(data[i][c][1])
                        xs.append(i + diff)
                    else:
                        if len(xs) > 1:
                            col = QColor(self.discrete_palette[c])
                            col.setAlpha(self.alpha_value_2)
                            self.add_curve("",
                                           col,
                                           col,
                                           1,
                                           OWCurve.Lines,
                                           OWPoint.NoSymbol,
                                           xData=xs,
                                           yData=ys,
                                           lineWidth=4)
                        xs = []
                        ys = []
                col = QColor(self.discrete_palette[c])
                col.setAlpha(self.alpha_value_2)
                self.add_curve("",
                               col,
                               col,
                               1,
                               OWCurve.Lines,
                               OWPoint.NoSymbol,
                               xData=xs,
                               yData=ys,
                               lineWidth=4)

    def draw_distributions(self):
        """Draw distributions with discrete attributes"""
        if not (self.show_distributions and self.have_data
                and self.data_has_discrete_class):
            return
        class_count = len(self.data_domain.class_var.values)
        class_ = self.data_domain.class_var

        # we create a hash table of possible class values (happens only if we have a discrete class)
        if self.domain_contingencies is None:
            self.domain_contingencies = dict(
                zip([attr for attr in self.data_domain if attr.is_discrete],
                    get_contingencies(self.raw_data, skipContinuous=True)))
            self.domain_contingencies[class_] = get_contingency(
                self.raw_data, class_, class_)

        max_count = max([
            contingency.max()
            for contingency in self.domain_contingencies.values()
        ] or [1])
        sorted_class_values = get_variable_values_sorted(
            self.data_domain.class_var)

        for axis_idx, attr_idx in enumerate(self.attribute_indices):
            attr = self.data_domain[attr_idx]
            if attr.is_discrete:
                continue

            contingency = self.domain_contingencies[attr]
            attr_len = len(attr.values)

            # we create a hash table of variable values and their indices
            sorted_variable_values = get_variable_values_sorted(attr)

            # create bar curve
            for j in range(attr_len):
                attribute_value = sorted_variable_values[j]
                value_count = contingency[:, attribute_value]

                for i in range(class_count):
                    class_value = sorted_class_values[i]

                    color = QColor(self.discrete_palette[i])
                    color.setAlpha(self.alpha_value)

                    width = float(
                        value_count[class_value] * 0.5) / float(max_count)
                    y_off = float(1.0 + 2.0 * j) / float(2 * attr_len)
                    height = 0.7 / float(class_count * attr_len)

                    y_low_bottom = y_off + float(
                        class_count * height) / 2.0 - i * height
                    curve = PolygonCurve(QPen(color),
                                         QBrush(color),
                                         xData=[
                                             axis_idx, axis_idx + width,
                                             axis_idx + width, axis_idx
                                         ],
                                         yData=[
                                             y_low_bottom, y_low_bottom,
                                             y_low_bottom - height,
                                             y_low_bottom - height
                                         ],
                                         tooltip=attr.name)
                    curve.attach(self)

    # handle tooltip events
    def event(self, ev):
        if ev.type() == QEvent.ToolTip:
            x = self.inv_transform(xBottom, ev.pos().x())
            y = self.inv_transform(yLeft, ev.pos().y())

            canvas_position = self.mapToScene(ev.pos())
            x_float = self.inv_transform(xBottom, canvas_position.x())
            contact, (index,
                      pos) = self.testArrowContact(int(round(x_float)),
                                                   canvas_position.x(),
                                                   canvas_position.y())
            if contact:
                attr = self.data_domain[self.attributes[index]]
                if attr.is_continuous:
                    condition = self.selection_conditions.get(
                        attr.name, [0, 1])
                    val = self.attr_values[attr.name][0] + condition[pos] * (
                        self.attr_values[attr.name][1] -
                        self.attr_values[attr.name][0])
                    str_val = attr.name + "= %%.%df" % attr.number_of_decimals % val
                    QToolTip.showText(ev.globalPos(), str_val)
            else:
                for curve in self.items():
                    if type(curve) == PolygonCurve and \
                            curve.boundingRect().contains(x, y) and \
                            getattr(curve, "tooltip", None):
                        (name, value, total, dist) = curve.tooltip
                        count = sum([v[1] for v in dist])
                        if count == 0:
                            continue
                        tooltip_text = "Attribute: <b>%s</b><br>Value: <b>%s</b><br>" \
                                       "Total instances: <b>%i</b> (%.1f%%)<br>" \
                                       "Class distribution:<br>" % (
                                           name, value, count, 100.0 * count / float(total))
                        for (val, n) in dist:
                            tooltip_text += "&nbsp; &nbsp; <b>%s</b> : <b>%i</b> (%.1f%%)<br>" % (
                                val, n, 100.0 * float(n) / float(count))
                        QToolTip.showText(ev.globalPos(), tooltip_text[:-4])

        elif ev.type() == QEvent.MouseMove:
            QToolTip.hideText()

        return OWPlot.event(self, ev)

    def testArrowContact(self, indices, x, y):
        if type(indices) != list: indices = [indices]
        for index in indices:
            if index >= len(self.attributes) or index < 0:
                continue
            int_x = self.transform(xBottom, index)
            bottom = self.transform(
                yLeft,
                self.selection_conditions.get(self.attributes[index],
                                              [0, 1])[0])
            bottom_rect = QRect(int_x - self.bottom_pixmap.width() / 2, bottom,
                                self.bottom_pixmap.width(),
                                self.bottom_pixmap.height())
            if bottom_rect.contains(QPoint(x, y)):
                return 1, (index, 0)
            top = self.transform(
                yLeft,
                self.selection_conditions.get(self.attributes[index],
                                              [0, 1])[1])
            top_rect = QRect(int_x - self.top_pixmap.width() / 2,
                             top - self.top_pixmap.height(),
                             self.top_pixmap.width(), self.top_pixmap.height())
            if top_rect.contains(QPoint(x, y)):
                return 1, (index, 1)
        return 0, (0, 0)

    def mousePressEvent(self, e):
        canvas_position = self.mapToScene(e.pos())
        x = self.inv_transform(xBottom, canvas_position.x())
        contact, info = self.testArrowContact(int(round(x)),
                                              canvas_position.x(),
                                              canvas_position.y())

        if contact:
            self.pressed_arrow = info
        else:
            OWPlot.mousePressEvent(self, e)

    def mouseMoveEvent(self, e):
        if hasattr(self, "pressed_arrow"):
            canvas_position = self.mapToScene(e.pos())
            y = min(1, max(0, self.inv_transform(yLeft, canvas_position.y())))
            index, pos = self.pressed_arrow
            attr = self.data_domain[self.attributes[index]]
            old_condition = self.selection_conditions.get(attr.name, [0, 1])
            old_condition[pos] = y
            self.selection_conditions[attr.name] = old_condition
            self.update_data(self.attributes, self.visualized_mid_labels)

            if attr.is_continuous:
                val = self.attr_values[attr.name][0] + old_condition[pos] * (
                    self.attr_values[attr.name][1] -
                    self.attr_values[attr.name][0])
                strVal = attr.name + "= %.2f" % val
                QToolTip.showText(e.globalPos(), strVal)
            if self.sendSelectionOnUpdate and self.auto_send_selection_callback:
                self.auto_send_selection_callback()

        else:
            OWPlot.mouseMoveEvent(self, e)

    def mouseReleaseEvent(self, e):
        if hasattr(self, "pressed_arrow"):
            del self.pressed_arrow
        else:
            OWPlot.mouseReleaseEvent(self, e)

    def zoom_to_rect(self, r):
        r.setTop(self.graph_area.top())
        r.setBottom(self.graph_area.bottom())
        super().zoom_to_rect(r)

    def removeAllSelections(self, send=1):
        self.selection_conditions = {}
        self.update_data(self.attributes, self.visualized_mid_labels)

    # draw the curves and the selection conditions
    def drawCanvas(self, painter):
        OWPlot.drawCanvas(self, painter)
        for i in range(
                int(
                    max(
                        0,
                        math.floor(
                            self.axisScaleDiv(
                                xBottom).interval().minValue()))),
                int(
                    min(
                        len(self.attributes),
                        math.ceil(
                            self.axisScaleDiv(xBottom).interval().maxValue()) +
                        1))):
            bottom, top = self.selection_conditions.get(
                self.attributes[i], (0, 1))
            painter.drawPixmap(
                self.transform(xBottom, i) - self.bottom_pixmap.width() / 2,
                self.transform(yLeft, bottom), self.bottom_pixmap)
            painter.drawPixmap(
                self.transform(xBottom, i) - self.top_pixmap.width() / 2,
                self.transform(yLeft, top) - self.top_pixmap.height(),
                self.top_pixmap)

    def auto_send_selection_callback(self):
        pass

    def clear(self):
        super().clear()

        self.attributes = []
        self.visualized_mid_labels = []
        self.selected_examples = []
        self.unselected_examples = []
        self.selection_conditions = {}
Пример #17
0
class OWDBSCAN(widget.OWWidget):
    name = "DBSCAN"
    description = "Density-based spatial clustering."
    icon = "icons/DBSCAN.svg"
    priority = 2150

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    class Error(widget.OWWidget.Error):
        not_enough_instances = Msg("Not enough unique data instances. "
                                   "At least two are required.")

    METRICS = [
        ("Euclidean", "euclidean"),
        ("Manhattan", "cityblock"),
        ("Cosine", "cosine")
    ]

    min_samples = Setting(4)
    eps = Setting(0.5)
    metric_idx = Setting(0)
    auto_commit = Setting(True)
    k_distances = None
    cut_point = None

    def __init__(self):
        super().__init__()

        self.data = None
        self.data_normalized = None
        self.db = None
        self.model = None

        box = gui.widgetBox(self.controlArea, "Parameters")
        gui.spin(box, self, "min_samples", 1, 100, 1,
                 callback=self._min_samples_changed,
                 label="Core point neighbors")
        gui.doubleSpin(box, self, "eps", EPS_BOTTOM_LIMIT, 1000, 0.01,
                       callback=self._eps_changed,
                       label="Neighborhood distance")

        box = gui.widgetBox(self.controlArea, self.tr("Distance Metric"))
        gui.comboBox(box, self, "metric_idx",
                     items=list(zip(*self.METRICS))[0],
                     callback=self._metirc_changed)

        gui.auto_apply(self.controlArea, self, "auto_commit")
        gui.rubber(self.controlArea)

        self.controlArea.layout().addStretch()

        self.plot = SliderGraph(
            x_axis_label="Data items sorted by score",
            y_axis_label="Distance to the k-th nearest neighbour",
            callback=self._on_cut_changed
        )

        self.mainArea.layout().addWidget(self.plot)

    def check_data_size(self, data):
        if data is None:
            return False
        if len(data) < 2:
            self.Error.not_enough_instances()
            return False
        return True

    def commit(self):
        self.cluster()

    def cluster(self):
        if not self.check_data_size(self.data):
            return
        self.model = DBSCAN(
            eps=self.eps,
            min_samples=self.min_samples,
            metric=self.METRICS[self.metric_idx][1]
        ).get_model(self.data_normalized)
        self.send_data()

    def _compute_and_plot(self, cut_point=None):
        self._compute_kdistances()
        if cut_point is None:
            self._compute_cut_point()
        self._plot_graph()

    def _plot_graph(self):
        nonzero = np.sum(self.k_distances > EPS_BOTTOM_LIMIT)
        self.plot.update(np.arange(len(self.k_distances)),
                         [self.k_distances],
                         colors=[QColor('red')],
                         cutpoint_x=self.cut_point,
                         selection_limit=(0, nonzero - 1))

    def _compute_kdistances(self):
        self.k_distances = get_kth_distances(
            self.data_normalized, metric=self.METRICS[self.metric_idx][1],
            k=self.min_samples
        )

    def _compute_cut_point(self):
        self.cut_point = int(DEFAULT_CUT_POINT * len(self.k_distances))
        self.eps = self.k_distances[self.cut_point]

        if self.eps < EPS_BOTTOM_LIMIT:
            self.eps = np.min(
                self.k_distances[self.k_distances >= EPS_BOTTOM_LIMIT])
            self.cut_point = self._find_nearest_dist(self.eps)

    @Inputs.data
    def set_data(self, data):
        self.Error.clear()
        if not self.check_data_size(data):
            data = None
        self.data = self.data_normalized = data
        if self.data is None:
            self.Outputs.annotated_data.send(None)
            self.plot.clear_plot()
            return

        if self.data is None:
            return

        # preprocess data
        for pp in PREPROCESSORS:
            self.data_normalized = pp(self.data_normalized)

        self._compute_and_plot()
        self.unconditional_commit()

    def send_data(self):
        model = self.model

        clusters = [c if c >= 0 else np.nan for c in model.labels]
        k = len(set(clusters) - {np.nan})
        clusters = np.array(clusters).reshape(len(self.data), 1)
        core_samples = set(model.projector.core_sample_indices_)
        in_core = np.array([1 if (i in core_samples) else 0
                            for i in range(len(self.data))])
        in_core = in_core.reshape(len(self.data), 1)

        clust_var = DiscreteVariable(
            "Cluster", values=["C%d" % (x + 1) for x in range(k)])
        in_core_var = DiscreteVariable("DBSCAN Core", values=("0", "1"))

        domain = self.data.domain
        attributes, classes = domain.attributes, domain.class_vars
        meta_attrs = domain.metas
        x, y, metas = self.data.X, self.data.Y, self.data.metas

        meta_attrs += (clust_var, )
        metas = np.hstack((metas, clusters))
        meta_attrs += (in_core_var, )
        metas = np.hstack((metas, in_core))

        domain = Domain(attributes, classes, meta_attrs)
        new_table = Table(domain, x, y, metas, self.data.W)

        self.Outputs.annotated_data.send(new_table)

    def _invalidate(self):
        self.commit()

    def _find_nearest_dist(self, value):
        array = np.asarray(self.k_distances)
        idx = (np.abs(array - value)).argmin()
        return idx

    def _eps_changed(self):
        # find the closest value to eps
        if self.data is None:
            return
        self.cut_point = self._find_nearest_dist(self.eps)
        self.plot.set_cut_point(self.cut_point)
        self._invalidate()

    def _metirc_changed(self):
        if self.data is not None:
            self._compute_and_plot()
            self._invalidate()

    def _on_cut_changed(self, value):
        # cut changed by means of a cut line over the scree plot.
        self.cut_point = value
        self.eps = self.k_distances[value]

        self.commit()

    def _min_samples_changed(self):
        if self.data is None:
            return
        self._compute_and_plot(cut_point=self.cut_point)
        self._invalidate()
Пример #18
0
class RadvizVizRank(VizRankDialog, OWComponent):
    captionTitle = "Score Plots"
    n_attrs = Setting(3)
    minK = 10

    attrsSelected = Signal([])
    _AttrRole = next(gui.OrangeUserRole)

    percent_data_used = Setting(100)

    def __init__(self, master):
        """Add the spin box for maximal number of attributes"""
        VizRankDialog.__init__(self, master)
        OWComponent.__init__(self, master)

        self.master = master
        self.n_neighbors = 10
        max_n_attrs = len(master.model_selected) + len(master.model_other) - 1

        box = gui.hBox(self)
        self.n_attrs_spin = gui.spin(box,
                                     self,
                                     "n_attrs",
                                     3,
                                     max_n_attrs,
                                     label="Maximum number of variables: ",
                                     controlWidth=50,
                                     alignment=Qt.AlignRight,
                                     callback=self._n_attrs_changed)
        gui.rubber(box)
        self.last_run_n_attrs = None
        self.attr_color = master.attr_color
        self.attr_ordering = None
        self.data = None
        self.valid_data = None

    def initialize(self):
        super().initialize()
        self.attr_color = self.master.attr_color

    def _compute_attr_order(self):
        """
        used by VizRank to evaluate attributes
        """
        master = self.master
        attrs = [
            v for v in chain(master.model_selected[:], master.model_other[:])
            if v is not self.attr_color
        ]
        data = self.master.data.transform(
            Domain(attributes=attrs, class_vars=self.attr_color))
        self.data = data
        self.valid_data = np.hstack(
            (~np.isnan(data.X), ~np.isnan(data.Y.reshape(len(data.Y), 1))))
        relief = ReliefF if self.attr_color.is_discrete else RReliefF
        weights = relief(n_iterations=100, k_nearest=self.minK)(data)
        attrs = sorted(zip(weights, attrs), key=lambda x: (-x[0], x[1].name))
        self.attr_ordering = attr_ordering = [a for _, a in attrs]
        return attr_ordering

    def _evaluate_projection(self, x, y):
        """
        kNNEvaluate - evaluate class separation in the given projection using a k-NN method
        Parameters
        ----------
        x - variables to evaluate
        y - class

        Returns
        -------
        scores
        """
        if self.percent_data_used != 100:
            rand = np.random.choice(len(x),
                                    int(len(x) * self.percent_data_used / 100),
                                    replace=False)
            x = x[rand]
            y = y[rand]
        neigh = KNeighborsClassifier(n_neighbors=3) if self.attr_color.is_discrete else \
            KNeighborsRegressor(n_neighbors=3)
        assert ~(np.isnan(x).any(axis=None) | np.isnan(x).any(axis=None))
        neigh.fit(x, y)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            scores = cross_val_score(neigh, x, y, cv=3)
        return scores.mean()

    def _n_attrs_changed(self):
        """
        Change the button label when the number of attributes changes. The method does not reset
        anything so the user can still see the results until actually restarting the search.
        """
        if self.n_attrs != self.last_run_n_attrs or self.saved_state is None:
            self.button.setText("Start")
        else:
            self.button.setText("Continue")
        self.button.setEnabled(self.check_preconditions())

    def progressBarSet(self, value, processEvents=None):
        self.setWindowTitle(self.captionTitle +
                            " Evaluated {} permutations".format(value))
        if processEvents is not None and processEvents is not False:
            qApp.processEvents(processEvents)

    def check_preconditions(self):
        master = self.master
        if not super().check_preconditions():
            return False
        elif not master.btn_vizrank.isEnabled():
            return False
        self.n_attrs_spin.setMaximum(20)  # all primitive vars except color one
        return True

    def on_selection_changed(self, selected, deselected):
        attrs = selected.indexes()[0].data(self._AttrRole)
        self.selectionChanged.emit([attrs])

    def iterate_states(self, state):
        if state is None:  # on the first call, compute order
            self.attrs = self._compute_attr_order()
            state = list(range(3))
        else:
            state = list(state)

        def combinations(n, s):
            while True:
                yield s
                for up, _ in enumerate(s):
                    s[up] += 1
                    if up + 1 == len(s) or s[up] < s[up + 1]:
                        break
                    s[up] = up
                if s[-1] == n:
                    if len(s) < self.n_attrs:
                        s = list(range(len(s) + 1))
                    else:
                        break

        for c in combinations(len(self.attrs), state):
            for p in islice(permutations(c[1:]), factorial(len(c) - 1) // 2):
                yield (c[0], ) + p

    def compute_score(self, state):
        attrs = [self.attrs[i] for i in state]
        domain = Domain(attributes=attrs, class_vars=[self.attr_color])
        data = self.data.transform(domain)
        radviz_xy, _, mask = radviz(data, attrs)
        y = data.Y[mask]
        return -self._evaluate_projection(radviz_xy, y)

    def bar_length(self, score):
        return -score

    def row_for_state(self, score, state):
        attrs = [self.attrs[s] for s in state]
        item = QStandardItem("[{:0.6f}] ".format(-score) +
                             ", ".join(a.name for a in attrs))
        item.setData(attrs, self._AttrRole)
        return [item]

    def _update_progress(self):
        self.progressBarSet(int(self.saved_progress))

    def before_running(self):
        """
        Disable the spin for number of attributes before running and
        enable afterwards. Also, if the number of attributes is different than
        in the last run, reset the saved state (if it was paused).
        """
        if self.n_attrs != self.last_run_n_attrs:
            self.saved_state = None
            self.saved_progress = 0
        if self.saved_state is None:
            self.scores = []
            self.rank_model.clear()
        self.last_run_n_attrs = self.n_attrs
        self.n_attrs_spin.setDisabled(True)

    def stopped(self):
        self.n_attrs_spin.setDisabled(False)
Пример #19
0
class OWDataSampler(OWWidget):
    name = "数据采样器(Data Sampler)"
    description = "从输入数据集中随机抽取数据点的子集 "

    icon = "icons/DataSampler.svg"
    priority = 100
    category = "Data"
    keywords = ["random"]

    _MAX_SAMPLE_SIZE = 2**31 - 1

    class Inputs:
        data = Input("数据(Data)", Table, replaces=['Data'])

    class Outputs:
        data_sample = Output("数据样本(Data Sample)",
                             Table,
                             default=True,
                             replaces=['Data Sample'])
        remaining_data = Output("剩余数据(Remaining Data)",
                                Table,
                                replaces=['Remaining Data'])

    want_main_area = False
    resizing_enabled = False

    RandomSeed = 42
    FixedProportion, FixedSize, CrossValidation, Bootstrap = range(4)
    SqlTime, SqlProportion = range(2)

    selectedFold: int

    use_seed = Setting(True)
    replacement = Setting(False)
    stratify = Setting(False)
    sql_dl = Setting(False)
    sampling_type = Setting(FixedProportion)
    sampleSizeNumber = Setting(1)
    sampleSizePercentage = Setting(70)
    sampleSizeSqlTime = Setting(1)
    sampleSizeSqlPercentage = Setting(0.1)
    number_of_folds = Setting(10)
    selectedFold = Setting(1)

    class Warning(OWWidget.Warning):
        could_not_stratify = Msg("Stratification failed\n{}")
        bigger_sample = Msg('Sample is bigger than input')

    class Error(OWWidget.Error):
        too_many_folds = Msg("Number of folds exceeds data size")
        sample_larger_than_data = Msg("Sample can't be larger than data")
        not_enough_to_stratify = Msg("Data is too small to stratify")
        no_data = Msg("Dataset is empty")

    def __init__(self):
        super().__init__()
        self.data = None
        self.indices = None
        self.sampled_instances = self.remaining_instances = None

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoInput)

        self.sampling_box = gui.vBox(self.controlArea, "采样类型")
        sampling = gui.radioButtons(self.sampling_box,
                                    self,
                                    "sampling_type",
                                    callback=self.sampling_type_changed)

        def set_sampling_type(i):
            def set_sampling_type_i():
                self.sampling_type = i
                self.sampling_type_changed()

            return set_sampling_type_i

        gui.appendRadioButton(sampling, "固定数据比例(Fixed proportion of data):")
        self.sampleSizePercentageSlider = gui.hSlider(
            gui.indentedBox(sampling),
            self,
            "sampleSizePercentage",
            minValue=0,
            maxValue=100,
            ticks=10,
            labelFormat="%d %%",
            callback=set_sampling_type(self.FixedProportion),
            addSpace=12)

        gui.appendRadioButton(sampling, "固定样本量(Fixed sample size)")
        ibox = gui.indentedBox(sampling)
        self.sampleSizeSpin = gui.spin(ibox,
                                       self,
                                       "sampleSizeNumber",
                                       label="实例量: ",
                                       minv=1,
                                       maxv=self._MAX_SAMPLE_SIZE,
                                       callback=set_sampling_type(
                                           self.FixedSize),
                                       controlWidth=90)
        gui.checkBox(ibox,
                     self,
                     "replacement",
                     "放回抽样(Sample with replacement)",
                     callback=set_sampling_type(self.FixedSize),
                     addSpace=12)

        gui.appendRadioButton(sampling, "交叉验证(Cross validation)")
        form = QFormLayout(formAlignment=Qt.AlignLeft | Qt.AlignTop,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        ibox = gui.indentedBox(sampling, addSpace=True, orientation=form)
        form.addRow(
            "折叠次数(Number of folds):",
            gui.spin(ibox,
                     self,
                     "number_of_folds",
                     2,
                     100,
                     addToLayout=False,
                     callback=self.number_of_folds_changed))
        self.selected_fold_spin = gui.spin(ibox,
                                           self,
                                           "selectedFold",
                                           1,
                                           self.number_of_folds,
                                           addToLayout=False,
                                           callback=self.fold_changed)
        form.addRow("选定的折叠(Selected fold):", self.selected_fold_spin)

        gui.appendRadioButton(sampling, "Bootstrap")

        self.sql_box = gui.vBox(self.controlArea, "Sampling Type")
        sampling = gui.radioButtons(self.sql_box,
                                    self,
                                    "sampling_type",
                                    callback=self.sampling_type_changed)
        gui.appendRadioButton(sampling, "Time:")
        ibox = gui.indentedBox(sampling)
        spin = gui.spin(ibox,
                        self,
                        "sampleSizeSqlTime",
                        minv=1,
                        maxv=3600,
                        callback=set_sampling_type(self.SqlTime))
        spin.setSuffix(" sec")
        gui.appendRadioButton(sampling, "Percentage")
        ibox = gui.indentedBox(sampling)
        spin = gui.spin(ibox,
                        self,
                        "sampleSizeSqlPercentage",
                        spinType=float,
                        minv=0.0001,
                        maxv=100,
                        step=0.1,
                        decimals=4,
                        callback=set_sampling_type(self.SqlProportion))
        spin.setSuffix(" %")
        self.sql_box.setVisible(False)

        self.options_box = gui.vBox(self.controlArea, "选项")
        self.cb_seed = gui.checkBox(
            self.options_box,
            self,
            "use_seed",
            "可复制(确定性)抽样 Replicable (deterministic) sampling",
            callback=self.settings_changed)
        self.cb_stratify = gui.checkBox(self.options_box,
                                        self,
                                        "stratify",
                                        "分层抽样(如果可能)",
                                        callback=self.settings_changed)
        self.cb_sql_dl = gui.checkBox(self.options_box,
                                      self,
                                      "sql_dl",
                                      "Download data to local memory",
                                      callback=self.settings_changed)
        self.cb_sql_dl.setVisible(False)

        gui.button(self.buttonsArea,
                   self,
                   "执行抽样(Sample Data)",
                   callback=self.commit)

    def sampling_type_changed(self):
        self.settings_changed()

    def number_of_folds_changed(self):
        self.selected_fold_spin.setMaximum(self.number_of_folds)
        self.sampling_type = self.CrossValidation
        self.settings_changed()

    def fold_changed(self):
        # a separate callback - if we decide to cache indices
        self.sampling_type = self.CrossValidation

    def settings_changed(self):
        self._update_sample_max_size()
        self.indices = None

    @Inputs.data
    def set_data(self, dataset):
        self.data = dataset
        if dataset is not None:
            sql = isinstance(dataset, SqlTable)
            self.sampling_box.setVisible(not sql)
            self.sql_box.setVisible(sql)
            self.cb_seed.setVisible(not sql)
            self.cb_stratify.setVisible(not sql)
            self.cb_sql_dl.setVisible(sql)
            self.info.set_input_summary(str(len(dataset)))

            if not sql:
                self._update_sample_max_size()
                self.updateindices()
        else:
            self.info.set_input_summary(self.info.NoInput)
            self.info.set_output_summary(self.info.NoInput)
            self.indices = None
            self.clear_messages()
        self.commit()

    def _update_sample_max_size(self):
        """Limit number of instances to input size unless using replacement."""
        if not self.data or self.replacement:
            self.sampleSizeSpin.setMaximum(self._MAX_SAMPLE_SIZE)
        else:
            self.sampleSizeSpin.setMaximum(len(self.data))

    def commit(self):
        if self.data is None:
            sample = other = None
            self.sampled_instances = self.remaining_instances = None
        elif isinstance(self.data, SqlTable):
            other = None
            if self.sampling_type == self.SqlProportion:
                sample = self.data.sample_percentage(
                    self.sampleSizeSqlPercentage, no_cache=True)
            else:
                sample = self.data.sample_time(self.sampleSizeSqlTime,
                                               no_cache=True)
            if self.sql_dl:
                sample.download_data()
                sample = Table(sample)

        else:
            if self.indices is None or not self.use_seed:
                self.updateindices()
                if self.indices is None:
                    return
            if self.sampling_type in (self.FixedProportion, self.FixedSize,
                                      self.Bootstrap):
                remaining, sample = self.indices
            elif self.sampling_type == self.CrossValidation:
                remaining, sample = self.indices[self.selectedFold - 1]
            self.info.set_output_summary(str(len(sample)))

            sample = self.data[sample]
            other = self.data[remaining]
            self.sampled_instances = len(sample)
            self.remaining_instances = len(other)
        self.Outputs.data_sample.send(sample)
        self.Outputs.remaining_data.send(other)

    def updateindices(self):
        self.Error.clear()
        self.Warning.clear()
        repl = True
        data_length = len(self.data)
        num_classes = len(self.data.domain.class_var.values) \
            if self.data.domain.has_discrete_class else 0

        size = None
        if not data_length:
            self.Error.no_data()
        elif self.sampling_type == self.FixedSize:
            size = self.sampleSizeNumber
            repl = self.replacement
        elif self.sampling_type == self.FixedProportion:
            size = np.ceil(self.sampleSizePercentage / 100 * data_length)
            repl = False
        elif self.sampling_type == self.CrossValidation:
            if data_length < self.number_of_folds:
                self.Error.too_many_folds()
        else:
            assert self.sampling_type == self.Bootstrap

        if not repl and size is not None and (size > data_length):
            self.Error.sample_larger_than_data()
        if not repl and data_length <= num_classes and self.stratify:
            self.Error.not_enough_to_stratify()

        if self.Error.active:
            self.indices = None
            return

        # By the above, we can safely assume there is data
        if self.sampling_type == self.FixedSize and repl and size and \
                size > len(self.data):
            # This should only be possible when using replacement
            self.Warning.bigger_sample()

        stratified = (self.stratify and isinstance(self.data, Table)
                      and self.data.domain.has_discrete_class)
        try:
            self.indices = self.sample(data_length, size, stratified)
        except ValueError as ex:
            self.Warning.could_not_stratify(str(ex))
            self.indices = self.sample(data_length, size, stratified=False)

    def sample(self, data_length, size, stratified):
        rnd = self.RandomSeed if self.use_seed else None
        if self.sampling_type == self.FixedSize:
            sampler = SampleRandomN(size,
                                    stratified=stratified,
                                    replace=self.replacement,
                                    random_state=rnd)
        elif self.sampling_type == self.FixedProportion:
            sampler = SampleRandomP(self.sampleSizePercentage / 100,
                                    stratified=stratified,
                                    random_state=rnd)
        elif self.sampling_type == self.Bootstrap:
            sampler = SampleBootstrap(data_length, random_state=rnd)
        else:
            sampler = SampleFoldIndices(self.number_of_folds,
                                        stratified=stratified,
                                        random_state=rnd)
        return sampler(self.data)

    def send_report(self):
        if self.sampling_type == self.FixedProportion:
            tpe = "Random sample with {} % of data".format(
                self.sampleSizePercentage)
        elif self.sampling_type == self.FixedSize:
            if self.sampleSizeNumber == 1:
                tpe = "Random data instance"
            else:
                tpe = "Random sample with {} data instances".format(
                    self.sampleSizeNumber)
                if self.replacement:
                    tpe += ", with replacement"
        elif self.sampling_type == self.CrossValidation:
            tpe = "Fold {} of {}-fold cross-validation".format(
                self.selectedFold, self.number_of_folds)
        else:
            tpe = "Undefined"  # should not come here at all
        if self.stratify:
            tpe += ", stratified (if possible)"
        if self.use_seed:
            tpe += ", deterministic"
        items = [("Sampling type", tpe)]
        if self.sampled_instances is not None:
            items += [
                ("Input", "{} instances".format(len(self.data))),
                ("Sample", "{} instances".format(self.sampled_instances)),
                ("Remaining", "{} instances".format(self.remaining_instances)),
            ]
        self.report_items(items)
Пример #20
0
class OWChoroplethPlotGraph(gui.OWComponent, QObject):
    """
    Main class containing functionality for piloting `ChoroplethItem`.
    It is wary similar to `OWScatterPlotBase`. In fact some functionality
    is directly copied from there.
    """

    alpha_value = Setting(128)
    show_legend = Setting(True)

    def __init__(self, widget, parent=None):
        QObject.__init__(self)
        gui.OWComponent.__init__(self, widget)

        self.view_box = MapViewBox(self)
        self.plot_widget = pg.PlotWidget(viewBox=self.view_box, parent=parent,
                                         background="w")
        self.plot_widget.hideAxis("left")
        self.plot_widget.hideAxis("bottom")
        self.plot_widget.getPlotItem().buttonsHidden = True
        self.plot_widget.setAntialiasing(True)
        self.plot_widget.sizeHint = lambda: QSize(500, 500)

        self.master = widget  # type: OWChoropleth
        self._create_drag_tooltip(self.plot_widget.scene())

        self.choropleth_items = []  # type: List[ChoroplethItem]

        self.n_ids = 0
        self.selection = None  # np.ndarray

        self.palette = None
        self.color_legend = self._create_legend(((1, 1), (1, 1)))
        self.update_legend_visibility()

        self._tooltip_delegate = HelpEventDelegate(self.help_event)
        self.plot_widget.scene().installEventFilter(self._tooltip_delegate)

    def _create_legend(self, anchor):
        legend = LegendItem()
        legend.setParentItem(self.plot_widget.getViewBox())
        legend.restoreAnchor(anchor)
        return legend

    def _create_drag_tooltip(self, scene):
        tip_parts = [
            (Qt.ShiftModifier, "Shift: Add group"),
            (Qt.ShiftModifier + Qt.ControlModifier,
             "Shift-{}: Append to group".
             format("Cmd" if sys.platform == "darwin" else "Ctrl")),
            (Qt.AltModifier, "Alt: Remove")
        ]
        all_parts = ", ".join(part for _, part in tip_parts)
        self.tiptexts = {
            int(modifier): all_parts.replace(part, "<b>{}</b>".format(part))
            for modifier, part in tip_parts
        }
        self.tiptexts[0] = all_parts

        self.tip_textitem = text = QGraphicsTextItem()
        # Set to the longest text
        text.setHtml(self.tiptexts[Qt.ShiftModifier + Qt.ControlModifier])
        text.setPos(4, 2)
        r = text.boundingRect()
        rect = QGraphicsRectItem(0, 0, r.width() + 8, r.height() + 4)
        rect.setBrush(QColor(224, 224, 224, 212))
        rect.setPen(QPen(Qt.NoPen))
        self.update_tooltip()

        scene.drag_tooltip = scene.createItemGroup([rect, text])
        scene.drag_tooltip.hide()

    def update_tooltip(self, modifiers=Qt.NoModifier):
        modifiers &= Qt.ShiftModifier + Qt.ControlModifier + Qt.AltModifier
        text = self.tiptexts.get(int(modifiers), self.tiptexts[0])
        self.tip_textitem.setHtml(text)

    def clear(self):
        self.plot_widget.clear()
        self.color_legend.clear()
        self.update_legend_visibility()
        self.choropleth_items = []
        self.n_ids = 0
        self.selection = None

    def reset_graph(self):
        """Reset plot on data change."""
        self.clear()
        self.selection = None
        self.update_choropleth()
        self.update_colors()

    def update_choropleth(self):
        """Draw new polygons."""
        pen = self._make_pen(QColor(Qt.white), 1)
        brush = QBrush(Qt.NoBrush)
        regions = self.master.get_choropleth_regions()
        for region in regions:
            choropleth_item = ChoroplethItem(region, pen=pen, brush=brush)
            choropleth_item.itemClicked.connect(self.select_by_id)
            self.plot_widget.addItem(choropleth_item)
            self.choropleth_items.append(choropleth_item)

        if self.choropleth_items:
            self.n_ids = len(self.master.region_ids)

    def update_colors(self):
        """Update agg_value and inner color of existing polygons."""
        if not self.choropleth_items:
            return

        agg_data = self.master.get_agg_data()
        brushes = self.get_colors()
        for ci, d, b in zip(self.choropleth_items, agg_data, brushes):
            ci.agg_value = self.master.format_agg_val(d)
            ci.setBrush(b)
        self.update_legends()

    def get_colors(self):
        self.palette = self.master.get_palette()
        c_data = self.master.get_color_data()
        if c_data is None:
            self.palette = None
            return []
        elif self.master.is_mode():
            return self._get_discrete_colors(c_data)
        else:
            return self._get_continuous_colors(c_data)

    def _get_continuous_colors(self, c_data):
        palette = self.master.get_palette()
        bins = self.master.get_binning().thresholds
        self.palette = BinnedContinuousPalette.from_palette(palette, bins)
        rgb = self.palette.values_to_colors(c_data)
        rgba = np.hstack(
            [rgb, np.full((len(rgb), 1), self.alpha_value, dtype=np.ubyte)])

        return [QBrush(QColor(*col)) for col in rgba]

    def _get_discrete_colors(self, c_data):
        self.palette = self.master.get_palette()
        c_data = c_data.copy()
        c_data[np.isnan(c_data)] = len(self.palette)
        c_data = c_data.astype(int)
        colors = self.palette.qcolors_w_nan
        for col in colors:
            col.setAlpha(self.alpha_value)
        brushes = np.array([QBrush(col) for col in colors])
        return brushes[c_data]

    def update_legends(self):
        color_labels = self.master.get_color_labels()
        self.color_legend.clear()
        if self.master.is_mode():
            self._update_color_legend(color_labels)
        else:
            self._update_continuous_color_legend(color_labels)
        self.update_legend_visibility()

    def _update_continuous_color_legend(self, label_formatter):
        label = BinningPaletteItemSample(self.palette,
                                         self.master.get_binning(),
                                         label_formatter)
        self.color_legend.addItem(label, "")
        self.color_legend.setGeometry(label.boundingRect())

    def _update_color_legend(self, labels):
        symbols = ['o' for _ in range(len(labels))]
        colors = self.palette.values_to_colors(np.arange(len(labels)))
        for color, label, symbol in zip(colors, labels, symbols):
            color = QColor(*color)
            pen = self._make_pen(color.darker(120), 1.5)
            color.setAlpha(self.alpha_value)
            brush = QBrush(color)
            sis = SymbolItemSample(pen=pen, brush=brush, size=10, symbol=symbol)
            self.color_legend.addItem(sis, escape(label))

    def update_legend_visibility(self):
        self.color_legend.setVisible(
            self.show_legend and bool(self.color_legend.items))

    def update_selection_colors(self):
        """
        Update color of selected regions.
        """
        pens = self.get_colors_sel()
        for ci, pen in zip(self.choropleth_items, pens):
            ci.setPen(pen)

    def get_colors_sel(self):
        white_pen = self._make_pen(QColor(Qt.white), 1)
        if self.selection is None:
            pen = [white_pen] * self.n_ids
        else:
            sels = np.max(self.selection)
            if sels == 1:
                orange_pen = self._make_pen(QColor(255, 190, 0, 255), 3)
                pen = np.where(self.selection, orange_pen, white_pen)
            else:
                palette = LimitedDiscretePalette(number_of_colors=sels + 1)
                pens = [white_pen] + [self._make_pen(palette[i], 3)
                                      for i in range(sels)]
                pen = np.choose(self.selection, pens)
        return pen

    @staticmethod
    def _make_pen(color, width):
        p = QPen(color, width)
        p.setCosmetic(True)
        return p

    def zoom_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def pan_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().PanMode)

    def select_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def select_by_id(self, region_id):
        """
        This is called by a `ChoroplethItem` on click.
        The selection is then based on the corresponding region.
        """
        indices = np.where(self.master.region_ids == region_id)[0]
        self.select_by_indices(indices)

    def select_by_rectangle(self, rect: QRectF):
        """
        Find regions that intersect with selected rectangle.
        """
        poly_rect = QPolygonF(rect)
        indices = set()
        for ci in self.choropleth_items:
            if ci.intersects(poly_rect):
                indices.add(np.where(self.master.region_ids == ci.region.id)[0][0])
        if indices:
            self.select_by_indices(np.array(list(indices)))

    def unselect_all(self):
        if self.selection is not None:
            self.selection = None
            self.update_selection_colors()
            self.master.selection_changed()

    def select_by_indices(self, indices):
        if self.selection is None:
            self.selection = np.zeros(self.n_ids, dtype=np.uint8)
        keys = QApplication.keyboardModifiers()
        if keys & Qt.AltModifier:
            self.selection_remove(indices)
        elif keys & Qt.ShiftModifier and keys & Qt.ControlModifier:
            self.selection_append(indices)
        elif keys & Qt.ShiftModifier:
            self.selection_new_group(indices)
        else:
            self.selection_select(indices)

    def selection_select(self, indices):
        self.selection = np.zeros(self.n_ids, dtype=np.uint8)
        self.selection[indices] = 1
        self._update_after_selection()

    def selection_append(self, indices):
        self.selection[indices] = np.max(self.selection)
        self._update_after_selection()

    def selection_new_group(self, indices):
        self.selection[indices] = np.max(self.selection) + 1
        self._update_after_selection()

    def selection_remove(self, indices):
        self.selection[indices] = 0
        self._update_after_selection()

    def _update_after_selection(self):
        self._compress_indices()
        self.update_selection_colors()
        self.master.selection_changed()

    def _compress_indices(self):
        indices = sorted(set(self.selection) | {0})
        if len(indices) == max(indices) + 1:
            return
        mapping = np.zeros((max(indices) + 1,), dtype=int)
        for i, ind in enumerate(indices):
            mapping[ind] = i
        self.selection = mapping[self.selection]

    def get_selection(self):
        if self.selection is None:
            return np.zeros(self.n_ids, dtype=np.uint8)
        else:
            return self.selection

    def help_event(self, event):
        """Tooltip"""
        if not self.choropleth_items:
            return False
        act_pos = self.choropleth_items[0].mapFromScene(event.scenePos())
        ci = next((ci for ci in self.choropleth_items
                   if ci.contains(act_pos)), None)
        if ci is not None:
            QToolTip.showText(event.screenPos(), ci.tooltip(),
                              widget=self.plot_widget)
            return True
        else:
            return False
Пример #21
0
class OWSave(widget.OWWidget):
    name = "Save Data"
    description = "Save data to an output file."
    icon = "icons/Save.svg"
    category = "Data"
    keywords = ["data", "save"]

    class Inputs:
        data = Input("Data", Table)

    want_main_area = False
    resizing_enabled = False

    last_dir = Setting("")
    last_filter = Setting("")
    auto_save = Setting(False)

    @classmethod
    def get_writers(cls, sparse):
        return [
            f for f in FileFormat.formats if getattr(f, 'write_file', None)
            and getattr(f, "EXTENSIONS", None) and (
                not sparse or getattr(f, 'SUPPORT_SPARSE_DATA', False))
        ]

    def __init__(self):
        super().__init__()
        self.data = None
        self.filename = ""
        self.writer = None

        self.save = gui.auto_commit(self.controlArea,
                                    self,
                                    "auto_save",
                                    "Save",
                                    box=False,
                                    commit=self.save_file,
                                    callback=self.adjust_label,
                                    disabled=True,
                                    addSpace=True)
        self.saveAs = gui.button(self.controlArea,
                                 self,
                                 "Save As...",
                                 callback=self.save_file_as,
                                 disabled=True)
        self.saveAs.setMinimumWidth(220)
        self.adjustSize()

    def adjust_label(self):
        if self.filename:
            filename = os.path.split(self.filename)[1]
            text = ["Save as '{}'", "Auto save as '{}'"][self.auto_save]
            self.save.button.setText(text.format(filename))

    @Inputs.data
    def dataset(self, data):
        self.data = data
        self.save.setDisabled(data is None)
        self.saveAs.setDisabled(data is None)
        if data is not None:
            self.save_file()

    def save_file_as(self):
        file_name = self.filename or \
            os.path.join(self.last_dir or os.path.expanduser("~"),
                         getattr(self.data, 'name', ''))
        filename, writer, filter = filedialogs.open_filename_dialog_save(
            file_name, self.last_filter,
            self.get_writers(self.data.is_sparse()))
        if not filename:
            return
        self.filename = filename
        self.writer = writer
        self.unconditional_save_file()
        self.last_dir = os.path.split(self.filename)[0]
        self.last_filter = filter
        self.adjust_label()

    def save_file(self):
        if self.data is None:
            return
        if not self.filename:
            self.save_file_as()
        else:
            try:
                self.writer.write(self.filename, self.data)
            except Exception as errValue:
                self.error(str(errValue))
            else:
                self.error()
Пример #22
0
class OWChoropleth(OWWidget):
    """
    This is to `OWDataProjectionWidget` what
    `OWChoroplethPlotGraph` is to `OWScatterPlotBase`.
    """

    name = 'Choropleth Map'
    description = 'A thematic map in which areas are shaded in proportion ' \
                  'to the measurement of the statistical variable being displayed.'
    icon = "icons/Choropleth.svg"
    priority = 120

    class Inputs:
        data = Input("Data", Table, default=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settings_version = 2
    settingsHandler = DomainContextHandler()
    selection = Setting(None, schema_only=True)
    auto_commit = Setting(True)

    attr_lat = ContextSetting(None)
    attr_lon = ContextSetting(None)

    agg_attr = ContextSetting(None)
    agg_func = ContextSetting(DEFAULT_AGG_FUNC)
    admin_level = Setting(0)
    binning_index = Setting(0)

    GRAPH_CLASS = OWChoroplethPlotMapGraph
    graph = SettingProvider(OWChoroplethPlotMapGraph)
    graph_name = "graph.plot_widget.plotItem"

    input_changed = Signal(object)
    output_changed = Signal(object)

    class Error(OWWidget.Error):
        no_lat_lon_vars = Msg("Data has no latitude and longitude variables.")

    class Warning(OWWidget.Warning):
        no_region = Msg("{} points are not in any region.")

    def __init__(self):
        super().__init__()
        self.data = None
        self.data_ids = None  # type: Optional[np.ndarray]

        self.agg_data = None  # type: Optional[np.ndarray]
        self.region_ids = None  # type: Optional[np.ndarray]

        self.choropleth_regions = []
        self.binnings = []

        self.input_changed.connect(self.set_input_summary)
        self.output_changed.connect(self.set_output_summary)
        self.setup_gui()

    def setup_gui(self):
        self._add_graph()
        self._add_controls()
        self.input_changed.emit(None)
        self.output_changed.emit(None)

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = self.GRAPH_CLASS(self, box)
        box.layout().addWidget(self.graph.plot_widget)

    def _add_controls(self):
        options = dict(
            labelWidth=75, orientation=Qt.Horizontal, sendSelectedValue=True,
            contentsLength=14
        )

        lat_lon_box = gui.vBox(self.controlArea, True)
        self.lat_lon_model = DomainModel(DomainModel.MIXED,
                                         valid_types=(ContinuousVariable,))
        gui.comboBox(lat_lon_box, self, 'attr_lat', label='Latitude:',
                     callback=self.setup_plot, model=self.lat_lon_model,
                     **options, searchable=True)

        gui.comboBox(lat_lon_box, self, 'attr_lon', label='Longitude:',
                     callback=self.setup_plot, model=self.lat_lon_model,
                     **options, searchable=True)

        agg_box = gui.vBox(self.controlArea, True)
        self.agg_attr_model = DomainModel(valid_types=(ContinuousVariable,
                                                       DiscreteVariable))
        gui.comboBox(agg_box, self, 'agg_attr', label='Attribute:',
                     callback=self.update_agg, model=self.agg_attr_model,
                     **options, searchable=True)

        self.agg_func_combo = gui.comboBox(agg_box, self, 'agg_func',
                                           label='Agg.:',
                                           items=[DEFAULT_AGG_FUNC],
                                           callback=self.graph.update_colors,
                                           **options)

        a_slider = gui.hSlider(agg_box, self, 'admin_level', minValue=0,
                               maxValue=2, step=1, label='Detail:',
                               createLabel=False, callback=self.setup_plot)
        a_slider.setFixedWidth(176)

        visualization_box = gui.vBox(self.controlArea, True)
        b_slider = gui.hSlider(visualization_box, self, "binning_index",
                               label="Bin width:", minValue=0,
                               maxValue=max(1, len(self.binnings) - 1),
                               createLabel=False,
                               callback=self.graph.update_colors)
        b_slider.setFixedWidth(176)

        av_slider = gui.hSlider(visualization_box, self, "graph.alpha_value",
                                minValue=0, maxValue=255, step=10,
                                label="Opacity:", createLabel=False,
                                callback=self.graph.update_colors)
        av_slider.setFixedWidth(176)

        gui.checkBox(visualization_box, self, "graph.show_legend",
                     "Show legend",
                     callback=self.graph.update_legend_visibility)

        self.controlArea.layout().addStretch(100)

        plot_gui = OWPlotGUI(self)
        plot_gui.box_zoom_select(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

    @property
    def effective_variables(self):
        return [self.attr_lat, self.attr_lon] \
            if self.attr_lat and self.attr_lon else []

    @property
    def effective_data(self):
        return self.data.transform(Domain(self.effective_variables))

    # Input
    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        data_existed = self.data is not None
        effective_data = self.effective_data if data_existed else None

        self.closeContext()
        self.data = data
        self.Warning.no_region.clear()
        self.Error.no_lat_lon_vars.clear()
        self.agg_func = DEFAULT_AGG_FUNC
        self.check_data()
        self.init_attr_values()
        self.openContext(self.data)

        if not (data_existed and self.data is not None and
                array_equal(effective_data.X, self.effective_data.X)):
            self.clear(cache=True)
            self.input_changed.emit(data)
            self.setup_plot()
        self.update_agg()
        self.apply_selection()
        self.unconditional_commit()

    def check_data(self):
        if self.data is not None and (len(self.data) == 0 or
                                      len(self.data.domain) == 0):
            self.data = None

    def init_attr_values(self):
        lat, lon = None, None
        if self.data is not None:
            lat, lon = find_lat_lon(self.data, filter_hidden=True)
            if lat is None or lon is None:
                # we either find both or we don't have valid data
                self.Error.no_lat_lon_vars()
                self.data = None
                lat, lon = None, None

        domain = self.data.domain if self.data is not None else None
        self.lat_lon_model.set_domain(domain)
        self.agg_attr_model.set_domain(domain)
        self.agg_attr = domain.class_var if domain is not None else None
        self.attr_lat, self.attr_lon = lat, lon

    def set_input_summary(self, data):
        summary = str(len(data)) if data else self.info.NoInput
        self.info.set_input_summary(summary)

    def set_output_summary(self, data):
        summary = str(len(data)) if data else self.info.NoOutput
        self.info.set_output_summary(summary)

    def update_agg(self):
        current_agg = self.agg_func
        self.agg_func_combo.clear()

        if self.agg_attr is not None:
            new_aggs = list(AGG_FUNCS)
            if self.agg_attr.is_discrete:
                new_aggs = [agg for agg in AGG_FUNCS if AGG_FUNCS[agg].disc]
            elif self.agg_attr.is_time:
                new_aggs = [agg for agg in AGG_FUNCS if AGG_FUNCS[agg].time]
        else:
            new_aggs = [DEFAULT_AGG_FUNC]

        self.agg_func_combo.addItems(new_aggs)

        if current_agg in new_aggs:
            self.agg_func = current_agg
        else:
            self.agg_func = DEFAULT_AGG_FUNC

        self.graph.update_colors()

    def setup_plot(self):
        self.controls.binning_index.setEnabled(not self.is_mode())
        self.clear()
        self.graph.reset_graph()

    def apply_selection(self):
        if self.data is not None and self.selection is not None:
            index_group = np.array(self.selection).T
            selection = np.zeros(self.graph.n_ids, dtype=np.uint8)
            selection[index_group[0]] = index_group[1]
            self.graph.selection = selection
            self.graph.update_selection_colors()

    def selection_changed(self):
        sel = None if self.data and isinstance(self.data, SqlTable) \
            else self.graph.selection
        self.selection = [(i, x) for i, x in enumerate(sel) if x] \
            if sel is not None else None
        self.commit()

    def commit(self):
        self.send_data()

    def send_data(self):
        data, graph_sel = self.data, self.graph.get_selection()
        selected_data, ann_data = None, None
        if data:
            group_sel = np.zeros(len(data), dtype=int)

            if len(graph_sel):
                # we get selection by region ids so we have to map it to points
                for id, s in zip(self.region_ids, graph_sel):
                    if s == 0:
                        continue
                    id_indices = np.where(self.data_ids == id)[0]
                    group_sel[id_indices] = s
            else:
                graph_sel = [0]

            if np.sum(graph_sel) > 0:
                selected_data = create_groups_table(data, group_sel, False, "Group")

            if data is not None:
                if np.max(graph_sel) > 1:
                    ann_data = create_groups_table(data, group_sel)
                else:
                    ann_data = create_annotated_table(data, group_sel.astype(bool))

        self.output_changed.emit(selected_data)
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(ann_data)

    def recompute_binnings(self):
        if self.is_mode():
            return

        if self.is_time():
            self.binnings = time_binnings(self.agg_data,
                                          min_bins=3, max_bins=15)
        else:
            self.binnings = decimal_binnings(self.agg_data,
                                             min_bins=3, max_bins=15)

        max_bins = len(self.binnings) - 1
        self.controls.binning_index.setMaximum(max_bins)
        self.binning_index = min(max_bins, self.binning_index)

    def get_binning(self) -> BinDefinition:
        return self.binnings[self.binning_index]

    def get_palette(self):
        if self.agg_func in ('Count', 'Count defined'):
            return DefaultContinuousPalette
        elif self.is_mode():
            return LimitedDiscretePalette(MAX_COLORS)
        else:
            return self.agg_attr.palette

    def get_color_data(self):
        return self.get_reduced_agg_data()

    def get_color_labels(self):
        if self.is_mode():
            return self.get_reduced_agg_data(return_labels=True)
        elif self.is_time():
            return self.agg_attr.str_val

    def get_reduced_agg_data(self, return_labels=False):
        """
        This returns agg data or its labels. It also merges infrequent data.
        """
        needs_merging = self.is_mode() \
                        and len(self.agg_attr.values) >= MAX_COLORS
        if return_labels and not needs_merging:
            return self.agg_attr.values

        if not needs_merging:
            return self.agg_data

        dist = bincount(self.agg_data, max_val=len(self.agg_attr.values) - 1)[0]
        infrequent = np.zeros(len(self.agg_attr.values), dtype=bool)
        infrequent[np.argsort(dist)[:-(MAX_COLORS - 1)]] = True
        if return_labels:
            return [value for value, infreq in zip(self.agg_attr.values, infrequent)
                    if not infreq] + ["Other"]
        else:
            result = self.agg_data.copy()
            freq_vals = [i for i, f in enumerate(infrequent) if not f]
            for i, infreq in enumerate(infrequent):
                if infreq:
                    result[self.agg_data == i] = MAX_COLORS - 1
                else:
                    result[self.agg_data == i] = freq_vals.index(i)
            return result

    def is_mode(self):
        return self.agg_attr is not None and \
               self.agg_attr.is_discrete and \
               self.agg_func == 'Mode'

    def is_time(self):
        return self.agg_attr is not None and \
               self.agg_attr.is_time and \
               self.agg_func not in ('Count', 'Count defined')

    @memoize_method(3)
    def get_regions(self, lat_attr, lon_attr, admin):
        """
        Map points to regions and get regions information.
        Returns:
            ndarray of ids corresponding to points,
            dict of region ids matched to their additional info,
            dict of region ids matched to their polygon
        """
        latlon = np.c_[self.data.get_column_view(lat_attr)[0],
                       self.data.get_column_view(lon_attr)[0]]
        region_info = latlon2region(latlon, admin)
        ids = np.array([region.get('_id') for region in region_info])
        region_info = {info.get('_id'): info for info in region_info}

        self.data_ids = np.array(ids)
        no_region = np.sum(self.data_ids == None)
        if no_region:
            self.Warning.no_region(no_region)

        unique_ids = list(set(ids) - {None})
        polygons = {_id: poly
                    for _id, poly in zip(unique_ids, get_shape(unique_ids))}
        return ids, region_info, polygons

    def get_grouped(self, lat_attr, lon_attr, admin, attr, agg_func):
        """
        Get aggregation value for points grouped by regions.
        Returns:
            Series of aggregated values
        """
        if attr is not None:
            data = self.data.get_column_view(attr)[0]
        else:
            data = np.ones(len(self.data))

        ids, _, _ = self.get_regions(lat_attr, lon_attr, admin)
        result = pd.Series(data, dtype=float)\
            .groupby(ids)\
            .agg(AGG_FUNCS[agg_func].transform)

        return result

    def get_agg_data(self) -> np.ndarray:
        result = self.get_grouped(self.attr_lat, self.attr_lon,
                                  self.admin_level, self.agg_attr,
                                  self.agg_func)

        self.agg_data = np.array(result.values)
        self.region_ids = np.array(result.index)

        arg_region_sort = np.argsort(self.region_ids)
        self.region_ids = self.region_ids[arg_region_sort]
        self.agg_data = self.agg_data[arg_region_sort]

        self.recompute_binnings()

        return self.agg_data

    def format_agg_val(self, value):
        if self.agg_func in ('Count', 'Count defined'):
            return f"{value:d}"
        else:
            return self.agg_attr.repr_val(value)

    def get_choropleth_regions(self) -> List[_ChoroplethRegion]:
        """Recalculate regions"""
        if self.attr_lat is None:
            # if we don't have locations we can't compute regions
            return []

        _, region_info, polygons = self.get_regions(self.attr_lat,
                                                    self.attr_lon,
                                                    self.admin_level)

        regions = []
        for _id in polygons:
            if isinstance(polygons[_id], MultiPolygon):
                # some regions consist of multiple polygons
                polys = list(polygons[_id].geoms)
            else:
                polys = [polygons[_id]]

            qpolys = [self.poly2qpoly(transform(self.deg2canvas, poly))
                      for poly in polys]
            regions.append(_ChoroplethRegion(id=_id, info=region_info[_id],
                                             qpolys=qpolys))

        self.choropleth_regions = sorted(regions, key=lambda cr: cr.id)
        self.get_agg_data()
        return self.choropleth_regions

    @staticmethod
    def poly2qpoly(poly: Polygon) -> QPolygonF:
        return QPolygonF([QPointF(x, y)
                          for x, y in poly.exterior.coords])

    @staticmethod
    def deg2canvas(x, y):
        x, y = deg2norm(x, y)
        y = 1 - y
        return x, y

    def clear(self, cache=False):
        self.choropleth_regions = []
        if cache:
            self.get_regions.cache_clear()

    def send_report(self):
        if self.data is None:
            return
        self.report_plot()

    def sizeHint(self):
        return QSize(1132, 708)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()
        self.graph.clear()

    def keyPressEvent(self, event):
        """Update the tip about using the modifier keys when selecting"""
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        """Update the tip about using the modifier keys when selecting"""
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def showEvent(self, ev):
        super().showEvent(ev)
        # reset the map on show event since before that we didn't know the
        # right resolution
        self.graph.update_view_range()

    def resizeEvent(self, ev):
        super().resizeEvent(ev)
        # when resizing we need to constantly reset the map so that new
        # portions are drawn
        self.graph.update_view_range(match_data=False)

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2:
            settings["graph"] = {}
            rename_setting(settings, "admin", "admin_level")
            rename_setting(settings, "autocommit", "auto_commit")
            settings["graph"]["alpha_value"] = \
                round(settings["opacity"] * 2.55)
            settings["graph"]["show_legend"] = settings["show_legend"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            migrate_str_to_variable(context, names="lat_attr",
                                    none_placeholder="")
            migrate_str_to_variable(context, names="lon_attr",
                                    none_placeholder="")
            migrate_str_to_variable(context, names="attr",
                                    none_placeholder="")

            rename_setting(context, "lat_attr", "attr_lat")
            rename_setting(context, "lon_attr", "attr_lon")
            rename_setting(context, "attr", "agg_attr")
            # old selection will not be ported
            rename_setting(context, "selection", "old_selection")

            if context.values["agg_func"][0] == "Max":
                context.values["agg_func"] = ("Maximal",
                                              context.values["agg_func"][1])
            elif context.values["agg_func"][0] == "Min":
                context.values["agg_func"] = ("Minimal",
                                              context.values["agg_func"][1])
            elif context.values["agg_func"][0] == "Std":
                context.values["agg_func"] = ("Std.",
                                              context.values["agg_func"][1])
Пример #23
0
class OWLoadClassifier(widget.OWWidget):
    name = "Load Classifier"
    description = "Load a classifier from an input file."
    priority = 3050
    icon = "icons/LoadClassifier.svg"

    outputs = [("Classifier", Model, widget.Dynamic)]

    #: List of recent filenames.
    history = Setting([])
    #: Current (last selected) filename or None.
    filename = Setting(None)

    FILTER = owsaveclassifier.OWSaveClassifier.FILTER

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()
        self.selectedIndex = -1

        box = gui.widgetBox(self.controlArea,
                            self.tr("File"),
                            orientation=QHBoxLayout())

        self.filesCB = gui.comboBox(box,
                                    self,
                                    "selectedIndex",
                                    callback=self._on_recent)
        self.filesCB.setMinimumContentsLength(20)
        self.filesCB.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLength)

        self.loadbutton = gui.button(box, self, "...", callback=self.browse)
        self.loadbutton.setIcon(self.style().standardIcon(
            QStyle.SP_DirOpenIcon))
        self.loadbutton.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed)

        self.reloadbutton = gui.button(box,
                                       self,
                                       "Reload",
                                       callback=self.reload,
                                       default=True)
        self.reloadbutton.setIcon(self.style().standardIcon(
            QStyle.SP_BrowserReload))
        self.reloadbutton.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed)

        # filter valid existing filenames
        self.history = list(filter(os.path.isfile, self.history))[:20]
        for filename in self.history:
            self.filesCB.addItem(os.path.basename(filename), userData=filename)

        # restore the current selection if the filename is
        # in the history list
        if self.filename in self.history:
            self.selectedIndex = self.history.index(self.filename)
        else:
            self.selectedIndex = -1
            self.filename = None
            self.reloadbutton.setEnabled(False)

        if self.filename:
            QTimer.singleShot(0, lambda: self.load(self.filename))

    def browse(self):
        """Select a filename using an open file dialog."""
        if self.filename is None:
            startdir = stdpaths.Documents
        else:
            startdir = os.path.dirname(self.filename)

        filename, _ = QFileDialog.getOpenFileName(self,
                                                  self.tr("Open"),
                                                  directory=startdir,
                                                  filter=self.FILTER)

        if filename:
            self.load(filename)

    def reload(self):
        """Reload the current file."""
        self.load(self.filename)

    def load(self, filename):
        """Load the object from filename and send it to output."""
        try:
            classifier = pickle.load(open(filename, "rb"))
        except pickle.UnpicklingError:
            raise  # TODO: error reporting
        except os.error:
            raise  # TODO: error reporting
        else:
            self._remember(filename)
            self.send("Classifier", classifier)

    def _remember(self, filename):
        """
        Remember `filename` was accessed.
        """
        if filename in self.history:
            index = self.history.index(filename)
            del self.history[index]
            self.filesCB.removeItem(index)

        self.history.insert(0, filename)

        self.filesCB.insertItem(0,
                                os.path.basename(filename),
                                userData=filename)
        self.selectedIndex = 0
        self.filename = filename
        self.reloadbutton.setEnabled(self.selectedIndex != -1)

    def _on_recent(self):
        self.load(self.history[self.selectedIndex])
Пример #24
0
class OWAdaBoost(OWBaseLearner):
    name = "AdaBoost"
    description = "An ensemble meta-algorithm that combines weak learners " \
                  "and adapts to the 'hardness' of each training sample. "
    icon = "icons/AdaBoost.svg"
    replaces = [
        "Orange.widgets.classify.owadaboost.OWAdaBoostClassification",
        "Orange.widgets.regression.owadaboostregression.OWAdaBoostRegression",
    ]
    priority = 80
    keywords = ["boost"]

    LEARNER = SklAdaBoostLearner

    class Inputs(OWBaseLearner.Inputs):
        learner = Input("Learner", Learner)

    #: Algorithms for classification problems
    algorithms = ["SAMME", "SAMME.R"]
    #: Losses for regression problems
    losses = ["Linear", "Square", "Exponential"]

    n_estimators = Setting(50)
    learning_rate = Setting(1.)
    algorithm_index = Setting(1)
    loss_index = Setting(0)
    use_random_seed = Setting(False)
    random_seed = Setting(0)

    DEFAULT_BASE_ESTIMATOR = SklTreeLearner()

    class Error(OWBaseLearner.Error):
        no_weight_support = Msg('The base learner does not support weights.')

    def add_main_layout(self):
        # this is part of init, pylint: disable=attribute-defined-outside-init
        box = gui.widgetBox(self.controlArea, "Parameters")
        self.base_estimator = self.DEFAULT_BASE_ESTIMATOR
        self.base_label = gui.label(
            box, self, "Base estimator: " + self.base_estimator.name.title())

        self.n_estimators_spin = gui.spin(box,
                                          self,
                                          "n_estimators",
                                          1,
                                          10000,
                                          label="Number of estimators:",
                                          alignment=Qt.AlignRight,
                                          controlWidth=80,
                                          callback=self.settings_changed)
        self.learning_rate_spin = gui.doubleSpin(
            box,
            self,
            "learning_rate",
            1e-5,
            1.0,
            1e-5,
            label="Learning rate:",
            decimals=5,
            alignment=Qt.AlignRight,
            controlWidth=80,
            callback=self.settings_changed)
        self.random_seed_spin = gui.spin(
            box,
            self,
            "random_seed",
            0,
            2**31 - 1,
            controlWidth=80,
            label="Fixed seed for random generator:",
            alignment=Qt.AlignRight,
            callback=self.settings_changed,
            checked="use_random_seed",
            checkCallback=self.settings_changed)

        # Algorithms
        box = gui.widgetBox(self.controlArea, "Boosting method")
        self.cls_algorithm_combo = gui.comboBox(
            box,
            self,
            "algorithm_index",
            label="Classification algorithm:",
            items=self.algorithms,
            orientation=Qt.Horizontal,
            callback=self.settings_changed)
        self.reg_algorithm_combo = gui.comboBox(
            box,
            self,
            "loss_index",
            label="Regression loss function:",
            items=self.losses,
            orientation=Qt.Horizontal,
            callback=self.settings_changed)

    def create_learner(self):
        if self.base_estimator is None:
            return None
        return self.LEARNER(base_estimator=self.base_estimator,
                            n_estimators=self.n_estimators,
                            learning_rate=self.learning_rate,
                            random_state=self.random_seed,
                            preprocessors=self.preprocessors,
                            algorithm=self.algorithms[self.algorithm_index],
                            loss=self.losses[self.loss_index].lower())

    @Inputs.learner
    def set_base_learner(self, learner):
        # base_estimator is defined in add_main_layout
        # pylint: disable=attribute-defined-outside-init
        self.Error.no_weight_support.clear()
        if learner and not learner.supports_weights:
            # Clear the error and reset to default base learner
            self.Error.no_weight_support()
            self.base_estimator = None
            self.base_label.setText("Base estimator: INVALID")
        else:
            self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR
            self.base_label.setText("Base estimator: %s" %
                                    self.base_estimator.name.title())
        if self.auto_apply:
            self.apply()

    def get_learner_parameters(self):
        return (("Base estimator", self.base_estimator),
                ("Number of estimators", self.n_estimators),
                ("Algorithm (classification)",
                 self.algorithms[self.algorithm_index].capitalize()),
                ("Loss (regression)",
                 self.losses[self.loss_index].capitalize()))
Пример #25
0
class OWSelectAttributes(widget.OWWidget):
    # pylint: disable=too-many-instance-attributes
    name = "Select Columns"
    description = "Select columns from the data table and assign them to " \
                  "data features, classes or meta variables."
    icon = "icons/SelectColumns.svg"
    priority = 100
    keywords = ["filter", "attributes", "target", "variable"]

    class Inputs:
        data = Input("Data", Table, default=True)
        features = Input("Features", AttributeList)

    class Outputs:
        data = Output("Data", Table)
        features = Output("Features", AttributeList, dynamic=False)

    want_main_area = False
    want_control_area = True

    settingsHandler = SelectAttributesDomainContextHandler(first_match=False)
    domain_role_hints = ContextSetting({})
    use_input_features = Setting(False)
    ignore_new_features = Setting(False)
    auto_commit = Setting(True)

    class Warning(widget.OWWidget.Warning):
        mismatching_domain = Msg("Features and data domain do not match")
        multiple_targets = Msg("Most widgets do not support multiple targets")

    def __init__(self):
        super().__init__()
        self.data = None
        self.features = None

        # Schedule interface updates (enabled buttons) using a coalescing
        # single shot timer (complex interactions on selection and filtering
        # updates in the 'available_attrs_view')
        self.__interface_update_timer = QTimer(self,
                                               interval=0,
                                               singleShot=True)
        self.__interface_update_timer.timeout.connect(
            self.__update_interface_state)
        # The last view that has the selection for move operation's source
        self.__last_active_view = None  # type: Optional[QListView]

        def update_on_change(view):
            # Schedule interface state update on selection change in `view`
            self.__last_active_view = view
            self.__interface_update_timer.start()

        new_control_area = QWidget(self.controlArea)
        self.controlArea.layout().addWidget(new_control_area)
        self.controlArea = new_control_area

        # init grid
        layout = QGridLayout()
        self.controlArea.setLayout(layout)
        layout.setContentsMargins(0, 0, 0, 0)
        box = gui.vBox(self.controlArea, "Ignored", addToLayout=False)

        self.available_attrs = VariablesListItemModel()
        filter_edit, self.available_attrs_view = variables_filter(
            parent=self, model=self.available_attrs)
        box.layout().addWidget(filter_edit)

        def dropcompleted(action):
            if action == Qt.MoveAction:
                self.commit()

        self.available_attrs_view.selectionModel().selectionChanged.connect(
            partial(update_on_change, self.available_attrs_view))
        self.available_attrs_view.dragDropActionDidComplete.connect(
            dropcompleted)

        box.layout().addWidget(self.available_attrs_view)
        layout.addWidget(box, 0, 0, 3, 1)

        # 3rd column
        box = gui.vBox(self.controlArea, "Features", addToLayout=False)
        self.used_attrs = VariablesListItemModel()
        filter_edit, self.used_attrs_view = variables_filter(
            parent=self,
            model=self.used_attrs,
            accepted_type=(Orange.data.DiscreteVariable,
                           Orange.data.ContinuousVariable))
        self.used_attrs.rowsInserted.connect(self.__used_attrs_changed)
        self.used_attrs.rowsRemoved.connect(self.__used_attrs_changed)
        self.used_attrs_view.selectionModel().selectionChanged.connect(
            partial(update_on_change, self.used_attrs_view))
        self.used_attrs_view.dragDropActionDidComplete.connect(dropcompleted)
        self.use_features_box = gui.auto_commit(
            self.controlArea,
            self,
            "use_input_features",
            "Use input features",
            "Always use input features",
            box=False,
            commit=self.__use_features_clicked,
            callback=self.__use_features_changed,
            addToLayout=False)
        self.enable_use_features_box()
        box.layout().addWidget(self.use_features_box)
        box.layout().addWidget(filter_edit)
        box.layout().addWidget(self.used_attrs_view)
        layout.addWidget(box, 0, 2, 1, 1)

        box = gui.vBox(self.controlArea, "Target", addToLayout=False)
        self.class_attrs = VariablesListItemModel()
        self.class_attrs_view = VariablesListItemView(
            acceptedType=(Orange.data.DiscreteVariable,
                          Orange.data.ContinuousVariable))
        self.class_attrs_view.setModel(self.class_attrs)
        self.class_attrs_view.selectionModel().selectionChanged.connect(
            partial(update_on_change, self.class_attrs_view))
        self.class_attrs_view.dragDropActionDidComplete.connect(dropcompleted)

        box.layout().addWidget(self.class_attrs_view)
        layout.addWidget(box, 1, 2, 1, 1)

        box = gui.vBox(self.controlArea, "Metas", addToLayout=False)
        self.meta_attrs = VariablesListItemModel()
        self.meta_attrs_view = VariablesListItemView(
            acceptedType=Orange.data.Variable)
        self.meta_attrs_view.setModel(self.meta_attrs)
        self.meta_attrs_view.selectionModel().selectionChanged.connect(
            partial(update_on_change, self.meta_attrs_view))
        self.meta_attrs_view.dragDropActionDidComplete.connect(dropcompleted)
        box.layout().addWidget(self.meta_attrs_view)
        layout.addWidget(box, 2, 2, 1, 1)

        # 2nd column
        bbox = gui.vBox(self.controlArea, addToLayout=False, margin=0)
        self.move_attr_button = gui.button(bbox,
                                           self,
                                           ">",
                                           callback=partial(
                                               self.move_selected,
                                               self.used_attrs_view))
        layout.addWidget(bbox, 0, 1, 1, 1)

        bbox = gui.vBox(self.controlArea, addToLayout=False, margin=0)
        self.move_class_button = gui.button(bbox,
                                            self,
                                            ">",
                                            callback=partial(
                                                self.move_selected,
                                                self.class_attrs_view))
        layout.addWidget(bbox, 1, 1, 1, 1)

        bbox = gui.vBox(self.controlArea, addToLayout=False)
        self.move_meta_button = gui.button(bbox,
                                           self,
                                           ">",
                                           callback=partial(
                                               self.move_selected,
                                               self.meta_attrs_view))
        layout.addWidget(bbox, 2, 1, 1, 1)

        # footer
        gui.button(self.buttonsArea, self, "Reset", callback=self.reset)

        bbox = gui.vBox(self.buttonsArea)
        gui.checkBox(
            widget=bbox,
            master=self,
            value="ignore_new_features",
            label="Ignore new variables by default",
            tooltip="When the widget receives data with additional columns "
            "they are added to the available attributes column if "
            "<i>Ignore new variables by default</i> is checked.")

        gui.rubber(self.buttonsArea)
        gui.auto_send(self.buttonsArea, self, "auto_commit")

        layout.setRowStretch(0, 2)
        layout.setRowStretch(1, 0)
        layout.setRowStretch(2, 1)
        layout.setHorizontalSpacing(0)
        self.controlArea.setLayout(layout)

        self.output_data = None
        self.original_completer_items = []

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        self.resize(600, 600)

    @property
    def features_from_data_attributes(self):
        if self.data is None or self.features is None:
            return []
        domain = self.data.domain
        return [
            domain[feature.name] for feature in self.features
            if feature.name in domain
            and domain[feature.name] in domain.attributes
        ]

    def can_use_features(self):
        return bool(self.features_from_data_attributes) and \
               self.features_from_data_attributes != self.used_attrs[:]

    def __use_features_changed(self):  # Use input features check box
        # Needs a check since callback is invoked before object is created
        if not hasattr(self, "use_features_box"):
            return
        self.enable_used_attrs(not self.use_input_features)
        if self.use_input_features and self.can_use_features():
            self.use_features()
        if not self.use_input_features:
            self.enable_use_features_box()

    def __use_features_clicked(self):  # Use input features button
        self.use_features()

    def __used_attrs_changed(self):
        self.enable_use_features_box()

    @Inputs.data
    def set_data(self, data=None):
        self.update_domain_role_hints()
        self.closeContext()
        self.domain_role_hints = {}

        self.data = data
        if data is None:
            self.used_attrs[:] = []
            self.class_attrs[:] = []
            self.meta_attrs[:] = []
            self.available_attrs[:] = []
            self.info.set_input_summary(self.info.NoInput)
            return

        self.openContext(data)
        all_vars = data.domain.variables + data.domain.metas

        def attrs_for_role(role):
            selected_attrs = [
                attr for attr in all_vars if domain_hints[attr][0] == role
            ]
            return sorted(selected_attrs,
                          key=lambda attr: domain_hints[attr][1])

        domain_hints = self.restore_hints(data.domain)
        self.used_attrs[:] = attrs_for_role("attribute")
        self.class_attrs[:] = attrs_for_role("class")
        self.meta_attrs[:] = attrs_for_role("meta")
        self.available_attrs[:] = attrs_for_role("available")
        self.info.set_input_summary(len(data), format_summary_details(data))

        self.update_interface_state(self.class_attrs_view)

    def restore_hints(self, domain: Domain) -> Dict[Variable, Tuple[str, int]]:
        """
        Define hints for selected/unselected features.
        Rules:
        - if context available, restore new features based on checked/unchecked
          ignore_new_features, context hint should be took into account
        - in no context, restore features based on the domain (as selected)

        Parameters
        ----------
        domain
            Data domain

        Returns
        -------
        Dictionary with hints about order and model in which each feature
        should appear
        """
        domain_hints = {}
        if not self.ignore_new_features or len(self.domain_role_hints) == 0:
            # select_new_features selected or no context - restore based on domain
            domain_hints.update(
                self._hints_from_seq("attribute", domain.attributes))
            domain_hints.update(self._hints_from_seq("meta", domain.metas))
            domain_hints.update(
                self._hints_from_seq("class", domain.class_vars))
        else:
            # if context restored and ignore_new_features selected - restore
            # new features as available
            d = domain.attributes + domain.metas + domain.class_vars
            domain_hints.update(self._hints_from_seq("available", d))

        domain_hints.update(self.domain_role_hints)
        return domain_hints

    def update_domain_role_hints(self):
        """ Update the domain hints to be stored in the widgets settings.
        """
        hints = {}
        hints.update(self._hints_from_seq("available", self.available_attrs))
        hints.update(self._hints_from_seq("attribute", self.used_attrs))
        hints.update(self._hints_from_seq("class", self.class_attrs))
        hints.update(self._hints_from_seq("meta", self.meta_attrs))
        self.domain_role_hints = hints

    @staticmethod
    def _hints_from_seq(role, model):
        return [(attr, (role, i)) for i, attr in enumerate(model)]

    @Inputs.features
    def set_features(self, features):
        self.features = features

    def handleNewSignals(self):
        self.check_data()
        self.enable_used_attrs()
        self.enable_use_features_box()
        if self.use_input_features and self.features_from_data_attributes:
            self.enable_used_attrs(False)
            self.use_features()
        self.unconditional_commit()

    def check_data(self):
        self.Warning.mismatching_domain.clear()
        if self.data is not None and self.features is not None and \
                not self.features_from_data_attributes:
            self.Warning.mismatching_domain()

    def enable_used_attrs(self, enable=True):
        self.move_attr_button.setEnabled(enable)
        self.used_attrs_view.setEnabled(enable)
        self.used_attrs_view.repaint()

    def enable_use_features_box(self):
        self.use_features_box.button.setEnabled(self.can_use_features())
        enable_checkbox = bool(self.features_from_data_attributes)
        self.use_features_box.setHidden(not enable_checkbox)
        self.use_features_box.repaint()

    def use_features(self):
        attributes = self.features_from_data_attributes
        available, used = self.available_attrs[:], self.used_attrs[:]
        self.available_attrs[:] = [
            attr for attr in used + available if attr not in attributes
        ]
        self.used_attrs[:] = attributes
        self.commit()

    @staticmethod
    def selected_rows(view):
        """ Return the selected rows in the view.
        """
        rows = view.selectionModel().selectedRows()
        model = view.model()
        if isinstance(model, QSortFilterProxyModel):
            rows = [model.mapToSource(r) for r in rows]
        return [r.row() for r in rows]

    def move_rows(self, view: QListView, offset: int, roles=(Qt.EditRole, )):
        rows = [idx.row() for idx in view.selectionModel().selectedRows()]
        model = view.model()  # type: QAbstractItemModel
        rowcount = model.rowCount()
        newrows = [min(max(0, row + offset), rowcount - 1) for row in rows]

        def itemData(index):
            return {role: model.data(index, role) for role in roles}

        for row, newrow in sorted(zip(rows, newrows), reverse=offset > 0):
            d1 = itemData(model.index(row, 0))
            d2 = itemData(model.index(newrow, 0))
            model.setItemData(model.index(row, 0), d2)
            model.setItemData(model.index(newrow, 0), d1)

        selection = QItemSelection()
        for nrow in newrows:
            index = model.index(nrow, 0)
            selection.select(index, index)
        view.selectionModel().select(selection,
                                     QItemSelectionModel.ClearAndSelect)

        self.commit()

    def move_up(self, view: QListView):
        self.move_rows(view, -1)

    def move_down(self, view: QListView):
        self.move_rows(view, 1)

    def move_selected(self, view):
        if self.selected_rows(view):
            self.move_selected_from_to(view, self.available_attrs_view)
        elif self.selected_rows(self.available_attrs_view):
            self.move_selected_from_to(self.available_attrs_view, view)

    def move_selected_from_to(self, src, dst):
        self.move_from_to(src, dst, self.selected_rows(src))

    def move_from_to(self, src, dst, rows):
        src_model = source_model(src)
        attrs = [src_model[r] for r in rows]

        for s1, s2 in reversed(list(slices(rows))):
            del src_model[s1:s2]

        dst_model = source_model(dst)

        dst_model.extend(attrs)

        self.commit()

    def __update_interface_state(self):
        last_view = self.__last_active_view
        if last_view is not None:
            self.update_interface_state(last_view)

    def update_interface_state(self, focus=None):
        for view in [
                self.available_attrs_view, self.used_attrs_view,
                self.class_attrs_view, self.meta_attrs_view
        ]:
            if view is not focus and not view.hasFocus() \
                    and view.selectionModel().hasSelection():
                view.selectionModel().clear()

        def selected_vars(view):
            model = source_model(view)
            return [model[i] for i in self.selected_rows(view)]

        available_selected = selected_vars(self.available_attrs_view)
        attrs_selected = selected_vars(self.used_attrs_view)
        class_selected = selected_vars(self.class_attrs_view)
        meta_selected = selected_vars(self.meta_attrs_view)

        available_types = set(map(type, available_selected))
        all_primitive = all(var.is_primitive() for var in available_types)

        move_attr_enabled = \
            ((available_selected and all_primitive) or attrs_selected) and \
            self.used_attrs_view.isEnabled()

        self.move_attr_button.setEnabled(bool(move_attr_enabled))
        if move_attr_enabled:
            self.move_attr_button.setText(">" if available_selected else "<")

        move_class_enabled = bool(all_primitive
                                  and available_selected) or class_selected

        self.move_class_button.setEnabled(bool(move_class_enabled))
        if move_class_enabled:
            self.move_class_button.setText(">" if available_selected else "<")
        move_meta_enabled = available_selected or meta_selected

        self.move_meta_button.setEnabled(bool(move_meta_enabled))
        if move_meta_enabled:
            self.move_meta_button.setText(">" if available_selected else "<")

        # update class_vars height
        if self.class_attrs.rowCount() == 0:
            height = 22
        else:
            height = ((self.class_attrs.rowCount() or 1) *
                      self.class_attrs_view.sizeHintForRow(0)) + 2
        self.class_attrs_view.setFixedHeight(height)

        self.__last_active_view = None
        self.__interface_update_timer.stop()

    def commit(self):
        self.update_domain_role_hints()
        self.Warning.multiple_targets.clear()
        if self.data is not None:
            attributes = list(self.used_attrs)
            class_var = list(self.class_attrs)
            metas = list(self.meta_attrs)

            domain = Orange.data.Domain(attributes, class_var, metas)
            newdata = self.data.transform(domain)
            self.output_data = newdata
            self.Outputs.data.send(newdata)
            self.Outputs.features.send(AttributeList(attributes))
            self.info.set_output_summary(len(newdata),
                                         format_summary_details(newdata))
            self.Warning.multiple_targets(shown=len(class_var) > 1)
        else:
            self.output_data = None
            self.Outputs.data.send(None)
            self.Outputs.features.send(None)
            self.info.set_output_summary(self.info.NoOutput)

    def reset(self):
        self.enable_used_attrs()
        self.use_features_box.checkbox.setChecked(False)
        if self.data is not None:
            self.available_attrs[:] = []
            self.used_attrs[:] = self.data.domain.attributes
            self.class_attrs[:] = self.data.domain.class_vars
            self.meta_attrs[:] = self.data.domain.metas
            self.update_domain_role_hints()
            self.commit()

    def send_report(self):
        if not self.data or not self.output_data:
            return
        in_domain, out_domain = self.data.domain, self.output_data.domain
        self.report_domain("Input data", self.data.domain)
        if (in_domain.attributes, in_domain.class_vars,
                in_domain.metas) == (out_domain.attributes,
                                     out_domain.class_vars, out_domain.metas):
            self.report_paragraph("Output data", "No changes.")
        else:
            self.report_domain("Output data", self.output_data.domain)
            diff = list(
                set(in_domain.variables + in_domain.metas) -
                set(out_domain.variables + out_domain.metas))
            if diff:
                text = "%i (%s)" % (len(diff), ", ".join(x.name for x in diff))
                self.report_items((("Removed", text), ))
class {{cookiecutter.widget_name}}(OWWidget):
    """Docstring."""

    # -------------------------------------------------------------------------
    # Widget info
    # -------------------------------------------------------------------------
    name = '{{cookiecutter.widget_name}}'
    description = 'Widget description'
    # icon = 'Widget icon'


    # -------------------------------------------------------------------------
    # Widget settings (variables)
    # -------------------------------------------------------------------------
    some_setting = Setting(0.5)


    # -------------------------------------------------------------------------
    # Inputs and Outputs
    # -------------------------------------------------------------------------
    class Inputs:
        in_data = Input('Data', Table)


    class Outputs:
        out_data = Output('Data', Table)
    

    # -------------------------------------------------------------------------
    # GUI
Пример #27
0
class Component:
    int_setting = Setting(42)
    schema_only_setting = Setting("only", schema_only=True)
Пример #28
0
class OWPythonScript(widget.OWWidget):
    name = "Python Script"
    description = "Write a Python script and run it on input data or models."
    icon = "icons/PythonScript.svg"
    priority = 3150
    keywords = ["file", "program"]

    class Inputs:
        data = Input("Data", Table, replaces=["in_data"],
                     default=True, multiple=True)
        learner = Input("Learner", Learner, replaces=["in_learner"],
                        default=True, multiple=True)
        classifier = Input("Classifier", Model, replaces=["in_classifier"],
                           default=True, multiple=True)
        object = Input("Object", object, replaces=["in_object"],
                       default=False, multiple=True)

    class Outputs:
        data = Output("Data", Table, replaces=["out_data"])
        learner = Output("Learner", Learner, replaces=["out_learner"])
        classifier = Output("Classifier", Model, replaces=["out_classifier"])
        object = Output("Object", object, replaces=["out_object"])

    signal_names = ("data", "learner", "classifier", "object")

    libraryListSource = \
        Setting([Script("Hello world", "print('Hello world')\n")])
    currentScriptIndex = Setting(0)
    splitterState = Setting(None)

    class Error(OWWidget.Error):
        pass

    def __init__(self):
        super().__init__()

        for name in self.signal_names:
            setattr(self, name, {})

        for s in self.libraryListSource:
            s.flags = 0

        self._cachedDocuments = {}

        self.infoBox = gui.vBox(self.controlArea, 'Info')
        gui.label(
            self.infoBox, self,
            "<p>Execute python script.</p><p>Input variables:<ul><li> " +
            "<li>".join(map("in_{0}, in_{0}s".format, self.signal_names)) +
            "</ul></p><p>Output variables:<ul><li>" +
            "<li>".join(map("out_{0}".format, self.signal_names)) +
            "</ul></p>"
        )

        self.libraryList = itemmodels.PyListModel(
            [], self,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable)

        self.libraryList.wrap(self.libraryListSource)

        self.controlBox = gui.vBox(self.controlArea, 'Library')
        self.controlBox.layout().setSpacing(1)

        self.libraryView = QListView(
            editTriggers=QListView.DoubleClicked |
            QListView.EditKeyPressed,
            sizePolicy=QSizePolicy(QSizePolicy.Ignored,
                                   QSizePolicy.Preferred)
        )
        self.libraryView.setItemDelegate(ScriptItemDelegate(self))
        self.libraryView.setModel(self.libraryList)

        self.libraryView.selectionModel().selectionChanged.connect(
            self.onSelectedScriptChanged
        )
        self.controlBox.layout().addWidget(self.libraryView)

        w = itemmodels.ModelActionsWidget()

        self.addNewScriptAction = action = QAction("+", self)
        action.setToolTip("Add a new script to the library")
        action.triggered.connect(self.onAddScript)
        w.addAction(action)

        action = QAction(unicodedata.lookup("MINUS SIGN"), self)
        action.setToolTip("Remove script from library")
        action.triggered.connect(self.onRemoveScript)
        w.addAction(action)

        action = QAction("Update", self)
        action.setToolTip("Save changes in the editor to library")
        action.setShortcut(QKeySequence(QKeySequence.Save))
        action.triggered.connect(self.commitChangesToLibrary)
        w.addAction(action)

        action = QAction("More", self, toolTip="More actions")

        new_from_file = QAction("Import Script from File", self)
        save_to_file = QAction("Save Selected Script to File", self)
        save_to_file.setShortcut(QKeySequence(QKeySequence.SaveAs))

        new_from_file.triggered.connect(self.onAddScriptFromFile)
        save_to_file.triggered.connect(self.saveScript)

        menu = QMenu(w)
        menu.addAction(new_from_file)
        menu.addAction(save_to_file)
        action.setMenu(menu)
        button = w.addAction(action)
        button.setPopupMode(QToolButton.InstantPopup)

        w.layout().setSpacing(1)

        self.controlBox.layout().addWidget(w)

        self.execute_button = gui.button(self.controlArea, self, 'Run', callback=self.commit)

        self.splitCanvas = QSplitter(Qt.Vertical, self.mainArea)
        self.mainArea.layout().addWidget(self.splitCanvas)

        self.defaultFont = defaultFont = \
            "Monaco" if sys.platform == "darwin" else "Courier"

        self.textBox = gui.vBox(self, 'Python Script')
        self.splitCanvas.addWidget(self.textBox)
        self.text = PythonScriptEditor(self)
        self.textBox.layout().addWidget(self.text)

        self.textBox.setAlignment(Qt.AlignVCenter)
        self.text.setTabStopWidth(4)

        self.text.modificationChanged[bool].connect(self.onModificationChanged)

        self.saveAction = action = QAction("&Save", self.text)
        action.setToolTip("Save script to file")
        action.setShortcut(QKeySequence(QKeySequence.Save))
        action.setShortcutContext(Qt.WidgetWithChildrenShortcut)
        action.triggered.connect(self.saveScript)

        self.consoleBox = gui.vBox(self, 'Console')
        self.splitCanvas.addWidget(self.consoleBox)
        self.console = PythonConsole({}, self)
        self.consoleBox.layout().addWidget(self.console)
        self.console.document().setDefaultFont(QFont(defaultFont))
        self.consoleBox.setAlignment(Qt.AlignBottom)
        self.console.setTabStopWidth(4)

        select_row(self.libraryView, self.currentScriptIndex)

        self.splitCanvas.setSizes([2, 1])
        if self.splitterState is not None:
            self.splitCanvas.restoreState(QByteArray(self.splitterState))

        self.splitCanvas.splitterMoved[int, int].connect(self.onSpliterMoved)
        self.controlArea.layout().addStretch(1)
        self.resize(800, 600)

    def handle_input(self, obj, id, signal):
        id = id[0]
        dic = getattr(self, signal)
        if obj is None:
            if id in dic.keys():
                del dic[id]
        else:
            dic[id] = obj

    @Inputs.data
    def set_data(self, data, id):
        self.handle_input(data, id, "data")

    @Inputs.learner
    def set_learner(self, data, id):
        self.handle_input(data, id, "learner")

    @Inputs.classifier
    def set_classifier(self, data, id):
        self.handle_input(data, id, "classifier")

    @Inputs.object
    def set_object(self, data, id):
        self.handle_input(data, id, "object")

    def handleNewSignals(self):
        self.commit()

    def selectedScriptIndex(self):
        rows = self.libraryView.selectionModel().selectedRows()
        if rows:
            return [i.row() for i in rows][0]
        else:
            return None

    def setSelectedScript(self, index):
        select_row(self.libraryView, index)

    def onAddScript(self, *args):
        self.libraryList.append(Script("New script", "", 0))
        self.setSelectedScript(len(self.libraryList) - 1)

    def onAddScriptFromFile(self, *args):
        filename, _ = QFileDialog.getOpenFileName(
            self, 'Open Python Script',
            os.path.expanduser("~/"),
            'Python files (*.py)\nAll files(*.*)'
        )
        if filename:
            name = os.path.basename(filename)
            # TODO: use `tokenize.detect_encoding`
            with open(filename, encoding="utf-8") as f:
                contents = f.read()
            self.libraryList.append(Script(name, contents, 0, filename))
            self.setSelectedScript(len(self.libraryList) - 1)

    def onRemoveScript(self, *args):
        index = self.selectedScriptIndex()
        if index is not None:
            del self.libraryList[index]
            select_row(self.libraryView, max(index - 1, 0))

    def onSaveScriptToFile(self, *args):
        index = self.selectedScriptIndex()
        if index is not None:
            self.saveScript()

    def onSelectedScriptChanged(self, selected, deselected):
        index = [i.row() for i in selected.indexes()]
        if index:
            current = index[0]
            if current >= len(self.libraryList):
                self.addNewScriptAction.trigger()
                return

            self.text.setDocument(self.documentForScript(current))
            self.currentScriptIndex = current

    def documentForScript(self, script=0):
        if type(script) != Script:
            script = self.libraryList[script]

        if script not in self._cachedDocuments:
            doc = QTextDocument(self)
            doc.setDocumentLayout(QPlainTextDocumentLayout(doc))
            doc.setPlainText(script.script)
            doc.setDefaultFont(QFont(self.defaultFont))
            doc.highlighter = PythonSyntaxHighlighter(doc)
            doc.modificationChanged[bool].connect(self.onModificationChanged)
            doc.setModified(False)
            self._cachedDocuments[script] = doc
        return self._cachedDocuments[script]

    def commitChangesToLibrary(self, *args):
        index = self.selectedScriptIndex()
        if index is not None:
            self.libraryList[index].script = self.text.toPlainText()
            self.text.document().setModified(False)
            self.libraryList.emitDataChanged(index)

    def onModificationChanged(self, modified):
        index = self.selectedScriptIndex()
        if index is not None:
            self.libraryList[index].flags = Script.Modified if modified else 0
            self.libraryList.emitDataChanged(index)

    def onSpliterMoved(self, pos, ind):
        self.splitterState = bytes(self.splitCanvas.saveState())

    def updateSelecetdScriptState(self):
        index = self.selectedScriptIndex()
        if index is not None:
            script = self.libraryList[index]
            self.libraryList[index] = Script(script.name,
                                             self.text.toPlainText(),
                                             0)

    def saveScript(self):
        index = self.selectedScriptIndex()
        if index is not None:
            script = self.libraryList[index]
            filename = script.filename
        else:
            filename = os.path.expanduser("~/")

        filename, _ = QFileDialog.getSaveFileName(
            self, 'Save Python Script',
            filename,
            'Python files (*.py)\nAll files(*.*)'
        )

        if filename:
            fn = ""
            head, tail = os.path.splitext(filename)
            if not tail:
                fn = head + ".py"
            else:
                fn = filename

            f = open(fn, 'w')
            f.write(self.text.toPlainText())
            f.close()

    def initial_locals_state(self):
        d = {}
        for name in self.signal_names:
            value = getattr(self, name)
            all_values = list(value.values())
            one_value = all_values[0] if len(all_values) == 1 else None
            d["in_" + name + "s"] = all_values
            d["in_" + name] = one_value
        return d

    def commit(self):
        self.Error.clear()
        self._script = str(self.text.toPlainText())
        lcls = self.initial_locals_state()
        lcls["_script"] = str(self.text.toPlainText())
        self.console.updateLocals(lcls)
        self.console.write("\nRunning script:\n")
        self.console.push("exec(_script)")
        self.console.new_prompt(sys.ps1)
        for signal in self.signal_names:
            out_var = self.console.locals.get("out_" + signal)
            signal_type = getattr(self.Outputs, signal).type
            if not isinstance(out_var, signal_type) and out_var is not None:
                self.Error.add_message(signal,
                                       "'{}' has to be an instance of '{}'.".
                                       format(signal, signal_type.__name__))
                getattr(self.Error, signal)()
                out_var = None
            getattr(self.Outputs, signal).send(out_var)
Пример #29
0
class OWScatterPlot(OWWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 2
    settingsHandler = DomainContextHandler()

    auto_send_selection = Setting(True)
    auto_sample = Setting(True)
    toolbar_selection = Setting(0)

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)

    #: Serialized selection state to be restored
    selection_group = Setting(None, schema_only=True)

    graph = SettingProvider(OWScatterPlotGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Information(OWWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")

    def __init__(self):
        super().__init__()

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWScatterPlotGraph(self, box, "ScatterPlot")
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        axispen = QPen(self.palette().color(QPalette.Text))
        axis = plot.getAxis("bottom")
        axis.setPen(axispen)

        axis = plot.getAxis("left")
        axis.setPen(axispen)

        self.data = None  # Orange.data.Table
        self.subset_data = None  # Orange.data.Table
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        #: Remember the saved state to restore
        self.__pending_selection_restore = self.selection_group
        self.selection_group = None

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str, contentsLength=14
        )
        box = gui.vBox(self.controlArea, "Axis Data")
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self.update_attr,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:", callback=self.update_attr,
            model=self.xy_model, **common_options)

        vizrank_box = gui.hBox(box)
        gui.separator(vizrank_box, width=common_options["labelWidth"])
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

        gui.separator(box)

        g = self.graph.gui
        g.add_widgets([g.JitterSizeSlider,
                       g.JitterNumericValues], box)

        self.sampling = gui.auto_commit(
            self.controlArea, self, "auto_sample", "Sample", box="Sampling",
            callback=self.switch_sampling, commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

        g.point_properties_box(self.controlArea)
        self.models = [self.xy_model] + g.points_models

        box_plot_prop = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([g.ShowLegend,
                       g.ShowGridLines,
                       g.ToolTipShowsAll,
                       g.ClassDensity,
                       g.RegressionLine,
                       g.LabelOnlySelected], box_plot_prop)

        self.graph.box_zoom_select(self.controlArea)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_send_selection",
                        "Send Selection", "Send Automatically")

        self.graph.zoom_actions(self)

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            for ext in w.EXTENSIONS:
                self.graph_writers[ext] = w


    def sizeHint(self):
        return QSize(1132, 708)

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
                     len([v for v in chain(self.data.domain.variables, self.data.domain.metas)
                          if v.is_primitive]) > 2\
                     and len(self.data) > 1
        self.vizrank_button.setEnabled(
            is_enabled and self.graph.attr_color is not None and
            not np.isnan(self.data.get_column_view(self.graph.attr_color)[0].astype(float)).all())
        if is_enabled and self.graph.attr_color is None:
            self.vizrank_button.setToolTip("Color variable has to be selected.")
        else:
            self.vizrank_button.setToolTip("")

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.Information.sampled_sql.clear()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled_sql()
                self.sql_data = data
                data_sample = data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if data is not None and (len(data) == 0 or len(data.domain) == 0):
            data = None
        if self.data and data and self.data.checksum() == data.checksum():
            return

        self.closeContext()
        same_domain = (self.data and data and
                       data.domain.checksum() == self.data.domain.checksum())
        self.data = data

        if not same_domain:
            self.init_attr_values()
        self.openContext(self.data)
        self._vizrank_color_change()

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Orange.data.Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.graph.attr_label, str):
            self.graph.attr_label = findvar(
                self.graph.attr_label, self.graph.gui.label_model)
        if isinstance(self.graph.attr_color, str):
            self.graph.attr_color = findvar(
                self.graph.attr_color, self.graph.gui.color_model)
        if isinstance(self.graph.attr_shape, str):
            self.graph.attr_shape = findvar(
                self.graph.attr_shape, self.graph.gui.shape_model)
        if isinstance(self.graph.attr_size, str):
            self.graph.attr_size = findvar(
                self.graph.attr_size, self.graph.gui.size_model)

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            return self.__timer.stop()
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        self.subset_data = subset_data
        self.controls.graph.alpha_value.setEnabled(subset_data is None)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.graph.new_data(self.data, self.subset_data)
        if self.attribute_selection_list and self.graph.domain is not None and \
                all(attr in self.graph.domain
                        for attr in self.attribute_selection_list):
            self.attr_x = self.attribute_selection_list[0]
            self.attr_y = self.attribute_selection_list[1]
        self.attribute_selection_list = None
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        if self.data is not None and self.__pending_selection_restore is not None:
            self.apply_selection(self.__pending_selection_restore)
            self.__pending_selection_restore = None
        self.unconditional_commit()

    def apply_selection(self, selection):
        """Apply `selection` to the current plot."""
        if self.data is not None:
            self.graph.selection = np.zeros(len(self.data), dtype=np.uint8)
            self.selection_group = [x for x in selection if x[0] < len(self.data)]
            selection_array = np.array(self.selection_group).T
            self.graph.selection[selection_array[0]] = selection_array[1]
            self.graph.update_colors(keep_colors=True)

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def init_attr_values(self):
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x
        self.graph.set_domain(data)

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        self.send_features()

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.attr_x, self.attr_y, reset_view)

    def selection_changed(self):

        # Store current selection in a setting that is stored in workflow
        if isinstance(self.data, SqlTable):
            selection = None
        elif self.data is not None:
            selection = self.graph.get_selection()
        else:
            selection = None
        if selection is not None and len(selection):
            self.selection_group = list(zip(selection, self.graph.selection[selection]))
        else:
            self.selection_group = None

        self.commit()

    def send_data(self):
        # TODO: Implement selection for sql data
        def _get_selected():
            if not len(selection):
                return None
            return create_groups_table(data, graph.selection, False, "Group")

        def _get_annotated():
            if graph.selection is not None and np.max(graph.selection) > 1:
                return create_groups_table(data, graph.selection)
            else:
                return create_annotated_table(data, selection)

        graph = self.graph
        data = self.data
        selection = graph.get_selection()
        self.Outputs.annotated_data.send(_get_annotated())
        self.Outputs.selected_data.send(_get_selected())

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def commit(self):
        self.send_data()
        self.send_features()

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        if self.data is None:
            return
        def name(var):
            return var and var.name
        caption = report.render_items_vert((
            ("Color", name(self.graph.attr_color)),
            ("Label", name(self.graph.attr_label)),
            ("Shape", name(self.graph.attr_shape)),
            ("Size", name(self.graph.attr_size)),
            ("Jittering", (self.attr_x.is_discrete or
                           self.attr_y.is_discrete or
                           self.graph.jitter_continuous) and
             self.graph.jitter_size)))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1) for a in settings["selection"]]
class OWContrastFactor(OWGenericWidget):

    name = "Contrast Factor Calculator"
    description = "Contrast Factor Calculator"
    icon = "icons/contrast_factor.png"
    priority = 10

    want_main_area = False

    c11 = Setting(24.65)
    c12 = Setting(13.45)
    c44 = Setting(2.87)

    inputs = [("Fit Global Parameters", FitGlobalParameters, 'set_data')]
    outputs = [("Fit Global Parameters", FitGlobalParameters)]

    def __init__(self):
        super().__init__(show_automatic_box=True)

        main_box = gui.widgetBox(self.controlArea,
                                 "Constrast Factor Calculator Parameters",
                                 orientation="vertical",
                                 width=self.CONTROL_AREA_WIDTH - 10,
                                 height=600)

        button_box = gui.widgetBox(main_box,
                                   "",
                                   orientation="horizontal",
                                   width=self.CONTROL_AREA_WIDTH - 25)

        gui.button(button_box,
                   self,
                   "Send Constrast Factor A/B Parameters",
                   height=50,
                   callback=self.send_contrast_factor_a_b)

        contrast_factor_box = gui.widgetBox(main_box,
                                            "Elastic Constants",
                                            orientation="vertical",
                                            height=300,
                                            width=self.CONTROL_AREA_WIDTH - 30)

        gui.lineEdit(contrast_factor_box,
                     self,
                     "c11",
                     "c11",
                     labelWidth=90,
                     valueType=float)
        gui.lineEdit(contrast_factor_box,
                     self,
                     "c12",
                     "c12",
                     labelWidth=90,
                     valueType=float)
        gui.lineEdit(contrast_factor_box,
                     self,
                     "c44",
                     "c44",
                     labelWidth=90,
                     valueType=float)

        text_area_box = gui.widgetBox(contrast_factor_box,
                                      "Calculation Result",
                                      orientation="vertical",
                                      height=160,
                                      width=self.CONTROL_AREA_WIDTH - 50)

        self.text_area = gui.textArea(height=120,
                                      width=self.CONTROL_AREA_WIDTH - 70,
                                      readOnly=True)
        self.text_area.setText("")
        self.text_area.setStyleSheet("font-family: Courier, monospace;")

        text_area_box.layout().addWidget(self.text_area)

        orangegui.separator(main_box, height=280)

    def send_contrast_factor_a_b(self):
        try:
            if not self.fit_global_parameters is None:
                if self.fit_global_parameters.fit_initialization is None:
                    raise ValueError(
                        "Calculation is not possibile, Crystal Structure is missing"
                    )

                if self.fit_global_parameters.fit_initialization.crystal_structures is None:
                    raise ValueError(
                        "Calculation is not possibile, Crystal Structure is missing"
                    )

                congruence.checkStrictlyPositiveNumber(self.c11, "c11")
                congruence.checkStrictlyPositiveNumber(self.c12, "c12")
                congruence.checkStrictlyPositiveNumber(self.c44, "c44")

                symmetry = self.fit_global_parameters.fit_initialization.crystal_structures[
                    0].symmetry

                Ae, Be, As, Bs = calculate_A_B_coefficients(
                    self.c11, self.c12, self.c44, symmetry)

                text = "Ae = " + str(Ae) + "\n"
                text += "Be = " + str(Be) + "\n"
                text += "As = " + str(As) + "\n"
                text += "Bs = " + str(Bs) + "\n"

                self.text_area.setText(text)

                self.fit_global_parameters.strain_parameters = [
                    KrivoglazWilkensModel(
                        Ae=FitParameter(parameter_name=KrivoglazWilkensModel.
                                        get_parameters_prefix() + "Ae",
                                        value=Ae,
                                        fixed=True),
                        Be=FitParameter(parameter_name=KrivoglazWilkensModel.
                                        get_parameters_prefix() + "Be",
                                        value=Be,
                                        fixed=True),
                        As=FitParameter(parameter_name=KrivoglazWilkensModel.
                                        get_parameters_prefix() + "As",
                                        value=As,
                                        fixed=True),
                        Bs=FitParameter(parameter_name=KrivoglazWilkensModel.
                                        get_parameters_prefix() + "Bs",
                                        value=Bs,
                                        fixed=True))
                ]

                self.send("Fit Global Parameters", self.fit_global_parameters)

        except Exception as e:
            QMessageBox.critical(self, "Error", str(e), QMessageBox.Ok)

            if self.IS_DEVELOP: raise e

    def set_data(self, data):
        if not data is None:
            self.fit_global_parameters = data.duplicate()
            if not self.fit_global_parameters.strain_parameters is None:
                raise Exception(
                    "This widget should be put BEFORE the strain model widget")

            if self.is_automatic_run:
                self.send_contrast_factor_a_b()