示例#1
0
 def get_combo(model):
     combo = ComboBoxSearch(self)
     combo.setModel(model)
     # We use signal activated because it is triggered only on user
     # interaction, not programmatically.
     combo.activated.connect(sync_combos)
     return combo
示例#2
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        layout = QFormLayout(
            fieldGrowthPolicy=QFormLayout.ExpandingFieldsGrow
        )
        layout.setContentsMargins(0, 0, 0, 0)
        self.nameedit = QLineEdit(
            placeholderText="Name...",
            sizePolicy=QSizePolicy(QSizePolicy.Minimum,
                                   QSizePolicy.Fixed)
        )
        self.expressionedit = QLineEdit(
            placeholderText="Expression...",
            toolTip=self.ExpressionTooltip)

        self.attrs_model = itemmodels.VariableListModel(
            ["Select Feature"], parent=self)
        self.attributescb = ComboBoxSearch(
            minimumContentsLength=16,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            sizePolicy=QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
        )
        self.attributescb.setModel(self.attrs_model)

        sorted_funcs = sorted(self.FUNCTIONS)
        self.funcs_model = itemmodels.PyListModelTooltip()
        self.funcs_model.setParent(self)

        self.funcs_model[:] = chain(["Select Function"], sorted_funcs)
        self.funcs_model.tooltips[:] = chain(
            [''],
            [self.FUNCTIONS[func].__doc__ for func in sorted_funcs])

        self.functionscb = ComboBoxSearch(
            minimumContentsLength=16,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            sizePolicy=QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum))
        self.functionscb.setModel(self.funcs_model)

        hbox = QHBoxLayout()
        hbox.addWidget(self.attributescb)
        hbox.addWidget(self.functionscb)

        layout.addRow(self.nameedit, self.expressionedit)
        layout.addRow(self.tr(""), hbox)
        self.setLayout(layout)

        self.nameedit.editingFinished.connect(self._invalidate)
        self.expressionedit.textChanged.connect(self._invalidate)
        self.attributescb.currentIndexChanged.connect(self.on_attrs_changed)
        self.functionscb.currentIndexChanged.connect(self.on_funcs_changed)

        self._modified = False
示例#3
0
    def add_row(self, attr=None, condition_type=None, condition_value=None):
        model = self.cond_list.model()
        row = model.rowCount()
        model.insertRow(row)

        attr_combo = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
        attr_combo.setModel(self.variable_model)
        attr_combo.row = row
        attr_combo.setCurrentIndex(
            self.variable_model.indexOf(attr) if attr else len(self.AllTypes) +
            1)
        self.cond_list.setCellWidget(row, 0, attr_combo)

        index = QPersistentModelIndex(model.index(row, 3))
        temp_button = QPushButton(
            '×',
            self,
            flat=True,
            styleSheet='* {font-size: 16pt; color: silver}'
            '*:hover {color: black}')
        temp_button.clicked.connect(lambda: self.remove_one(index.row()))
        self.cond_list.setCellWidget(row, 3, temp_button)

        self.remove_all_button.setDisabled(False)
        self.set_new_operators(attr_combo, attr is not None, condition_type,
                               condition_value)
        attr_combo.currentIndexChanged.connect(
            lambda _: self.set_new_operators(attr_combo, False))

        self.cond_list.resizeRowToContents(row)
示例#4
0
    def __init__(self):
        super().__init__()
        self.domain = None
        self.dataset = None
        self.clf_dataset = None
        self.tree_adapter = None

        self.color_label = QLabel("Target class: ")
        combo = self.color_combo = ComboBoxSearch()
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        combo.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLengthWithIcon)
        combo.setMinimumContentsLength(8)
        combo.activated[int].connect(self.color_changed)
        self.display_box.layout().addRow(self.color_label, combo)
示例#5
0
    def __init__(self):
        super().__init__()
        self.domain = None
        self.dataset = None
        self.clf_dataset = None
        self.tree_adapter = None

        self.color_label = QLabel("Target class: ")
        combo = self.color_combo = ComboBoxSearch()
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        combo.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLengthWithIcon)
        combo.setMinimumContentsLength(8)
        combo.activated[int].connect(self.color_changed)
        self.display_box.layout().addRow(self.color_label, combo)

        box = gui.hBox(None)
        gui.rubber(box)
        gui.checkBox(box,
                     self,
                     "show_intermediate",
                     "Show details in non-leaves",
                     callback=self.set_node_info)
        self.display_box.layout().addRow(box)
示例#6
0
class FeatureEditor(QFrame):
    FUNCTIONS = dict(chain([(key, val) for key, val in math.__dict__.items()
                            if not key.startswith("_")],
                           [(key, val) for key, val in builtins.__dict__.items()
                            if key in {"str", "float", "int", "len",
                                       "abs", "max", "min"}]))
    featureChanged = Signal()
    featureEdited = Signal()

    modifiedChanged = Signal(bool)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        layout = QFormLayout(
            fieldGrowthPolicy=QFormLayout.ExpandingFieldsGrow
        )
        layout.setContentsMargins(0, 0, 0, 0)
        self.nameedit = QLineEdit(
            placeholderText="Name...",
            sizePolicy=QSizePolicy(QSizePolicy.Minimum,
                                   QSizePolicy.Fixed)
        )
        self.expressionedit = QLineEdit(
            placeholderText="Expression...",
            toolTip=self.ExpressionTooltip)

        self.attrs_model = itemmodels.VariableListModel(
            ["Select Feature"], parent=self)
        self.attributescb = ComboBoxSearch(
            minimumContentsLength=16,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            sizePolicy=QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
        )
        self.attributescb.setModel(self.attrs_model)

        sorted_funcs = sorted(self.FUNCTIONS)
        self.funcs_model = itemmodels.PyListModelTooltip()
        self.funcs_model.setParent(self)

        self.funcs_model[:] = chain(["Select Function"], sorted_funcs)
        self.funcs_model.tooltips[:] = chain(
            [''],
            [self.FUNCTIONS[func].__doc__ for func in sorted_funcs])

        self.functionscb = ComboBoxSearch(
            minimumContentsLength=16,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            sizePolicy=QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum))
        self.functionscb.setModel(self.funcs_model)

        hbox = QHBoxLayout()
        hbox.addWidget(self.attributescb)
        hbox.addWidget(self.functionscb)

        layout.addRow(self.nameedit, self.expressionedit)
        layout.addRow(self.tr(""), hbox)
        self.setLayout(layout)

        self.nameedit.editingFinished.connect(self._invalidate)
        self.expressionedit.textChanged.connect(self._invalidate)
        self.attributescb.currentIndexChanged.connect(self.on_attrs_changed)
        self.functionscb.currentIndexChanged.connect(self.on_funcs_changed)

        self._modified = False

    def setModified(self, modified):
        if not isinstance(modified, bool):
            raise TypeError

        if self._modified != modified:
            self._modified = modified
            self.modifiedChanged.emit(modified)

    def modified(self):
        return self._modified

    modified = Property(bool, modified, setModified,
                        notify=modifiedChanged)

    def setEditorData(self, data, domain):
        self.nameedit.setText(data.name)
        self.expressionedit.setText(data.expression)
        self.setModified(False)
        self.featureChanged.emit()
        self.attrs_model[:] = ["Select Feature"]
        if domain is not None and not domain.empty():
            self.attrs_model[:] += chain(domain.attributes,
                                         domain.class_vars,
                                         domain.metas)

    def editorData(self):
        return FeatureDescriptor(name=self.nameedit.text(),
                                 expression=self.nameedit.text())

    def _invalidate(self):
        self.setModified(True)
        self.featureEdited.emit()
        self.featureChanged.emit()

    def on_attrs_changed(self):
        index = self.attributescb.currentIndex()
        if index > 0:
            attr = sanitized_name(self.attrs_model[index].name)
            self.insert_into_expression(attr)
            self.attributescb.setCurrentIndex(0)

    def on_funcs_changed(self):
        index = self.functionscb.currentIndex()
        if index > 0:
            func = self.funcs_model[index]
            if func in ["atan2", "fmod", "ldexp", "log",
                        "pow", "copysign", "hypot"]:
                self.insert_into_expression(func + "(,)")
                self.expressionedit.cursorBackward(False, 2)
            elif func in ["e", "pi"]:
                self.insert_into_expression(func)
            else:
                self.insert_into_expression(func + "()")
                self.expressionedit.cursorBackward(False)
            self.functionscb.setCurrentIndex(0)

    def insert_into_expression(self, what):
        cp = self.expressionedit.cursorPosition()
        ct = self.expressionedit.text()
        text = ct[:cp] + what + ct[cp:]
        self.expressionedit.setText(text)
        self.expressionedit.setFocus()
示例#7
0
    def set_new_values(self, oper_combo, adding_all, selected_values=None):
        # def remove_children():
        #     for child in box.children()[1:]:
        #         box.layout().removeWidget(child)
        #         child.setParent(None)

        def add_textual(contents):
            le = gui.lineEdit(box,
                              self,
                              None,
                              sizePolicy=QSizePolicy(QSizePolicy.Expanding,
                                                     QSizePolicy.Expanding))
            if contents:
                le.setText(contents)
            le.setAlignment(Qt.AlignRight)
            le.editingFinished.connect(self.conditions_changed)
            return le

        def add_numeric(contents):
            le = add_textual(contents)
            le.setValidator(OWSelectRows.QDoubleValidatorEmpty())
            return le

        box = self.cond_list.cellWidget(oper_combo.row, 2)
        lc = ["", ""]
        oper = oper_combo.currentIndex()
        attr_name = oper_combo.attr_combo.currentText()
        if attr_name in self.AllTypes:
            vtype = self.AllTypes[attr_name]
            var = None
        else:
            var = self.data.domain[attr_name]
            var_idx = self.data.domain.index(attr_name)
            vtype = vartype(var)
            if selected_values is not None:
                lc = list(selected_values) + ["", ""]
                lc = [str(x) if vtype != 4 else x for x in lc[:2]]
        if box and vtype == box.var_type:
            lc = self._get_lineedit_contents(box) + lc

        if oper_combo.currentText().endswith(" defined"):
            label = QLabel()
            label.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, label)
        elif var is not None and var.is_discrete:
            if oper_combo.currentText().endswith(" one of"):
                if selected_values:
                    lc = list(selected_values)
                button = DropDownToolButton(self, var, lc)
                button.var_type = vtype
                self.cond_list.setCellWidget(oper_combo.row, 2, button)
            else:
                combo = ComboBoxSearch()
                combo.addItems(("", ) + var.values)
                if lc[0]:
                    combo.setCurrentIndex(int(lc[0]))
                else:
                    combo.setCurrentIndex(0)
                combo.var_type = vartype(var)
                self.cond_list.setCellWidget(oper_combo.row, 2, combo)
                combo.currentIndexChanged.connect(self.conditions_changed)
        else:
            box = gui.hBox(self, addToLayout=False)
            box.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, box)
            if vtype == 2:  # continuous:
                box.controls = [add_numeric(lc[0])]
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_numeric(lc[1]))
            elif vtype == 3:  # string:
                box.controls = [add_textual(lc[0])]
                if oper in [6, 7]:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_textual(lc[1]))
            elif vtype == 4:  # time:

                def invalidate_datetime():
                    if w_:
                        if w.dateTime() > w_.dateTime():
                            w_.setDateTime(w.dateTime())
                        if w.format == (1, 1):
                            w.calendarWidget.timeedit.setTime(w.time())
                            w_.calendarWidget.timeedit.setTime(w_.time())
                    elif w.format == (1, 1):
                        w.calendarWidget.timeedit.setTime(w.time())

                def datetime_changed():
                    self.conditions_changed()
                    invalidate_datetime()

                datetime_format = (var.have_date, var.have_time)
                column = self.data.get_column_view(var_idx)[0]
                w = DateTimeWidget(self, column, datetime_format)
                w.set_datetime(lc[0])
                box.controls = [w]
                box.layout().addWidget(w)
                w.dateTimeChanged.connect(datetime_changed)
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    w_ = DateTimeWidget(self, column, datetime_format)
                    w_.set_datetime(lc[1])
                    box.layout().addWidget(w_)
                    box.controls.append(w_)
                    invalidate_datetime()
                    w_.dateTimeChanged.connect(datetime_changed)
                else:
                    w_ = None
            else:
                box.controls = []
        if not adding_all:
            self.conditions_changed()
示例#8
0
    def set_new_values(self, oper_combo, adding_all, selected_values=None):
        # def remove_children():
        #     for child in box.children()[1:]:
        #         box.layout().removeWidget(child)
        #         child.setParent(None)

        def add_textual(contents):
            le = gui.lineEdit(box,
                              self,
                              None,
                              sizePolicy=QSizePolicy(QSizePolicy.Expanding,
                                                     QSizePolicy.Expanding))
            if contents:
                le.setText(contents)
            le.setAlignment(Qt.AlignRight)
            le.editingFinished.connect(self.conditions_changed)
            return le

        def add_numeric(contents):
            le = add_textual(contents)
            le.setValidator(OWSelectRows.QDoubleValidatorEmpty())
            return le

        def add_datetime(contents):
            le = add_textual(contents)
            le.setValidator(QRegExpValidator(QRegExp(TimeVariable.REGEX)))
            return le

        box = self.cond_list.cellWidget(oper_combo.row, 2)
        lc = ["", ""]
        oper = oper_combo.currentIndex()
        attr_name = oper_combo.attr_combo.currentText()
        if attr_name in self.AllTypes:
            vtype = self.AllTypes[attr_name]
            var = None
        else:
            var = self.data.domain[attr_name]
            vtype = vartype(var)
            if selected_values is not None:
                lc = list(selected_values) + ["", ""]
                lc = [str(x) for x in lc[:2]]
        if box and vtype == box.var_type:
            lc = self._get_lineedit_contents(box) + lc

        if oper_combo.currentText().endswith(" defined"):
            label = QLabel()
            label.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, label)
        elif var is not None and var.is_discrete:
            if oper_combo.currentText().endswith(" one of"):
                if selected_values:
                    lc = [x for x in list(selected_values)]
                button = DropDownToolButton(self, var, lc)
                button.var_type = vtype
                self.cond_list.setCellWidget(oper_combo.row, 2, button)
            else:
                combo = ComboBoxSearch()
                combo.addItems(("", ) + var.values)
                if lc[0]:
                    combo.setCurrentIndex(int(lc[0]))
                else:
                    combo.setCurrentIndex(0)
                combo.var_type = vartype(var)
                self.cond_list.setCellWidget(oper_combo.row, 2, combo)
                combo.currentIndexChanged.connect(self.conditions_changed)
        else:
            box = gui.hBox(self, addToLayout=False)
            box.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, box)
            if vtype in (2, 4):  # continuous, time:
                validator = add_datetime if isinstance(
                    var, TimeVariable) else add_numeric
                box.controls = [validator(lc[0])]
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(validator(lc[1]))
            elif vtype == 3:  # string:
                box.controls = [add_textual(lc[0])]
                if oper in [6, 7]:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_textual(lc[1]))
            else:
                box.controls = []
        if not adding_all:
            self.conditions_changed()
示例#9
0
    def __init__(self):
        super().__init__()
        self.__pending_selection = self.selected_rows

        # A kingdom for a save_state/restore_state
        self.col_clustering = enum_get(Clustering, self.col_clustering_method,
                                       Clustering.None_)
        self.row_clustering = enum_get(Clustering, self.row_clustering_method,
                                       Clustering.None_)

        @self.settingsAboutToBePacked.connect
        def _():
            self.col_clustering_method = self.col_clustering.name
            self.row_clustering_method = self.row_clustering.name

        self.keep_aspect = False

        #: The original data with all features (retained to
        #: preserve the domain on the output)
        self.input_data = None
        #: The effective data striped of discrete features, and often
        #: merged using k-means
        self.data = None
        self.effective_data = None
        #: kmeans model used to merge rows of input_data
        self.kmeans_model = None
        #: merge indices derived from kmeans
        #: a list (len==k) of int ndarray where the i-th item contains
        #: the indices which merge the input_data into the heatmap row i
        self.merge_indices = None
        self.parts: Optional[Parts] = None
        self.__rows_cache = {}
        self.__columns_cache = {}

        # GUI definition
        colorbox = gui.vBox(self.controlArea, "Color")
        self.color_cb = gui.palette_combo_box(self.palette_name)
        self.color_cb.currentIndexChanged.connect(self.update_color_schema)
        colorbox.layout().addWidget(self.color_cb)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        lowslider = gui.hSlider(colorbox,
                                self,
                                "threshold_low",
                                minValue=0.0,
                                maxValue=1.0,
                                step=0.05,
                                ticks=True,
                                intOnly=False,
                                createLabel=False,
                                callback=self.update_lowslider)
        highslider = gui.hSlider(colorbox,
                                 self,
                                 "threshold_high",
                                 minValue=0.0,
                                 maxValue=1.0,
                                 step=0.05,
                                 ticks=True,
                                 intOnly=False,
                                 createLabel=False,
                                 callback=self.update_highslider)

        form.addRow("Low:", lowslider)
        form.addRow("High:", highslider)

        colorbox.layout().addLayout(form)

        mergebox = gui.vBox(
            self.controlArea,
            "Merge",
        )
        gui.checkBox(mergebox,
                     self,
                     "merge_kmeans",
                     "Merge by k-means",
                     callback=self.__update_row_clustering)
        ibox = gui.indentedBox(mergebox)
        gui.spin(ibox,
                 self,
                 "merge_kmeans_k",
                 minv=5,
                 maxv=500,
                 label="Clusters:",
                 keyboardTracking=False,
                 callbackOnReturn=True,
                 callback=self.update_merge)

        cluster_box = gui.vBox(self.controlArea, "Clustering")
        # Row clustering
        self.row_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.row_clustering, ClusteringRole)
        self.connect_control(
            "row_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_row_clustering(cb.itemData(idx, ClusteringRole))

        # Column clustering
        self.col_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.col_clustering, ClusteringRole)
        self.connect_control(
            "col_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_col_clustering(cb.itemData(idx, ClusteringRole))

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )
        form.addRow("Rows:", self.row_cluster_cb)
        form.addRow("Columns:", self.col_cluster_cb)
        cluster_box.layout().addLayout(form)
        box = gui.vBox(self.controlArea, "Split By")

        self.row_split_model = DomainModel(
            placeholder="(None)",
            valid_types=(Orange.data.DiscreteVariable, ),
            parent=self,
        )
        self.row_split_cb = cb = ComboBox(
            enabled=not self.merge_kmeans,
            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=14,
            toolTip="Split the heatmap vertically by a categorical column")
        self.row_split_cb.setModel(self.row_split_model)
        self.connect_control("split_by_var",
                             lambda value, cb=cb: cbselect(cb, value))
        self.connect_control("merge_kmeans", self.row_split_cb.setDisabled)
        self.split_by_var = None

        self.row_split_cb.activated.connect(self.__on_split_rows_activated)
        box.layout().addWidget(self.row_split_cb)

        box = gui.vBox(self.controlArea, 'Annotation && Legends')

        gui.checkBox(box,
                     self,
                     'legend',
                     'Show legend',
                     callback=self.update_legend)

        gui.checkBox(box,
                     self,
                     'averages',
                     'Stripes with averages',
                     callback=self.update_averages_stripe)
        annotbox = QGroupBox("Row Annotations", flat=True)
        form = QFormLayout(annotbox,
                           formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        self.annotation_model = DomainModel(placeholder="(None)")
        self.annotation_text_cb = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength)
        self.annotation_text_cb.setModel(self.annotation_model)
        self.annotation_text_cb.activated.connect(self.set_annotation_var)
        self.connect_control("annotation_var", self.annotation_var_changed)

        self.row_side_color_model = DomainModel(
            order=(DomainModel.CLASSES, DomainModel.Separator,
                   DomainModel.METAS),
            placeholder="(None)",
            valid_types=DomainModel.PRIMITIVE,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
            parent=self,
        )
        self.row_side_color_cb = ComboBoxSearch(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            minimumContentsLength=12)
        self.row_side_color_cb.setModel(self.row_side_color_model)
        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
        self.connect_control("annotation_color_var",
                             self.annotation_color_var_changed)
        form.addRow("Text", self.annotation_text_cb)
        form.addRow("Color", self.row_side_color_cb)
        box.layout().addWidget(annotbox)
        posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
        posbox.setFlat(True)
        cb = gui.comboBox(posbox,
                          self,
                          "column_label_pos",
                          callback=self.update_column_annotations)
        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
        cb.setCurrentIndex(self.column_label_pos)
        gui.checkBox(self.controlArea,
                     self,
                     "keep_aspect",
                     "Keep aspect ratio",
                     box="Resize",
                     callback=self.__aspect_mode_changed)

        gui.rubber(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

        # Scene with heatmap
        class HeatmapScene(QGraphicsScene):
            widget: Optional[HeatmapGridWidget] = None

        self.scene = self.scene = HeatmapScene(parent=self)
        self.view = GraphicsView(
            self.scene,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
            widgetResizable=True,
        )
        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
        self.view.customContextMenuRequested.connect(
            self._on_view_context_menu)
        self.mainArea.layout().addWidget(self.view)
        self.selected_rows = []
        self.__font_inc = QAction("Increase Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+>"))
        self.__font_dec = QAction("Decrease Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+<"))
        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
            apply_all([self.__font_inc, self.__font_dec],
                      lambda a: a.setShortcutVisibleInContextMenu(True))
        self.addActions([self.__font_inc, self.__font_dec])
示例#10
0
class OWHeatMap(widget.OWWidget):
    name = "Heat Map"
    description = "Plot a data matrix heatmap."
    icon = "icons/Heatmap.svg"
    priority = 260
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settings_version = 3

    settingsHandler = settings.DomainContextHandler()

    # Disable clustering for inputs bigger than this
    MaxClustering = 25000
    # Disable cluster leaf ordering for inputs bigger than this
    MaxOrderedClustering = 1000

    threshold_low = settings.Setting(0.0)
    threshold_high = settings.Setting(1.0)

    merge_kmeans = settings.Setting(False)
    merge_kmeans_k = settings.Setting(50)

    # Display column with averages
    averages: bool = settings.Setting(True)
    # Display legend
    legend: bool = settings.Setting(True)
    # Annotations
    #: text row annotation (row names)
    annotation_var = settings.ContextSetting(None)
    #: color row annotation
    annotation_color_var = settings.ContextSetting(None)
    # Discrete variable used to split that data/heatmaps (vertically)
    split_by_var = settings.ContextSetting(None)
    # Selected row/column clustering method (name)
    col_clustering_method: str = settings.Setting(Clustering.None_.name)
    row_clustering_method: str = settings.Setting(Clustering.None_.name)

    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
    column_label_pos: int = settings.Setting(1)
    selected_rows: List[int] = settings.Setting(None, schema_only=True)

    auto_commit = settings.Setting(True)

    graph_name = "scene"

    left_side_scrolling = True

    class Information(widget.OWWidget.Information):
        sampled = Msg("Data has been sampled")
        discrete_ignored = Msg("{} categorical feature{} ignored")
        row_clust = Msg("{}")
        col_clust = Msg("{}")
        sparse_densified = Msg("Showing this data may require a lot of memory")

    class Error(widget.OWWidget.Error):
        no_continuous = Msg("No numeric features")
        not_enough_features = Msg("Not enough features for column clustering")
        not_enough_instances = Msg("Not enough instances for clustering")
        not_enough_instances_k_means = Msg(
            "Not enough instances for k-means merging")
        not_enough_memory = Msg("Not enough memory to show this data")

    class Warning(widget.OWWidget.Warning):
        empty_clusters = Msg("Empty clusters were removed")

    def __init__(self):
        super().__init__()
        self.__pending_selection = self.selected_rows

        # A kingdom for a save_state/restore_state
        self.col_clustering = enum_get(Clustering, self.col_clustering_method,
                                       Clustering.None_)
        self.row_clustering = enum_get(Clustering, self.row_clustering_method,
                                       Clustering.None_)

        @self.settingsAboutToBePacked.connect
        def _():
            self.col_clustering_method = self.col_clustering.name
            self.row_clustering_method = self.row_clustering.name

        self.keep_aspect = False

        #: The original data with all features (retained to
        #: preserve the domain on the output)
        self.input_data = None
        #: The effective data striped of discrete features, and often
        #: merged using k-means
        self.data = None
        self.effective_data = None
        #: kmeans model used to merge rows of input_data
        self.kmeans_model = None
        #: merge indices derived from kmeans
        #: a list (len==k) of int ndarray where the i-th item contains
        #: the indices which merge the input_data into the heatmap row i
        self.merge_indices = None
        self.parts: Optional[Parts] = None
        self.__rows_cache = {}
        self.__columns_cache = {}

        # GUI definition
        colorbox = gui.vBox(self.controlArea, "Color")
        self.color_cb = gui.palette_combo_box(self.palette_name)
        self.color_cb.currentIndexChanged.connect(self.update_color_schema)
        colorbox.layout().addWidget(self.color_cb)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        lowslider = gui.hSlider(colorbox,
                                self,
                                "threshold_low",
                                minValue=0.0,
                                maxValue=1.0,
                                step=0.05,
                                ticks=True,
                                intOnly=False,
                                createLabel=False,
                                callback=self.update_lowslider)
        highslider = gui.hSlider(colorbox,
                                 self,
                                 "threshold_high",
                                 minValue=0.0,
                                 maxValue=1.0,
                                 step=0.05,
                                 ticks=True,
                                 intOnly=False,
                                 createLabel=False,
                                 callback=self.update_highslider)

        form.addRow("Low:", lowslider)
        form.addRow("High:", highslider)

        colorbox.layout().addLayout(form)

        mergebox = gui.vBox(
            self.controlArea,
            "Merge",
        )
        gui.checkBox(mergebox,
                     self,
                     "merge_kmeans",
                     "Merge by k-means",
                     callback=self.__update_row_clustering)
        ibox = gui.indentedBox(mergebox)
        gui.spin(ibox,
                 self,
                 "merge_kmeans_k",
                 minv=5,
                 maxv=500,
                 label="Clusters:",
                 keyboardTracking=False,
                 callbackOnReturn=True,
                 callback=self.update_merge)

        cluster_box = gui.vBox(self.controlArea, "Clustering")
        # Row clustering
        self.row_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.row_clustering, ClusteringRole)
        self.connect_control(
            "row_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_row_clustering(cb.itemData(idx, ClusteringRole))

        # Column clustering
        self.col_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.col_clustering, ClusteringRole)
        self.connect_control(
            "col_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_col_clustering(cb.itemData(idx, ClusteringRole))

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )
        form.addRow("Rows:", self.row_cluster_cb)
        form.addRow("Columns:", self.col_cluster_cb)
        cluster_box.layout().addLayout(form)
        box = gui.vBox(self.controlArea, "Split By")

        self.row_split_model = DomainModel(
            placeholder="(None)",
            valid_types=(Orange.data.DiscreteVariable, ),
            parent=self,
        )
        self.row_split_cb = cb = ComboBox(
            enabled=not self.merge_kmeans,
            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=14,
            toolTip="Split the heatmap vertically by a categorical column")
        self.row_split_cb.setModel(self.row_split_model)
        self.connect_control("split_by_var",
                             lambda value, cb=cb: cbselect(cb, value))
        self.connect_control("merge_kmeans", self.row_split_cb.setDisabled)
        self.split_by_var = None

        self.row_split_cb.activated.connect(self.__on_split_rows_activated)
        box.layout().addWidget(self.row_split_cb)

        box = gui.vBox(self.controlArea, 'Annotation && Legends')

        gui.checkBox(box,
                     self,
                     'legend',
                     'Show legend',
                     callback=self.update_legend)

        gui.checkBox(box,
                     self,
                     'averages',
                     'Stripes with averages',
                     callback=self.update_averages_stripe)
        annotbox = QGroupBox("Row Annotations", flat=True)
        form = QFormLayout(annotbox,
                           formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        self.annotation_model = DomainModel(placeholder="(None)")
        self.annotation_text_cb = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength)
        self.annotation_text_cb.setModel(self.annotation_model)
        self.annotation_text_cb.activated.connect(self.set_annotation_var)
        self.connect_control("annotation_var", self.annotation_var_changed)

        self.row_side_color_model = DomainModel(
            order=(DomainModel.CLASSES, DomainModel.Separator,
                   DomainModel.METAS),
            placeholder="(None)",
            valid_types=DomainModel.PRIMITIVE,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
            parent=self,
        )
        self.row_side_color_cb = ComboBoxSearch(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            minimumContentsLength=12)
        self.row_side_color_cb.setModel(self.row_side_color_model)
        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
        self.connect_control("annotation_color_var",
                             self.annotation_color_var_changed)
        form.addRow("Text", self.annotation_text_cb)
        form.addRow("Color", self.row_side_color_cb)
        box.layout().addWidget(annotbox)
        posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
        posbox.setFlat(True)
        cb = gui.comboBox(posbox,
                          self,
                          "column_label_pos",
                          callback=self.update_column_annotations)
        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
        cb.setCurrentIndex(self.column_label_pos)
        gui.checkBox(self.controlArea,
                     self,
                     "keep_aspect",
                     "Keep aspect ratio",
                     box="Resize",
                     callback=self.__aspect_mode_changed)

        gui.rubber(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

        # Scene with heatmap
        class HeatmapScene(QGraphicsScene):
            widget: Optional[HeatmapGridWidget] = None

        self.scene = self.scene = HeatmapScene(parent=self)
        self.view = GraphicsView(
            self.scene,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
            widgetResizable=True,
        )
        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
        self.view.customContextMenuRequested.connect(
            self._on_view_context_menu)
        self.mainArea.layout().addWidget(self.view)
        self.selected_rows = []
        self.__font_inc = QAction("Increase Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+>"))
        self.__font_dec = QAction("Decrease Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+<"))
        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
            apply_all([self.__font_inc, self.__font_dec],
                      lambda a: a.setShortcutVisibleInContextMenu(True))
        self.addActions([self.__font_inc, self.__font_dec])

    @property
    def center_palette(self):
        palette = self.color_cb.currentData()
        return bool(palette.flags & palette.Diverging)

    @property
    def _column_label_pos(self) -> HeatmapGridWidget.Position:
        return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole]

    def annotation_color_var_changed(self, value):
        cbselect(self.row_side_color_cb, value, Qt.EditRole)

    def annotation_var_changed(self, value):
        cbselect(self.annotation_text_cb, value, Qt.EditRole)

    def set_row_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.row_clustering != method:
            self.row_clustering = method
            cbselect(self.row_cluster_cb, method, ClusteringRole)
            self.__update_row_clustering()

    def set_col_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.col_clustering != method:
            self.col_clustering = method
            cbselect(self.col_cluster_cb, method, ClusteringRole)
            self.__update_column_clustering()

    def sizeHint(self) -> QSize:
        return super().sizeHint().expandedTo(QSize(900, 700))

    def color_palette(self):
        return self.color_cb.currentData().lookup_table()

    def color_map(self) -> GradientColorMap:
        return GradientColorMap(self.color_palette(),
                                (self.threshold_low, self.threshold_high),
                                0 if self.center_palette else None)

    def clear(self):
        self.data = None
        self.input_data = None
        self.effective_data = None
        self.kmeans_model = None
        self.merge_indices = None
        self.annotation_model.set_domain(None)
        self.annotation_var = None
        self.row_side_color_model.set_domain(None)
        self.annotation_color_var = None
        self.row_split_model.set_domain(None)
        self.split_by_var = None
        self.parts = None
        self.clear_scene()
        self.selected_rows = []
        self.__columns_cache.clear()
        self.__rows_cache.clear()
        self.__update_clustering_enable_state(None)

    def clear_scene(self):
        if self.scene.widget is not None:
            self.scene.widget.layoutDidActivate.disconnect(
                self.__on_layout_activate)
            self.scene.widget.selectionFinished.disconnect(
                self.on_selection_finished)
        self.scene.widget = None
        self.scene.clear()

        self.view.setSceneRect(QRectF())
        self.view.setHeaderSceneRect(QRectF())
        self.view.setFooterSceneRect(QRectF())

    @Inputs.data
    def set_dataset(self, data=None):
        """Set the input dataset to display."""
        self.closeContext()
        self.clear()
        self.clear_messages()

        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)

        if data is not None and not len(data):
            data = None

        if data is not None and sp.issparse(data.X):
            try:
                data = data.to_dense()
            except MemoryError:
                data = None
                self.Error.not_enough_memory()
            else:
                self.Information.sparse_densified()

        input_data = data

        # Data contains no attributes or meta attributes only
        if data is not None and len(data.domain.attributes) == 0:
            self.Error.no_continuous()
            input_data = data = None

        # Data contains some discrete attributes which must be filtered
        if data is not None and \
                any(var.is_discrete for var in data.domain.attributes):
            ndisc = sum(var.is_discrete for var in data.domain.attributes)
            data = data.transform(
                Domain([
                    var for var in data.domain.attributes if var.is_continuous
                ], data.domain.class_vars, data.domain.metas))
            if not data.domain.attributes:
                self.Error.no_continuous()
                input_data = data = None
            else:
                self.Information.discrete_ignored(ndisc,
                                                  "s" if ndisc > 1 else "")

        self.data = data
        self.input_data = input_data

        if data is not None:
            self.annotation_model.set_domain(self.input_data.domain)
            self.row_side_color_model.set_domain(self.input_data.domain)
            self.annotation_var = None
            self.annotation_color_var = None
            self.row_split_model.set_domain(data.domain)
            if data.domain.has_discrete_class:
                self.split_by_var = data.domain.class_var
            else:
                self.split_by_var = None
            self.openContext(self.input_data)
            if self.split_by_var not in self.row_split_model:
                self.split_by_var = None

        self.update_heatmaps()
        if data is not None and self.__pending_selection is not None:
            assert self.scene.widget is not None
            self.scene.widget.selectRows(self.__pending_selection)
            self.selected_rows = self.__pending_selection
            self.__pending_selection = None

        self.unconditional_commit()

    def __on_split_rows_activated(self):
        self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))

    def set_split_variable(self, var):
        if var != self.split_by_var:
            self.split_by_var = var
            self.update_heatmaps()

    def update_heatmaps(self):
        if self.data is not None:
            self.clear_scene()
            self.clear_messages()
            if self.col_clustering != Clustering.None_ and \
                    len(self.data.domain.attributes) < 2:
                self.Error.not_enough_features()
            elif (self.col_clustering != Clustering.None_ or
                  self.row_clustering != Clustering.None_) and \
                    len(self.data) < 2:
                self.Error.not_enough_instances()
            elif self.merge_kmeans and len(self.data) < 3:
                self.Error.not_enough_instances_k_means()
            else:
                parts = self.construct_heatmaps(self.data, self.split_by_var)
                self.construct_heatmaps_scene(parts, self.effective_data)
                self.selected_rows = []
        else:
            self.clear()

    def update_merge(self):
        self.kmeans_model = None
        self.merge_indices = None
        if self.data is not None and self.merge_kmeans:
            self.update_heatmaps()
            self.commit()

    def _make_parts(self, data, group_var=None):
        """
        Make initial `Parts` for data, split by group_var, group_key
        """
        if group_var is not None:
            assert group_var.is_discrete
            _col_data = table_column_data(data, group_var)
            row_indices = [
                np.flatnonzero(_col_data == i)
                for i in range(len(group_var.values))
            ]
            row_groups = [
                RowPart(title=name,
                        indices=ind,
                        cluster=None,
                        cluster_ordered=None)
                for name, ind in zip(group_var.values, row_indices)
            ]
        else:
            row_groups = [
                RowPart(title=None,
                        indices=range(0, len(data)),
                        cluster=None,
                        cluster_ordered=None)
            ]

        col_groups = [
            ColumnPart(title=None,
                       indices=range(0, len(data.domain.attributes)),
                       domain=data.domain,
                       cluster=None,
                       cluster_ordered=None)
        ]

        minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
        return Parts(row_groups, col_groups, span=(minv, maxv))

    def cluster_rows(self,
                     data: Table,
                     parts: 'Parts',
                     ordered=False) -> 'Parts':
        row_groups = []
        for row in parts.rows:
            if row.cluster is not None:
                cluster = row.cluster
            else:
                cluster = None
            if row.cluster_ordered is not None:
                cluster_ord = row.cluster_ordered
            else:
                cluster_ord = None

            if row.can_cluster:
                matrix = None
                need_dist = cluster is None or (ordered
                                                and cluster_ord is None)
                if need_dist:
                    subset = data[row.indices]
                    matrix = Orange.distance.Euclidean(subset)

                if cluster is None:
                    assert len(matrix) < self.MaxClustering
                    cluster = hierarchical.dist_matrix_clustering(
                        matrix, linkage=hierarchical.WARD)
                if ordered and cluster_ord is None:
                    assert len(matrix) < self.MaxOrderedClustering
                    cluster_ord = hierarchical.optimal_leaf_ordering(
                        cluster,
                        matrix,
                    )
            row_groups.append(
                row._replace(cluster=cluster, cluster_ordered=cluster_ord))

        return parts._replace(rows=row_groups)

    def cluster_columns(self, data, parts, ordered=False):
        assert len(parts.columns) == 1, "columns split is no longer supported"
        assert all(var.is_continuous for var in data.domain.attributes)

        col0 = parts.columns[0]
        if col0.cluster is not None:
            cluster = col0.cluster
        else:
            cluster = None
        if col0.cluster_ordered is not None:
            cluster_ord = col0.cluster_ordered
        else:
            cluster_ord = None
        need_dist = cluster is None or (ordered and cluster_ord is None)
        matrix = None
        if need_dist:
            data = Orange.distance._preprocess(data)
            matrix = np.asarray(Orange.distance.PearsonR(data, axis=0))
            # nan values break clustering below
            matrix = np.nan_to_num(matrix)

        if cluster is None:
            assert matrix is not None
            assert len(matrix) < self.MaxClustering
            cluster = hierarchical.dist_matrix_clustering(
                matrix, linkage=hierarchical.WARD)
        if ordered and cluster_ord is None:
            assert len(matrix) < self.MaxOrderedClustering
            cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix)

        col_groups = [
            col._replace(cluster=cluster, cluster_ordered=cluster_ord)
            for col in parts.columns
        ]
        return parts._replace(columns=col_groups)

    def construct_heatmaps(self, data, group_var=None) -> 'Parts':
        if self.merge_kmeans:
            if self.kmeans_model is None:
                effective_data = self.input_data.transform(
                    Orange.data.Domain([
                        var for var in self.input_data.domain.attributes
                        if var.is_continuous
                    ], self.input_data.domain.class_vars,
                                       self.input_data.domain.metas))
                nclust = min(self.merge_kmeans_k, len(effective_data) - 1)
                self.kmeans_model = kmeans_compress(effective_data, k=nclust)
                effective_data.domain = self.kmeans_model.domain
                merge_indices = [
                    np.flatnonzero(self.kmeans_model.labels == ind)
                    for ind in range(nclust)
                ]
                not_empty_indices = [
                    i for i, x in enumerate(merge_indices) if len(x) > 0
                ]
                self.merge_indices = \
                    [merge_indices[i] for i in not_empty_indices]
                if len(merge_indices) != len(self.merge_indices):
                    self.Warning.empty_clusters()
                effective_data = Orange.data.Table(
                    Orange.data.Domain(effective_data.domain.attributes),
                    self.kmeans_model.centroids[not_empty_indices])
            else:
                effective_data = self.effective_data

            group_var = None
        else:
            self.kmeans_model = None
            self.merge_indices = None
            effective_data = data

        self.effective_data = effective_data

        self.__update_clustering_enable_state(effective_data)

        parts = self._make_parts(effective_data, group_var)
        # Restore/update the row/columns items descriptions from cache if
        # available
        rows_cache_key = (group_var,
                          self.merge_kmeans_k if self.merge_kmeans else None)
        if rows_cache_key in self.__rows_cache:
            parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)

        if self.row_clustering != Clustering.None_:
            parts = self.cluster_rows(
                effective_data,
                parts,
                ordered=self.row_clustering == Clustering.OrderedClustering)
        if self.col_clustering != Clustering.None_:
            parts = self.cluster_columns(
                effective_data,
                parts,
                ordered=self.col_clustering == Clustering.OrderedClustering)

        # Cache the updated parts
        self.__rows_cache[rows_cache_key] = parts
        return parts

    def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None:
        _T = TypeVar("_T", bound=Union[RowPart, ColumnPart])

        def select_cluster(clustering: Clustering, item: _T) -> _T:
            if clustering == Clustering.None_:
                return item._replace(cluster=None, cluster_ordered=None)
            elif clustering == Clustering.Clustering:
                return item._replace(cluster=item.cluster,
                                     cluster_ordered=None)
            elif clustering == Clustering.OrderedClustering:
                return item._replace(cluster=item.cluster_ordered,
                                     cluster_ordered=None)
            else:  # pragma: no cover
                raise TypeError()

        rows = [
            select_cluster(self.row_clustering, rowitem)
            for rowitem in parts.rows
        ]
        cols = [
            select_cluster(self.col_clustering, colitem)
            for colitem in parts.columns
        ]
        parts = Parts(columns=cols, rows=rows, span=parts.span)

        self.setup_scene(parts, data)

    def setup_scene(self, parts, data):
        # type: (Parts, Table) -> None
        widget = HeatmapGridWidget()
        widget.setColorMap(self.color_map())
        self.scene.addItem(widget)
        self.scene.widget = widget
        columns = [v.name for v in data.domain.attributes]
        parts = HeatmapGridWidget.Parts(
            rows=[
                HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster)
                for r in parts.rows
            ],
            columns=[
                HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster)
                for c in parts.columns
            ],
            data=data.X,
            span=parts.span,
            row_names=None,
            col_names=columns,
        )
        widget.setHeatmaps(parts)
        side = self.row_side_colors()
        if side is not None:
            widget.setRowSideColorAnnotations(side[0],
                                              side[1],
                                              name=side[2].name)
        widget.setColumnLabelsPosition(self._column_label_pos)
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        widget.setShowAverages(self.averages)
        widget.setLegendVisible(self.legend)

        widget.layoutDidActivate.connect(self.__on_layout_activate)
        widget.selectionFinished.connect(self.on_selection_finished)

        self.update_annotations()
        self.view.setCentralWidget(widget)
        self.parts = parts

    def __update_scene_rects(self):
        widget = self.scene.widget
        if widget is None:
            return
        rect = widget.geometry()
        self.scene.setSceneRect(rect)
        self.view.setSceneRect(rect)
        self.view.setHeaderSceneRect(widget.headerGeometry())
        self.view.setFooterSceneRect(widget.footerGeometry())

    def __on_layout_activate(self):
        self.__update_scene_rects()

    def __aspect_mode_changed(self):
        widget = self.scene.widget
        if widget is None:
            return
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        # when aspect fixed the vertical sh is fixex, when not, it can
        # shrink vertically
        sp = widget.sizePolicy()
        if self.keep_aspect:
            sp.setVerticalPolicy(QSizePolicy.Fixed)
        else:
            sp.setVerticalPolicy(QSizePolicy.Preferred)
        widget.setSizePolicy(sp)

    def __update_clustering_enable_state(self, data):
        if data is not None:
            N = len(data)
            M = len(data.domain.attributes)
        else:
            N = M = 0

        rc_enabled = N <= self.MaxClustering
        rco_enabled = N <= self.MaxOrderedClustering
        cc_enabled = M <= self.MaxClustering
        cco_enabled = M <= self.MaxOrderedClustering
        row_clust, col_clust = self.row_clustering, self.col_clustering

        row_clust_msg = ""
        col_clust_msg = ""

        if not rco_enabled and row_clust == Clustering.OrderedClustering:
            row_clust = Clustering.Clustering
            row_clust_msg = "Row cluster ordering was disabled due to the " \
                            "input matrix being to big"
        if not rc_enabled and row_clust == Clustering.Clustering:
            row_clust = Clustering.None_
            row_clust_msg = "Row clustering was was disabled due to the " \
                            "input matrix being to big"

        if not cco_enabled and col_clust == Clustering.OrderedClustering:
            col_clust = Clustering.Clustering
            col_clust_msg = "Column cluster ordering was disabled due to " \
                            "the input matrix being to big"
        if not cc_enabled and col_clust == Clustering.Clustering:
            col_clust = Clustering.None_
            col_clust_msg = "Column clustering was disabled due to the " \
                            "input matrix being to big"

        self.col_clustering = col_clust
        self.row_clustering = row_clust

        self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg))
        self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg))

        # Disable/enable the combobox items for the clustering methods
        def setenabled(cb: QComboBox, clu: bool, clu_op: bool):
            model = cb.model()
            assert isinstance(model, QStandardItemModel)
            idx = cb.findData(Clustering.OrderedClustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu_op)
            idx = cb.findData(Clustering.Clustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu)

        setenabled(self.row_cluster_cb, rc_enabled, rco_enabled)
        setenabled(self.col_cluster_cb, cc_enabled, cco_enabled)

    def update_averages_stripe(self):
        """Update the visibility of the averages stripe.
        """
        widget = self.scene.widget
        if widget is not None:
            widget.setShowAverages(self.averages)

    def update_lowslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            low.setSliderPosition(high.value() - 1)
        self.update_color_schema()

    def update_highslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            high.setSliderPosition(low.value() + 1)
        self.update_color_schema()

    def update_color_schema(self):
        self.palette_name = self.color_cb.currentData().name
        w = self.scene.widget
        if w is not None:
            w.setColorMap(self.color_map())

    def __update_column_clustering(self):
        self.update_heatmaps()
        self.commit()

    def __update_row_clustering(self):
        self.update_heatmaps()
        self.commit()

    def update_legend(self):
        widget = self.scene.widget
        if widget is not None:
            widget.setLegendVisible(self.legend)

    def row_annotation_var(self):
        return self.annotation_var

    def row_annotation_data(self):
        var = self.row_annotation_var()
        if var is None:
            return None
        return column_str_from_table(self.input_data, var)

    def _merge_row_indices(self):
        if self.merge_kmeans and self.kmeans_model is not None:
            return self.merge_indices
        else:
            return None

    def set_annotation_var(self, var: Union[None, Variable, int]):
        if isinstance(var, int):
            var = self.annotation_model[var]
        if self.annotation_var != var:
            self.annotation_var = var
            self.update_annotations()

    def update_annotations(self):
        widget = self.scene.widget
        if widget is not None:
            annot_col = self.row_annotation_data()
            merge_indices = self._merge_row_indices()
            if merge_indices is not None and annot_col is not None:
                join = lambda _1: join_elided(", ", 42, _1, " ({} more)")
                annot_col = aggregate_apply(join, annot_col, merge_indices)
            if annot_col is not None:
                widget.setRowLabels(annot_col)
                widget.setRowLabelsVisible(True)
            else:
                widget.setRowLabelsVisible(False)
                widget.setRowLabels(None)

    def row_side_colors(self):
        var = self.annotation_color_var
        if var is None:
            return None
        column_data = column_data_from_table(self.input_data, var)
        span = (np.nanmin(column_data), np.nanmax(column_data))
        merges = self._merge_row_indices()
        if merges is not None:
            column_data = aggregate(var, column_data, merges)
        data, colormap = self._colorize(var, column_data)
        if var.is_continuous:
            colormap.span = span
        return data, colormap, var

    def set_annotation_color_var(self, var: Union[None, Variable, int]):
        """Set the current side color annotation variable."""
        if isinstance(var, int):
            var = self.row_side_color_model[var]
        if self.annotation_color_var != var:
            self.annotation_color_var = var
            self.update_row_side_colors()

    def update_row_side_colors(self):
        widget = self.scene.widget
        if widget is None:
            return
        colors = self.row_side_colors()
        if colors is None:
            widget.setRowSideColorAnnotations(None)
        else:
            widget.setRowSideColorAnnotations(colors[0], colors[1],
                                              colors[2].name)

    def _colorize(self, var: Variable,
                  data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
        palette = var.palette  # type: Palette
        colors = np.array(
            [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
            dtype=np.uint8,
        )
        if var.is_discrete:
            mask = np.isnan(data)
            data[mask] = -1
            data = data.astype(int)
            if mask.any():
                values = (*var.values, "N/A")
            else:
                values = var.values
                colors = colors[:-1]
            return data, CategoricalColorMap(colors, values)
        elif var.is_continuous:
            cmap = GradientColorMap(colors[:-1])
            return data, cmap
        else:
            raise TypeError

    def update_column_annotations(self):
        widget = self.scene.widget
        if self.data is not None and widget is not None:
            widget.setColumnLabelsPosition(self._column_label_pos)

    def __adjust_font_size(self, diff):
        widget = self.scene.widget
        if widget is None:
            return
        curr = widget.font().pointSizeF()
        new = curr + diff

        self.__font_dec.setEnabled(new > 1.0)
        self.__font_inc.setEnabled(new <= 32)
        if new > 1.0:
            font = QFont()
            font.setPointSizeF(new)
            widget.setFont(font)

    def _on_view_context_menu(self, pos):
        widget = self.scene.widget
        if widget is None:
            return
        assert isinstance(widget, HeatmapGridWidget)
        menu = QMenu(self.view.viewport())
        menu.setAttribute(Qt.WA_DeleteOnClose)
        menu.addActions(self.view.actions())
        menu.addSeparator()
        menu.addActions([self.__font_inc, self.__font_dec])
        menu.addSeparator()
        a = QAction("Keep aspect ratio", menu, checkable=True)
        a.setChecked(self.keep_aspect)

        def ontoggled(state):
            self.keep_aspect = state
            self.__aspect_mode_changed()

        a.toggled.connect(ontoggled)
        menu.addAction(a)
        menu.popup(self.view.viewport().mapToGlobal(pos))

    def on_selection_finished(self):
        if self.scene.widget is not None:
            self.selected_rows = list(self.scene.widget.selectedRows())
        else:
            self.selected_rows = []
        self.commit()

    def commit(self):
        data = None
        indices = None
        if self.merge_kmeans:
            merge_indices = self.merge_indices
        else:
            merge_indices = None

        if self.input_data is not None and self.selected_rows:
            indices = self.selected_rows
            if merge_indices is not None:
                # expand merged indices
                indices = np.hstack([merge_indices[i] for i in indices])

            data = self.input_data[indices]

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.input_data, indices))

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()

    def send_report(self):
        self.report_items((
            ("Columns:",
             "Clustering" if self.col_clustering else "No sorting"),
            ("Rows:", "Clustering" if self.row_clustering else "No sorting"),
            ("Split:", self.split_by_var is not None
             and self.split_by_var.name),
            ("Row annotation", self.annotation_var is not None
             and self.annotation_var.name),
        ))
        self.report_plot()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version is not None and version < 3:

            def st2cl(state: bool) -> Clustering:
                return Clustering.OrderedClustering if state else \
                    Clustering.None_

            rc = settings.pop("row_clustering", False)
            cc = settings.pop("col_clustering", False)
            settings["row_clustering_method"] = st2cl(rc).name
            settings["col_clustering_method"] = st2cl(cc).name
示例#11
0
    def add_row(self, attr=None, condition_type=None, condition_value=None):
        model = self.cond_list.model()
        row = model.rowCount()
        model.insertRow(row)

        attr_combo = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
        attr_combo.row = row
        for var in self._visible_variables(self.data.domain):
            if isinstance(var, Variable):
                attr_combo.addItem(*gui.attributeItem(var))
            else:
                attr_combo.addItem(var)
        if isinstance(attr, str):
            attr_combo.setCurrentText(attr)
        else:
            attr_combo.setCurrentIndex(
                attr or len(self.AllTypes) -
                (attr_combo.count() == len(self.AllTypes)))
        self.cond_list.setCellWidget(row, 0, attr_combo)

        index = QPersistentModelIndex(model.index(row, 3))
        temp_button = QPushButton(
            '×',
            self,
            flat=True,
            styleSheet='* {font-size: 16pt; color: silver}'
            '*:hover {color: black}')
        temp_button.clicked.connect(lambda: self.remove_one(index.row()))
        self.cond_list.setCellWidget(row, 3, temp_button)

        self.remove_all_button.setDisabled(False)
        self.set_new_operators(attr_combo, attr is not None, condition_type,
                               condition_value)
        attr_combo.currentIndexChanged.connect(
            lambda _: self.set_new_operators(attr_combo, False))

        self.cond_list.resizeRowToContents(row)