Beispiel #1
0
    def test_read_only(self):
        model = DomainModel()
        domain = Domain([ContinuousVariable(x) for x in "abc"])
        model.set_domain(domain)
        index = model.index(0, 0)

        self.assertRaises(TypeError, model.append, 42)
        self.assertRaises(TypeError, model.extend, [42])
        self.assertRaises(TypeError, model.insert, 0, 42)
        self.assertRaises(TypeError, model.remove, 0)
        self.assertRaises(TypeError, model.pop)
        self.assertRaises(TypeError, model.clear)
        self.assertRaises(TypeError, model.reverse)
        self.assertRaises(TypeError, model.sort)
        with self.assertRaises(TypeError):
            model[0] = 1
        with self.assertRaises(TypeError):
            del model[0]

        self.assertRaises(TypeError, model.setData, index, domain[0])
        self.assertTrue(model.setData(index, "foo", Qt.ToolTipRole))

        self.assertRaises(TypeError, model.setItemData, index,
                          {Qt.EditRole: domain[0], Qt.ToolTipRole: "foo"})
        self.assertTrue(model.setItemData(index, {Qt.ToolTipRole: "foo"}))

        self.assertRaises(TypeError, model.insertRows, 0, 0)
        self.assertRaises(TypeError, model.removeRows, 0, 0)
Beispiel #2
0
    def test_separators(self):
        attrs = [ContinuousVariable(n) for n in "abg"]
        classes = [ContinuousVariable(n) for n in "deh"]
        metas = [ContinuousVariable(n) for n in "ijf"]

        model = DomainModel()
        sep = [model.Separator]
        model.set_domain(Domain(attrs, classes, metas))
        self.assertEqual(list(model), classes + sep + metas + sep + attrs)

        model = DomainModel()
        model.set_domain(Domain(attrs, [], metas))
        self.assertEqual(list(model), metas + sep + attrs)

        model = DomainModel()
        model.set_domain(Domain([], [], metas))
        self.assertEqual(list(model), metas)

        model = DomainModel(placeholder="foo")
        model.set_domain(Domain([], [], metas))
        self.assertEqual(list(model), [None] + sep + metas)

        model = DomainModel(placeholder="foo")
        model.set_domain(Domain(attrs, [], metas))
        self.assertEqual(list(model), [None] + sep + metas + sep + attrs)
Beispiel #3
0
    def test_no_separators(self):
        """
        GH-2697
        """
        attrs = [ContinuousVariable(n) for n in "abg"]
        classes = [ContinuousVariable(n) for n in "deh"]
        metas = [ContinuousVariable(n) for n in "ijf"]

        model = DomainModel(order=DomainModel.SEPARATED, separators=False)
        model.set_domain(Domain(attrs, classes, metas))
        self.assertEqual(list(model), classes + metas + attrs)

        model = DomainModel(order=DomainModel.SEPARATED, separators=True)
        model.set_domain(Domain(attrs, classes, metas))
        self.assertEqual(
            list(model),
            classes + [PyListModel.Separator] + metas + [PyListModel.Separator] + attrs)
Beispiel #4
0
    def test_subparts(self):
        attrs = [ContinuousVariable(n) for n in "abg"]
        classes = [ContinuousVariable(n) for n in "deh"]
        metas = [ContinuousVariable(n) for n in "ijf"]

        m = DomainModel
        sep = m.Separator
        model = DomainModel(
            order=(m.ATTRIBUTES | m.METAS, sep, m.CLASSES))
        model.set_domain(Domain(attrs, classes, metas))
        self.assertEqual(list(model), attrs + metas + [sep] + classes)

        m = DomainModel
        sep = m.Separator
        model = DomainModel(
            order=(m.ATTRIBUTES | m.METAS, sep, m.CLASSES),
            alphabetical=True)
        model.set_domain(Domain(attrs, classes, metas))
        self.assertEqual(list(model),
                         sorted(attrs + metas, key=lambda x: x.name) +
                         [sep] +
                         sorted(classes, key=lambda x: x.name))
Beispiel #5
0
    def test_filtering(self):
        cont = [ContinuousVariable(n) for n in "abc"]
        disc = [DiscreteVariable(n) for n in "def"]
        attrs = cont + disc

        model = DomainModel(valid_types=(ContinuousVariable, ))
        model.set_domain(Domain(attrs))
        self.assertEqual(list(model), cont)

        model = DomainModel(valid_types=(DiscreteVariable, ))
        model.set_domain(Domain(attrs))
        self.assertEqual(list(model), disc)

        disc[0].attributes["hidden"] = True
        model.set_domain(Domain(attrs))
        self.assertEqual(list(model), disc[1:])

        model = DomainModel(valid_types=(DiscreteVariable, ),
                            skip_hidden_vars=False)
        model.set_domain(Domain(attrs))
        self.assertEqual(list(model), disc)
Beispiel #6
0
class OWHyper(OWWidget):
    name = "HyperSpectra"

    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        selected_data = Output("Selection", Orange.data.Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    icon = "icons/hyper.svg"
    priority = 20
    replaces = ["orangecontrib.infrared.widgets.owhyper.OWHyper"]

    settings_version = 3
    settingsHandler = DomainContextHandler()

    imageplot = SettingProvider(ImagePlot)
    curveplot = SettingProvider(CurvePlotHyper)

    integration_method = Setting(0)
    integration_methods = Integrate.INTEGRALS
    value_type = Setting(0)
    attr_value = ContextSetting(None)

    lowlim = Setting(None)
    highlim = Setting(None)
    choose = Setting(None)

    graph_name = "imageplot.plotview"  # defined so that the save button is shown

    class Warning(OWWidget.Warning):
        threshold_error = Msg("Low slider should be less than High")

    class Error(OWWidget.Warning):
        image_too_big = Msg("Image for chosen features is too big ({} x {}).")

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            # delete the saved attr_value to prevent crashes
            try:
                del settings_["context_settings"][0].values["attr_value"]
            except:
                pass

        # migrate selection
        if version <= 2:
            try:
                current_context = settings_["context_settings"][0]
                selection = getattr(current_context, "selection", None)
                if selection is not None:
                    selection = [(i, 1) for i in np.flatnonzero(np.array(selection))]
                    settings_.setdefault("imageplot", {})["selection_group_saved"] = selection
            except:
                pass

    def __init__(self):
        super().__init__()

        dbox = gui.widgetBox(self.controlArea, "Image values")

        rbox = gui.radioButtons(
            dbox, self, "value_type", callback=self._change_integration)

        gui.appendRadioButton(rbox, "From spectra")

        self.box_values_spectra = gui.indentedBox(rbox)

        gui.comboBox(
            self.box_values_spectra, self, "integration_method", valueType=int,
            items=(a.name for a in self.integration_methods),
            callback=self._change_integral_type)
        gui.rubber(self.controlArea)

        gui.appendRadioButton(rbox, "Use feature")

        self.box_values_feature = gui.indentedBox(rbox)

        self.feature_value_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                               valid_types=DomainModel.PRIMITIVE)
        self.feature_value = gui.comboBox(
            self.box_values_feature, self, "attr_value",
            callback=self.update_feature_value, model=self.feature_value_model,
            sendSelectedValue=True, valueType=str)

        splitter = QSplitter(self)
        splitter.setOrientation(Qt.Vertical)
        self.imageplot = ImagePlot(self)
        self.imageplot.selection_changed.connect(self.output_image_selection)

        self.curveplot = CurvePlotHyper(self, select=SELECTONE)
        self.curveplot.selection_changed.connect(self.redraw_data)
        self.curveplot.plot.vb.x_padding = 0.005  # pad view so that lines are not hidden
        splitter.addWidget(self.imageplot)
        splitter.addWidget(self.curveplot)
        self.mainArea.layout().addWidget(splitter)

        self.line1 = MovableVline(position=self.lowlim, label="", report=self.curveplot)
        self.line1.sigMoved.connect(lambda v: setattr(self, "lowlim", v))
        self.line2 = MovableVline(position=self.highlim, label="", report=self.curveplot)
        self.line2.sigMoved.connect(lambda v: setattr(self, "highlim", v))
        self.line3 = MovableVline(position=self.choose, label="", report=self.curveplot)
        self.line3.sigMoved.connect(lambda v: setattr(self, "choose", v))
        for line in [self.line1, self.line2, self.line3]:
            line.sigMoveFinished.connect(self.changed_integral_range)
            self.curveplot.add_marking(line)
            line.hide()

        self.data = None
        self.disable_integral_range = False

        self.resize(900, 700)
        self._update_integration_type()

        # prepare interface according to the new context
        self.contextAboutToBeOpened.connect(lambda x: self.init_interface_data(x[0]))

    def init_interface_data(self, data):
        same_domain = (self.data and data and
                       data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)

    def output_image_selection(self):
        if not self.data:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            self.curveplot.set_data(None)
            return

        indices = np.flatnonzero(self.imageplot.selection_group)

        annotated_data = groups_or_annotated_table(self.data, self.imageplot.selection_group)
        self.Outputs.annotated_data.send(annotated_data)

        selected = self.data[indices]
        self.Outputs.selected_data.send(selected if selected else None)
        if selected:
            self.curveplot.set_data(selected)
        else:
            self.curveplot.set_data(self.data)

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.feature_value_model.set_domain(domain)
        self.attr_value = self.feature_value_model[0] if self.feature_value_model else None

    def redraw_data(self):
        self.imageplot.update_view()

    def update_feature_value(self):
        self.redraw_data()

    def _update_integration_type(self):
        self.line1.hide()
        self.line2.hide()
        self.line3.hide()
        if self.value_type == 0:
            self.box_values_spectra.setDisabled(False)
            self.box_values_feature.setDisabled(True)
            if self.integration_methods[self.integration_method] != Integrate.PeakAt:
                self.line1.show()
                self.line2.show()
            else:
                self.line3.show()
        elif self.value_type == 1:
            self.box_values_spectra.setDisabled(True)
            self.box_values_feature.setDisabled(False)
        QTest.qWait(1)  # first update the interface

    def _change_integration(self):
        # change what to show on the image
        self._update_integration_type()
        self.redraw_data()

    def changed_integral_range(self):
        if self.disable_integral_range:
            return
        self.redraw_data()

    def _change_integral_type(self):
        self._change_integration()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()

        def valid_context(data):
            if data is None:
                return False
            annotation_features = [v for v in data.domain.metas + data.domain.class_vars
                                   if isinstance(v, (DiscreteVariable, ContinuousVariable))]
            return len(annotation_features) >= 1

        if valid_context(data):
            self.openContext(data)
        else:
            # to generate valid interface even if context was not loaded
            self.contextAboutToBeOpened.emit([data])
        self.data = data
        self.imageplot.set_data(data)
        self.curveplot.set_data(data)
        self._init_integral_boundaries()
        self.imageplot.update_view()
        self.output_image_selection()

    def _init_integral_boundaries(self):
        # requires data in curveplot
        self.disable_integral_range = True
        if self.curveplot.data_x is not None and len(self.curveplot.data_x):
            minx = self.curveplot.data_x[0]
            maxx = self.curveplot.data_x[-1]
        else:
            minx = 0.
            maxx = 1.

        if self.lowlim is None or not minx <= self.lowlim <= maxx:
            self.lowlim = minx
        self.line1.setValue(self.lowlim)

        if self.highlim is None or not minx <= self.highlim <= maxx:
            self.highlim = maxx
        self.line2.setValue(self.highlim)

        if self.choose is None:
            self.choose = (minx + maxx)/2
        elif self.choose < minx:
            self.choose = minx
        elif self.choose > maxx:
            self.choose = maxx
        self.line3.setValue(self.choose)
        self.disable_integral_range = False

    def save_graph(self):
        self.imageplot.save_graph()
Beispiel #7
0
class OWMosaicDisplay(OWWidget):
    name = "Mosaic Display"
    description = "Display data in a mosaic plot."
    icon = "icons/MosaicDisplay.svg"
    priority = 220

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    PEARSON, CLASS_DISTRIBUTION = 0, 1

    settingsHandler = DomainContextHandler()
    use_boxes = Setting(True)
    interior_coloring = Setting(CLASS_DISTRIBUTION)
    variable1 = ContextSetting("")
    variable2 = ContextSetting("")
    variable3 = ContextSetting("")
    variable4 = ContextSetting("")
    variable_color = ContextSetting("")
    selection = ContextSetting(set())

    BAR_WIDTH = 5
    SPACING = 4
    ATTR_NAME_OFFSET = 20
    ATTR_VAL_OFFSET = 3
    BLUE_COLORS = [QColor(255, 255, 255), QColor(210, 210, 255),
                   QColor(110, 110, 255), QColor(0, 0, 255)]
    RED_COLORS = [QColor(255, 255, 255), QColor(255, 200, 200),
                  QColor(255, 100, 100), QColor(255, 0, 0)]

    vizrank = SettingProvider(MosaicVizRank)

    graph_name = "canvas"

    class Warning(OWWidget.Warning):
        incompatible_subset = Msg("Data subset is incompatible with Data")
        no_valid_data = Msg("No valid data")
        no_cont_selection_sql = \
            Msg("Selection of numeric features on SQL is not supported")

    def __init__(self):
        super().__init__()

        self.data = None
        self.discrete_data = None
        self.subset_data = None
        self.subset_indices = None

        self.color_data = None

        self.areas = []

        self.canvas = QGraphicsScene()
        self.canvas_view = ViewWithPress(self.canvas,
                                         handler=self.clear_selection)
        self.mainArea.layout().addWidget(self.canvas_view)
        self.canvas_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setRenderHint(QPainter.Antialiasing)

        box = gui.vBox(self.controlArea, box=True)
        self.attr_combos = [
            gui.comboBox(
                box, self, value="variable{}".format(i),
                orientation=Qt.Horizontal, contentsLength=12,
                callback=self.reset_graph,
                sendSelectedValue=True, valueType=str, emptyString="(None)")
            for i in range(1, 5)]
        self.vizrank, self.vizrank_button = MosaicVizRank.add_vizrank(
            box, self, "Find Informative Mosaics", self.set_attr)

        box2 = gui.vBox(self.controlArea, box="Interior Coloring")
        dmod = DomainModel
        self.color_model = DomainModel(order=dmod.MIXED,
                                       valid_types=dmod.PRIMITIVE,
                                       placeholder="(Pearson residuals)")
        self.cb_attr_color = gui.comboBox(
            box2, self, value="variable_color",
            orientation=Qt.Horizontal, contentsLength=12, labelWidth=50,
            callback=self.set_color_data,
            sendSelectedValue=True, model=self.color_model, valueType=str)
        self.bar_button = gui.checkBox(
            box2, self, 'use_boxes', label='Compare with total',
            callback=self._compare_with_total)
        gui.rubber(self.controlArea)

    def sizeHint(self):
        return QSize(720, 530)

    def _compare_with_total(self):
        if self.data is not None and \
                self.data.domain.class_var is not None and \
                self.interior_coloring != self.CLASS_DISTRIBUTION:
            self.interior_coloring = self.CLASS_DISTRIBUTION
            self.coloring_changed()  # This also calls self.update_graph
        else:
            self.update_graph()

    def _get_discrete_data(self, data):
        """
        Discretizes continuous attributes.
        Returns None when there is no data, no rows, or no discrete or continuous attributes.
        """
        if (data is None or
                not len(data) or
                not any(attr.is_discrete or attr.is_continuous
                        for attr in chain(data.domain, data.domain.metas))):
            return None
        elif any(attr.is_continuous for attr in data.domain):
            return Discretize(
                method=EqualFreq(n=4), remove_const=False, discretize_classes=True,
                discretize_metas=True)(data)
        else:
            return data

    def init_combos(self, data):
        for combo in self.attr_combos:
            combo.clear()
        if data is None:
            self.color_model.set_domain(None)
            return
        self.color_model.set_domain(self.data.domain)
        for combo in self.attr_combos[1:]:
            combo.addItem("(None)")

        icons = gui.attributeIconDict
        for attr in chain(data.domain, data.domain.metas):
            if attr.is_primitive:
                for combo in self.attr_combos:
                    combo.addItem(icons[attr], attr.name)

        if self.attr_combos[0].count() > 0:
            self.variable1 = self.attr_combos[0].itemText(0)
            self.variable2 = self.attr_combos[1].itemText(
                2 * (self.attr_combos[1].count() > 2))
        self.variable3 = self.attr_combos[2].itemText(0)
        self.variable4 = self.attr_combos[3].itemText(0)
        if self.data.domain.class_var:
            self.variable_color = self.data.domain.class_var.name
            idx = self.cb_attr_color.findText(self.variable_color)
        else:
            idx = 0
        self.cb_attr_color.setCurrentIndex(idx)

    def get_attr_list(self):
        return [
            a for a in [self.variable1, self.variable2,
                        self.variable3, self.variable4]
            if a and a != "(None)"]

    def set_attr(self, *attrs):
        self.variable1, self.variable2, self.variable3, self.variable4 = \
            [a.name if a else "" for a in attrs]
        self.reset_graph()

    def resizeEvent(self, e):
        OWWidget.resizeEvent(self, e)
        self.update_graph()

    def showEvent(self, ev):
        OWWidget.showEvent(self, ev)
        self.update_graph()

    @Inputs.data
    def set_data(self, data):
        if type(data) == SqlTable and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data

        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(
            self.data is not None and len(self.data) > 1 \
            and len(self.data.domain.attributes) >= 1)

        if self.data is None:
            self.discrete_data = None
            self.init_combos(None)
            return

        self.init_combos(self.data)

        self.openContext(self.data)

    @Inputs.data_subset
    def set_subset_data(self, data):
        self.subset_data = data

    # this is called by widget after setData and setSubsetData are called.
    # this way the graph is updated only once
    def handleNewSignals(self):
        self.Warning.incompatible_subset.clear()
        self.subset_indices = indices = None
        if self.data is not None and self.subset_data:
            transformed = self.subset_data.transform(self.data.domain)
            if np.all(np.isnan(transformed.X)) and np.all(np.isnan(transformed.Y)):
                self.Warning.incompatible_subset()
            else:
                indices = {e.id for e in transformed}
                self.subset_indices = [ex.id in indices for ex in self.data]

        self.set_color_data()
        self.reset_graph()

    def clear_selection(self):
        self.selection = set()
        self.update_selection_rects()
        self.send_selection()

    def coloring_changed(self):
        self.vizrank.coloring_changed()
        self.update_graph()

    def reset_graph(self):
        self.clear_selection()
        self.update_graph()

    def set_color_data(self):
        if self.data is None or len(self.data) < 2 or len(self.data.domain.attributes) < 1:
            return
        if self.cb_attr_color.currentIndex() <= 0:
            color_var = None
            self.interior_coloring = self.PEARSON
            self.bar_button.setEnabled(False)
        else:
            color_var = self.data.domain[self.cb_attr_color.currentText()]
            self.interior_coloring = self.CLASS_DISTRIBUTION
            self.bar_button.setEnabled(True)
        attributes = [v for v in self.data.domain.attributes + self.data.domain.class_vars
                      + self.data.domain.metas if v != color_var and v.is_primitive()]
        domain = Domain(attributes, color_var, None)
        self.color_data = color_data = self.data.from_table(domain, self.data)
        self.discrete_data = self._get_discrete_data(color_data)
        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(True)
        self.coloring_changed()

    def update_selection_rects(self):
        for i, (_, _, area) in enumerate(self.areas):
            if i in self.selection:
                area.setPen(QPen(Qt.black, 3, Qt.DotLine))
            else:
                area.setPen(QPen())

    def select_area(self, index, ev):
        if ev.button() != Qt.LeftButton:
            return
        if ev.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection_rects()
        self.send_selection()

    def send_selection(self):
        if not self.selection or self.data is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(create_annotated_table(self.data, []))
            return
        filters = []
        self.Warning.no_cont_selection_sql.clear()
        if self.discrete_data is not self.data:
            if isinstance(self.data, SqlTable):
                self.Warning.no_cont_selection_sql()
        for i in self.selection:
            cols, vals, _ = self.areas[i]
            filters.append(
                filter.Values(
                    filter.FilterDiscrete(col, [val])
                    for col, val in zip(cols, vals)))
        if len(filters) > 1:
            filters = filter.Values(filters, conjunction=False)
        else:
            filters = filters[0]
        selection = filters(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]
        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(create_annotated_table(self.data, sel_idx))

    def send_report(self):
        self.report_plot(self.canvas)

    def update_graph(self):
        spacing = self.SPACING
        bar_width = self.BAR_WIDTH

        def get_counts(attr_vals, values):
            """This function calculates rectangles' widths.
            If all widths are zero then all widths are set to 1."""
            if attr_vals == "":
                counts = [conditionaldict[val] for val in values]
            else:
                counts = [conditionaldict[attr_vals + "-" + val]
                          for val in values]
            total = sum(counts)
            if total == 0:
                counts = [1] * len(values)
                total = sum(counts)
            return total, counts

        def draw_data(attr_list, x0_x1, y0_y1, side, condition,
                      total_attrs, used_attrs, used_vals, attr_vals=""):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if conditionaldict[attr_vals] == 0:
                add_rect(x0, x1, y0, y1, "",
                         used_attrs, used_vals, attr_vals=attr_vals)
                # store coordinates for later drawing of labels
                draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                          used_attrs, used_vals, attr_vals)
                return

            attr = attr_list[0]
            # how much smaller rectangles do we draw
            edge = len(attr_list) * spacing
            values = get_variable_values_sorted(data.domain[attr])
            if side % 2:
                values = values[::-1]  # reverse names if necessary

            if side % 2 == 0:  # we are drawing on the x axis
                # remove the space needed for separating different attr. values
                whole = max(0, (x1 - x0) - edge * (
                    len(values) - 1))
                if whole == 0:
                    edge = (x1 - x0) / float(len(values) - 1)
            else:  # we are drawing on the y axis
                whole = max(0, (y1 - y0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (y1 - y0) / float(len(values) - 1)

            total, counts = get_counts(attr_vals, values)

            # if we are visualizing the third attribute and the first attribute
            # has the last value, we have to reverse the order in which the
            # boxes will be drawn otherwise, if the last cell, nearest to the
            # labels of the fourth attribute, is empty, we wouldn't be able to
            # position the labels
            valrange = list(range(len(values)))
            if len(attr_list + used_attrs) == 4 and len(used_attrs) == 2:
                attr1values = get_variable_values_sorted(
                    data.domain[used_attrs[0]])
                if used_vals[0] == attr1values[-1]:
                    valrange = valrange[::-1]

            for i in valrange:
                start = i * edge + whole * float(sum(counts[:i]) / total)
                end = i * edge + whole * float(sum(counts[:i + 1]) / total)
                val = values[i]
                htmlval = to_html(val)
                if attr_vals != "":
                    newattrvals = attr_vals + "-" + val
                else:
                    newattrvals = val

                tooltip = condition + 4 * "&nbsp;" + attr + \
                    ": <b>" + htmlval + "</b><br>"
                attrs = used_attrs + [attr]
                vals = used_vals + [val]
                common_args = attrs, vals, newattrvals
                if side % 2 == 0:  # if we are moving horizontally
                    if len(attr_list) == 1:
                        add_rect(x0 + start, x0 + end, y0, y1,
                                 tooltip, *common_args)
                    else:
                        draw_data(attr_list[1:], (x0 + start, x0 + end),
                                  (y0, y1), side + 1,
                                  tooltip, total_attrs, *common_args)
                else:
                    if len(attr_list) == 1:
                        add_rect(x0, x1, y0 + start, y0 + end,
                                 tooltip, *common_args)
                    else:
                        draw_data(attr_list[1:], (x0, x1),
                                  (y0 + start, y0 + end), side + 1,
                                  tooltip, total_attrs, *common_args)

            draw_text(side, attr_list[0], (x0, x1), (y0, y1),
                      total_attrs, used_attrs, used_vals, attr_vals)

        def draw_text(side, attr, x0_x1, y0_y1,
                      total_attrs, used_attrs, used_vals, attr_vals):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if side in drawn_sides:
                return

            # the text on the right will be drawn when we are processing
            # visualization of the last value of the first attribute
            if side == 3:
                attr1values = \
                    get_variable_values_sorted(data.domain[used_attrs[0]])
                if used_vals[0] != attr1values[-1]:
                    return

            if not conditionaldict[attr_vals]:
                if side not in draw_positions:
                    draw_positions[side] = (x0, x1, y0, y1)
                return
            else:
                if side in draw_positions:
                    # restore the positions of attribute values and name
                    (x0, x1, y0, y1) = draw_positions[side]

            drawn_sides.add(side)

            values = get_variable_values_sorted(data.domain[attr])
            if side % 2:
                values = values[::-1]

            spaces = spacing * (total_attrs - side) * (len(values) - 1)
            width = x1 - x0 - spaces * (side % 2 == 0)
            height = y1 - y0 - spaces * (side % 2 == 1)

            # calculate position of first attribute
            currpos = 0

            total, counts = get_counts(attr_vals, values)

            aligns = [Qt.AlignTop | Qt.AlignHCenter,
                      Qt.AlignRight | Qt.AlignVCenter,
                      Qt.AlignBottom | Qt.AlignHCenter,
                      Qt.AlignLeft | Qt.AlignVCenter]
            align = aligns[side]
            for i, val in enumerate(values):
                perc = counts[i] / float(total)
                if distributiondict[val] != 0:
                    if side == 0:
                        CanvasText(self.canvas, str(val),
                                   x0 + currpos + width * 0.5 * perc,
                                   y1 + self.ATTR_VAL_OFFSET, align)
                    elif side == 1:
                        CanvasText(self.canvas, str(val),
                                   x0 - self.ATTR_VAL_OFFSET,
                                   y0 + currpos + height * 0.5 * perc, align)
                    elif side == 2:
                        CanvasText(self.canvas, str(val),
                                   x0 + currpos + width * perc * 0.5,
                                   y0 - self.ATTR_VAL_OFFSET, align)
                    else:
                        CanvasText(self.canvas, str(val),
                                   x1 + self.ATTR_VAL_OFFSET,
                                   y0 + currpos + height * 0.5 * perc, align)

                if side % 2 == 0:
                    currpos += perc * width + spacing * (total_attrs - side)
                else:
                    currpos += perc * height + spacing * (total_attrs - side)

            if side == 0:
                CanvasText(
                    self.canvas, attr,
                    x0 + (x1 - x0) / 2,
                    y1 + self.ATTR_VAL_OFFSET + self.ATTR_NAME_OFFSET,
                    align, bold=1)
            elif side == 1:
                CanvasText(
                    self.canvas, attr,
                    x0 - max_ylabel_w1 - self.ATTR_VAL_OFFSET,
                    y0 + (y1 - y0) / 2,
                    align, bold=1, vertical=True)
            elif side == 2:
                CanvasText(
                    self.canvas, attr,
                    x0 + (x1 - x0) / 2,
                    y0 - self.ATTR_VAL_OFFSET - self.ATTR_NAME_OFFSET,
                    align, bold=1)
            else:
                CanvasText(
                    self.canvas, attr,
                    x1 + max_ylabel_w2 + self.ATTR_VAL_OFFSET,
                    y0 + (y1 - y0) / 2,
                    align, bold=1, vertical=True)

        def add_rect(x0, x1, y0, y1, condition,
                     used_attrs, used_vals, attr_vals=""):
            area_index = len(self.areas)
            if x0 == x1:
                x1 += 1
            if y0 == y1:
                y1 += 1

            # rectangles of width and height 1 are not shown - increase
            if x1 - x0 + y1 - y0 == 2:
                y1 += 1

            if class_var:
                colors = [QColor(*col) for col in class_var.colors]
            else:
                colors = None

            def select_area(_, ev):
                self.select_area(area_index, ev)

            def rect(x, y, w, h, z, pen_color=None, brush_color=None, **args):
                if pen_color is None:
                    return CanvasRectangle(
                        self.canvas, x, y, w, h, z=z, onclick=select_area,
                        **args)
                if brush_color is None:
                    brush_color = pen_color
                return CanvasRectangle(
                    self.canvas, x, y, w, h, pen_color, brush_color, z=z,
                    onclick=select_area, **args)

            def line(x1, y1, x2, y2):
                r = QGraphicsLineItem(x1, y1, x2, y2, None)
                self.canvas.addItem(r)
                r.setPen(QPen(Qt.white, 2))
                r.setZValue(30)

            outer_rect = rect(x0, y0, x1 - x0, y1 - y0, 30)
            self.areas.append((used_attrs, used_vals, outer_rect))
            if not conditionaldict[attr_vals]:
                return

            if self.interior_coloring == self.PEARSON:
                s = sum(apriori_dists[0])
                expected = s * reduce(
                    mul,
                    (apriori_dists[i][used_vals[i]] / float(s)
                     for i in range(len(used_vals))))
                actual = conditionaldict[attr_vals]
                pearson = (actual - expected) / sqrt(expected)
                if pearson == 0:
                    ind = 0
                else:
                    ind = max(0, min(int(log(abs(pearson), 2)), 3))
                color = [self.RED_COLORS, self.BLUE_COLORS][pearson > 0][ind]
                rect(x0, y0, x1 - x0, y1 - y0, -20, color)
                outer_rect.setToolTip(
                    condition + "<hr/>" +
                    "Expected instances: %.1f<br>"
                    "Actual instances: %d<br>"
                    "Standardized (Pearson) residual: %.1f" %
                    (expected, conditionaldict[attr_vals], pearson))
            else:
                cls_values = get_variable_values_sorted(class_var)
                prior = get_distribution(data, class_var.name)
                total = 0
                for i, value in enumerate(cls_values):
                    val = conditionaldict[attr_vals + "-" + value]
                    if val == 0:
                        continue
                    if i == len(cls_values) - 1:
                        v = y1 - y0 - total
                    else:
                        v = ((y1 - y0) * val) / conditionaldict[attr_vals]
                    rect(x0, y0 + total, x1 - x0, v, -20, colors[i])
                    total += v

                if self.use_boxes and \
                        abs(x1 - x0) > bar_width and \
                        abs(y1 - y0) > bar_width:
                    total = 0
                    line(x0 + bar_width, y0, x0 + bar_width, y1)
                    n = sum(prior)
                    for i, (val, color) in enumerate(zip(prior, colors)):
                        if i == len(prior) - 1:
                            h = y1 - y0 - total
                        else:
                            h = (y1 - y0) * val / n
                        rect(x0, y0 + total, bar_width, h, 20, color)
                        total += h

                if conditionalsubsetdict:
                    if conditionalsubsetdict[attr_vals]:
                        if self.subset_indices is not None:
                            line(x1 - bar_width, y0, x1 - bar_width, y1)
                            total = 0
                            n = conditionalsubsetdict[attr_vals]
                            if n:
                                for i, (cls, color) in \
                                        enumerate(zip(cls_values, colors)):
                                    val = conditionalsubsetdict[
                                        attr_vals + "-" + cls]
                                    if val == 0:
                                        continue
                                    if i == len(prior) - 1:
                                        v = y1 - y0 - total
                                    else:
                                        v = ((y1 - y0) * val) / n
                                    rect(x1 - bar_width, y0 + total,
                                         bar_width, v, 15, color)
                                    total += v

                actual = [conditionaldict[attr_vals + "-" + cls_values[i]]
                          for i in range(len(prior))]
                n_actual = sum(actual)
                if n_actual > 0:
                    apriori = [prior[key] for key in cls_values]
                    n_apriori = sum(apriori)
                    text = "<br/>".join(
                        "<b>%s</b>: %d / %.1f%% (Expected %.1f / %.1f%%)" %
                        (cls, act, 100.0 * act / n_actual,
                         apr / n_apriori * n_actual, 100.0 * apr / n_apriori)
                        for cls, act, apr in zip(cls_values, actual, apriori))
                else:
                    text = ""
                outer_rect.setToolTip(
                    "{}<hr>Instances: {}<br><br>{}".format(
                        condition, n_actual, text[:-4]))

        def draw_legend(x0_x1, y0_y1):
            x0, x1 = x0_x1
            _, y1 = y0_y1
            if self.interior_coloring == self.PEARSON:
                names = ["<-8", "-8:-4", "-4:-2", "-2:2", "2:4", "4:8", ">8",
                         "Residuals:"]
                colors = self.RED_COLORS[::-1] + self.BLUE_COLORS[1:]
            else:
                names = get_variable_values_sorted(class_var) + \
                        [class_var.name + ":"]
                colors = [QColor(*col) for col in class_var.colors]

            names = [CanvasText(self.canvas, name, alignment=Qt.AlignVCenter)
                     for name in names]
            totalwidth = sum(text.boundingRect().width() for text in names)

            # compute the x position of the center of the legend
            y = y1 + self.ATTR_NAME_OFFSET + self.ATTR_VAL_OFFSET + 35
            distance = 30
            startx = (x0 + x1) / 2 - (totalwidth + (len(names)) * distance) / 2

            names[-1].setPos(startx + 15, y)
            names[-1].show()
            xoffset = names[-1].boundingRect().width() + distance

            size = 8

            for i in range(len(names) - 1):
                if self.interior_coloring == self.PEARSON:
                    edgecolor = Qt.black
                else:
                    edgecolor = colors[i]

                CanvasRectangle(self.canvas, startx + xoffset, y - size / 2,
                                size, size, edgecolor, colors[i])
                names[i].setPos(startx + xoffset + 10, y)
                xoffset += distance + names[i].boundingRect().width()

        self.canvas.clear()
        self.areas = []

        data = self.discrete_data
        if data is None:
            return
        attr_list = self.get_attr_list()
        class_var = data.domain.class_var
        if class_var:
            sql = type(data) == SqlTable
            name = not sql and data.name
            # save class_var because it is removed in the next line
            data = data[:, attr_list + [class_var]]
            data.domain.class_var = class_var
            if not sql:
                data.name = name
        else:
            data = data[:, attr_list]
        # TODO: check this
        # data = Preprocessor_dropMissing(data)
        if len(data) == 0:
            self.Warning.no_valid_data()
            return
        else:
            self.Warning.no_valid_data.clear()

        attrs = [attr for attr in attr_list if not data.domain[attr].values]
        if attrs:
            CanvasText(self.canvas,
                       "Feature {} has no values".format(attrs[0]),
                       (self.canvas_view.width() - 120) / 2,
                       self.canvas_view.height() / 2)
            return
        if self.interior_coloring == self.PEARSON:
            apriori_dists = [get_distribution(data, attr) for attr in attr_list]
        else:
            apriori_dists = []

        def get_max_label_width(attr):
            values = get_variable_values_sorted(data.domain[attr])
            maxw = 0
            for val in values:
                t = CanvasText(self.canvas, val, 0, 0, bold=0, show=False)
                maxw = max(int(t.boundingRect().width()), maxw)
            return maxw

        # get the maximum width of rectangle
        xoff = 20
        width = 20
        if len(attr_list) > 1:
            text = CanvasText(self.canvas, attr_list[1], bold=1, show=0)
            max_ylabel_w1 = min(get_max_label_width(attr_list[1]), 150)
            width = 5 + text.boundingRect().height() + \
                self.ATTR_VAL_OFFSET + max_ylabel_w1
            xoff = width
            if len(attr_list) == 4:
                text = CanvasText(self.canvas, attr_list[3], bold=1, show=0)
                max_ylabel_w2 = min(get_max_label_width(attr_list[3]), 150)
                width += text.boundingRect().height() + \
                    self.ATTR_VAL_OFFSET + max_ylabel_w2 - 10

        # get the maximum height of rectangle
        height = 100
        yoff = 45
        square_size = min(self.canvas_view.width() - width - 20,
                          self.canvas_view.height() - height - 20)

        if square_size < 0:
            return  # canvas is too small to draw rectangles
        self.canvas_view.setSceneRect(
            0, 0, self.canvas_view.width(), self.canvas_view.height())

        drawn_sides = set()
        draw_positions = {}

        conditionaldict, distributiondict = \
            get_conditional_distribution(data, attr_list)
        conditionalsubsetdict = None
        if self.subset_indices:
            conditionalsubsetdict, _ = \
                get_conditional_distribution(self.discrete_data[self.subset_indices], attr_list)

        # draw rectangles
        draw_data(
            attr_list, (xoff, xoff + square_size), (yoff, yoff + square_size),
            0, "", len(attr_list), [], [])
        draw_legend((xoff, xoff + square_size), (yoff, yoff + square_size))
        self.update_selection_rects()
Beispiel #8
0
class OWChoropleth(widget.OWWidget):
    name = 'Choropleth'
    description = 'A thematic map in which areas are shaded in proportion ' \
                  'to the measurement of the statistical variable being displayed.'
    icon = "icons/Choropleth.svg"
    priority = 120

    inputs = [("Data", Table, "set_data", widget.Default)]

    outputs = [("Selected Data", Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Table)]

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    AGG_FUNCS = (
        'Count',
        'Count defined',
        'Sum',
        'Mean',
        'Median',
        'Mode',
        'Max',
        'Min',
        'Std',
    )
    AGG_FUNCS_TRANSFORM = {
        'Count': 'size',
        'Count defined': 'count',
        'Mode': lambda x: stats.mode(x, nan_policy='omit').mode[0],
    }
    AGG_FUNCS_DISCRETE = ('Count', 'Count defined', 'Mode')
    AGG_FUNCS_CANT_TIME = ('Count', 'Count defined', 'Sum', 'Std')

    autocommit = settings.Setting(True)
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    attr = settings.ContextSetting('')
    agg_func = settings.ContextSetting(AGG_FUNCS[0])
    admin = settings.Setting(0)
    opacity = settings.Setting(70)
    color_steps = settings.Setting(5)
    color_quantization = settings.Setting('equidistant')
    show_labels = settings.Setting(True)
    show_legend = settings.Setting(True)
    show_details = settings.Setting(True)
    selection = settings.ContextSetting([])

    class Error(widget.OWWidget.Error):
        aggregation_discrete = widget.Msg(
            "Only certain types of aggregation defined on categorical attributes: {}"
        )

    class Warning(widget.OWWidget.Warning):
        logarithmic_nonpositive = widget.Msg(
            "Logarithmic quantization requires all values > 0. Using 'equidistant' quantization instead."
        )

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletChoropleth(self)
        self.mainArea.layout().addWidget(map)
        self.selection = []
        self.data = None
        self.latlon = None
        self.result_min_nonpositive = False

        def selectionChanged(selection):
            self._indices = self.ids.isin(selection).nonzero()[0]
            self.selection = selection
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        box = gui.vBox(self.controlArea, 'Aggregation')

        self._latlon_model = DomainModel(parent=self,
                                         valid_types=ContinuousVariable)
        self._combo_lat = combo = gui.comboBox(box,
                                               self,
                                               'lat_attr',
                                               orientation=Qt.Horizontal,
                                               label='Latitude:',
                                               sendSelectedValue=True,
                                               callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_lon = combo = gui.comboBox(box,
                                               self,
                                               'lon_attr',
                                               orientation=Qt.Horizontal,
                                               label='Longitude:',
                                               sendSelectedValue=True,
                                               callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_attr = combo = gui.comboBox(box,
                                                self,
                                                'attr',
                                                orientation=Qt.Horizontal,
                                                label='Attribute:',
                                                sendSelectedValue=True,
                                                callback=self.aggregate)
        combo.setModel(
            DomainModel(parent=self,
                        valid_types=(ContinuousVariable, DiscreteVariable)))

        gui.comboBox(box,
                     self,
                     'agg_func',
                     orientation=Qt.Horizontal,
                     items=self.AGG_FUNCS,
                     label='Aggregation:',
                     sendSelectedValue=True,
                     callback=self.aggregate)

        self._detail_slider = gui.hSlider(box,
                                          self,
                                          'admin',
                                          None,
                                          0,
                                          2,
                                          1,
                                          label='Administrative level:',
                                          labelFormat=' %d',
                                          callback=self.aggregate)

        box = gui.vBox(self.controlArea, 'Visualization')

        gui.spin(box,
                 self,
                 'color_steps',
                 3,
                 15,
                 1,
                 label='Color steps:',
                 callback=lambda: self.map.set_color_steps(self.color_steps))

        def _set_quantization():
            self.Warning.logarithmic_nonpositive(
                shown=(self.color_quantization.startswith('log')
                       and self.result_min_nonpositive))
            self.map.set_quantization(self.color_quantization)

        gui.comboBox(box,
                     self,
                     'color_quantization',
                     label='Color quantization:',
                     orientation=Qt.Horizontal,
                     sendSelectedValue=True,
                     items=('equidistant', 'logarithmic', 'quantile',
                            'k-means'),
                     callback=_set_quantization)

        self._opacity_slider = gui.hSlider(
            box,
            self,
            'opacity',
            None,
            20,
            100,
            5,
            label='Opacity:',
            labelFormat=' %d%%',
            callback=lambda: self.map.set_opacity(self.opacity))

        gui.checkBox(box,
                     self,
                     'show_legend',
                     label='Show legend',
                     callback=lambda: self.map.toggle_legend(self.show_legend))
        gui.checkBox(
            box,
            self,
            'show_labels',
            label='Show map labels',
            callback=lambda: self.map.toggle_map_labels(self.show_labels))
        gui.checkBox(box,
                     self,
                     'show_details',
                     label='Show region details in tooltip',
                     callback=lambda: self.map.toggle_tooltip_details(
                         self.show_details))

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        self.map.toggle_legend(self.show_legend)
        self.map.toggle_map_labels(self.show_labels)
        self.map.toggle_tooltip_details(self.show_details)
        self.map.set_quantization(self.color_quantization)
        self.map.set_color_steps(self.color_steps)
        self.map.set_opacity(self.opacity)

    def __del__(self):
        self.progressBarFinished(None)
        self.map = None

    def commit(self):
        self.send(
            'Selected Data', self.data[self._indices]
            if self.data is not None and self.selection else None)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.data, self._indices))

    def set_data(self, data):
        self.data = data

        self.closeContext()

        self.clear()

        if data is None:
            return

        self._combo_attr.model().set_domain(data.domain)
        self._latlon_model.set_domain(data.domain)

        lat, lon = find_lat_lon(data)
        if lat or lon:
            self._combo_lat.setCurrentIndex(
                -1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(
                -1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name if lat else None
            self.lon_attr = lon.name if lon else None
            if lat and lon:
                self.latlon = np.c_[
                    self.data.get_column_view(self.lat_attr)[0],
                    self.data.get_column_view(self.lon_attr)[0]]

        if data.domain.class_var:
            self.attr = data.domain.class_var.name
        else:
            self.attr = self._combo_attr.itemText(0)

        self.openContext(data)

        if self.selection:
            self.map.preset_region_selection(self.selection)
        self.aggregate()
        self.map.fit_to_bounds()

    def aggregate(self):
        if self.latlon is None or self.attr not in self.data.domain:
            self.clear(caches=False)
            return

        attr = self.data.domain[self.attr]

        if attr.is_discrete and self.agg_func not in self.AGG_FUNCS_DISCRETE:
            self.Error.aggregation_discrete(', '.join(
                map(str.lower, self.AGG_FUNCS_DISCRETE)))
            self.Warning.logarithmic_nonpositive.clear()
            self.clear(caches=False)
            return
        else:
            self.Error.aggregation_discrete.clear()

        try:
            regions, adm0, result, self.map.bounds = \
                self.get_grouped(self.lat_attr, self.lon_attr, self.admin, self.attr, self.agg_func)
        except ValueError:
            # This might happen if widget scheme File→Choropleth, and
            # some attr is selected in choropleth, and then the same attr
            # is set to string attr in File and dataset reloaded.
            # Our "dataflow" arch can suck my balls
            return
        discrete_values = list(
            attr.values) if attr.is_discrete and not self.agg_func.startswith(
                'Count') else []

        self.result_min_nonpositive = attr.is_continuous and result.min() <= 0
        force_quantization = self.color_quantization.startswith(
            'log') and self.result_min_nonpositive
        self.Warning.logarithmic_nonpositive(shown=force_quantization)

        repr_time = isinstance(
            attr,
            TimeVariable) and self.agg_func not in self.AGG_FUNCS_CANT_TIME

        self.map.exposeObject(
            'results',
            dict(
                discrete=discrete_values,
                colors=[
                    color_to_hex(i)
                    for i in (attr.colors if discrete_values else (
                        (0, 0,
                         255), (255, 255,
                                0)) if attr.is_discrete else attr.colors[:-1])
                ],  # ???
                regions=list(adm0),
                attr=attr.name,
                have_nonpositive=self.result_min_nonpositive
                or discrete_values,
                values=result.to_dict(),
                repr_vals=result.map(attr.repr_val).to_dict()
                if repr_time else {},
                minmax=([result.min(), result.max()]
                        if attr.is_discrete and not discrete_values else [
                            attr.repr_val(result.min()),
                            attr.repr_val(result.max())
                        ] if repr_time or not discrete_values else [])))

        self.map.evalJS('replot();')

    @memoize_method(3)
    def get_regions(self, lat_attr, lon_attr, admin):
        latlon = np.c_[self.data.get_column_view(lat_attr)[0],
                       self.data.get_column_view(lon_attr)[0]]
        regions = latlon2region(latlon, admin)
        adm0 = ({'0'} if admin == 0 else {
            '1-' + a3
            for a3 in (i.get('adm0_a3') for i in regions) if a3
        } if admin == 1 else {('2-' if a3 in ADMIN2_COUNTRIES else '1-') + a3
                              for a3 in (i.get('adm0_a3') for i in regions)
                              if a3})
        ids = [i.get('_id') for i in regions]
        self.ids = pd.Series(ids)
        regions = set(ids) - {None}
        bounds = get_bounding_rect(regions) if regions else None
        return regions, ids, adm0, bounds

    @memoize_method(6)
    def get_grouped(self, lat_attr, lon_attr, admin, attr, agg_func):
        log.debug('Grouping %s(%s) by (%s, %s; admin%d)', agg_func, attr,
                  lat_attr, lon_attr, admin)
        regions, ids, adm0, bounds = self.get_regions(lat_attr, lon_attr,
                                                      admin)
        attr = self.data.domain[attr]
        result = pd.Series(self.data.get_column_view(attr)[0], dtype=float)\
            .groupby(ids)\
            .agg(self.AGG_FUNCS_TRANSFORM.get(agg_func, agg_func.lower()))
        return regions, adm0, result, bounds

    def clear(self, caches=True):
        if caches:
            try:
                self.get_regions.cache_clear()
                self.get_grouped.cache_clear()
            except AttributeError:
                pass  # back-compat https://github.com/biolab/orange3/pull/2229
        self.selection = []
        self.map.exposeObject('results', {})
        self.map.evalJS('replot();')
class ImagePlot(QWidget, OWComponent, SelectionGroupMixin,
                ImageColorSettingMixin, ImageZoomMixin, ConcurrentMixin):

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    gamma = Setting(0)

    selection_changed = Signal()
    image_updated = Signal()

    def __init__(self, parent):
        QWidget.__init__(self)
        OWComponent.__init__(self, parent)
        SelectionGroupMixin.__init__(self)
        ImageColorSettingMixin.__init__(self)
        ImageZoomMixin.__init__(self)
        ConcurrentMixin.__init__(self)
        self.parent = parent

        self.selection_type = SELECTMANY
        self.saving_enabled = True
        self.selection_enabled = True
        self.viewtype = INDIVIDUAL  # required bt InteractiveViewBox
        self.highlighted = None
        self.data_points = None
        self.data_values = None
        self.data_imagepixels = None
        self.data_valid_positions = None

        self.plotview = pg.GraphicsLayoutWidget()
        self.plot = pg.PlotItem(background="w",
                                viewBox=InteractiveViewBox(self))
        self.plotview.addItem(self.plot)

        self.legend = ImageColorLegend()
        self.plotview.addItem(self.legend)

        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        layout = QVBoxLayout()
        self.setLayout(layout)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().addWidget(self.plotview)

        self.img = ImageItemNan()
        self.img.setOpts(axisOrder='row-major')
        self.plot.addItem(self.img)
        self.vis_img = pg.ImageItem()
        self.vis_img.setOpts(axisOrder='row-major')
        self.plot.vb.setAspectLocked()
        self.plot.scene().sigMouseMoved.connect(self.plot.vb.mouseMovedEvent)

        layout = QGridLayout()
        self.plotview.setLayout(layout)
        self.button = QPushButton("Menu", self.plotview)
        self.button.setAutoDefault(False)

        layout.setRowStretch(1, 1)
        layout.setColumnStretch(1, 1)
        layout.addWidget(self.button, 0, 0)
        view_menu = MenuFocus(self)
        self.button.setMenu(view_menu)

        # prepare interface according to the new context
        self.parent.contextAboutToBeOpened.connect(
            lambda x: self.init_interface_data(x[0]))

        actions = []

        self.add_zoom_actions(view_menu)

        select_square = QAction(
            "Select (square)",
            self,
            triggered=self.plot.vb.set_mode_select_square,
        )
        select_square.setShortcuts([Qt.Key_S])
        select_square.setShortcutContext(Qt.WidgetWithChildrenShortcut)
        actions.append(select_square)

        select_polygon = QAction(
            "Select (polygon)",
            self,
            triggered=self.plot.vb.set_mode_select_polygon,
        )
        select_polygon.setShortcuts([Qt.Key_P])
        select_polygon.setShortcutContext(Qt.WidgetWithChildrenShortcut)
        actions.append(select_polygon)

        if self.saving_enabled:
            save_graph = QAction(
                "Save graph",
                self,
                triggered=self.save_graph,
            )
            save_graph.setShortcuts(
                [QKeySequence(Qt.ControlModifier | Qt.Key_I)])
            actions.append(save_graph)

        view_menu.addActions(actions)
        self.addActions(actions)

        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True)

        choose_xy = QWidgetAction(self)
        box = gui.vBox(self)
        box.setFocusPolicy(Qt.TabFocus)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=DomainModel.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self.update_attr,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(box,
                                      self,
                                      "attr_y",
                                      label="Axis y:",
                                      callback=self.update_attr,
                                      model=self.xy_model,
                                      **common_options)
        box.setFocusProxy(self.cb_attr_x)

        box.layout().addWidget(self.color_settings_box())

        choose_xy.setDefaultWidget(box)
        view_menu.addAction(choose_xy)

        self.lsx = None  # info about the X axis
        self.lsy = None  # info about the Y axis

        self.data = None
        self.data_ids = {}

    def init_interface_data(self, data):
        same_domain = (self.data and data and data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)

    def help_event(self, ev):
        pos = self.plot.vb.mapSceneToView(ev.scenePos())
        sel = self._points_at_pos(pos)
        prepared = []
        if sel is not None:
            data, vals, points = self.data[sel], self.data_values[
                sel], self.data_points[sel]
            for d, v, p in zip(data, vals, points):
                basic = "({}, {}): {}".format(p[0], p[1], v)
                variables = [
                    v for v in self.data.domain.metas +
                    self.data.domain.class_vars
                    if v not in [self.attr_x, self.attr_y]
                ]
                features = [
                    '{} = {}'.format(attr.name, d[attr]) for attr in variables
                ]
                prepared.append("\n".join([basic] + features))
        text = "\n\n".join(prepared)
        if text:
            text = ('<span style="white-space:pre">{}</span>'.format(
                escape(text)))
            QToolTip.showText(ev.screenPos(), text, widget=self.plotview)
            return True
        else:
            return False

    def update_attr(self):
        self.update_view()

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def save_graph(self):
        saveplot.save_plot(self.plotview, self.parent.graph_writers)

    def set_data(self, data):
        if data:
            self.data = data
            self.data_ids = {e: i for i, e in enumerate(data.ids)}
            self.restore_selection_settings()
        else:
            self.data = None
            self.data_ids = {}

    def refresh_img_selection(self):
        selected_px = np.zeros((self.lsy[2], self.lsx[2]), dtype=np.uint8)
        selected_px[self.data_imagepixels[self.data_valid_positions, 0],
                    self.data_imagepixels[self.data_valid_positions, 1]] = \
            self.selection_group[self.data_valid_positions]
        self.img.setSelection(selected_px)

    def make_selection(self, selected):
        """Add selected indices to the selection."""
        add_to_group, add_group, remove = selection_modifiers()
        if self.data and self.lsx and self.lsy:
            if add_to_group:  # both keys - need to test it before add_group
                selnum = np.max(self.selection_group)
            elif add_group:
                selnum = np.max(self.selection_group) + 1
            elif remove:
                selnum = 0
            else:
                self.selection_group *= 0
                selnum = 1
            if selected is not None:
                self.selection_group[selected] = selnum
            self.refresh_img_selection()
        self.prepare_settings_for_saving()
        self.selection_changed.emit()

    def select_square(self, p1, p2):
        """ Select elements within a square drawn by the user.
        A selection needs to contain whole pixels """
        x1, y1 = p1.x(), p1.y()
        x2, y2 = p2.x(), p2.y()
        polygon = [
            QPointF(x1, y1),
            QPointF(x2, y1),
            QPointF(x2, y2),
            QPointF(x1, y2),
            QPointF(x1, y1)
        ]
        self.select_polygon(polygon)

    def select_polygon(self, polygon):
        """ Select by a polygon which has to contain whole pixels. """
        if self.data and self.lsx and self.lsy:
            polygon = [(p.x(), p.y()) for p in polygon]
            # a polygon should contain all pixel
            shiftx = _shift(self.lsx)
            shifty = _shift(self.lsy)
            points_edges = [
                self.data_points + [[shiftx, shifty]],
                self.data_points + [[-shiftx, shifty]],
                self.data_points + [[shiftx, -shifty]],
                self.data_points + [[-shiftx, -shifty]]
            ]
            inp = in_polygon(points_edges[0], polygon)
            for p in points_edges[1:]:
                inp *= in_polygon(p, polygon)
            self.make_selection(inp)

    def _points_at_pos(self, pos):
        if self.data and self.lsx and self.lsy:
            x, y = pos.x(), pos.y()
            distance = np.abs(self.data_points - [[x, y]])
            sel = (distance[:, 0] < _shift(self.lsx)) * (distance[:, 1] <
                                                         _shift(self.lsy))
            return sel

    def select_by_click(self, pos):
        sel = self._points_at_pos(pos)
        self.make_selection(sel)

    def update_view(self):
        self.cancel()
        self.parent.Error.image_too_big.clear()
        self.parent.Information.not_shown.clear()
        self.img.clear()
        self.img.setSelection(None)
        self.legend.set_colors(None)
        self.lsx = None
        self.lsy = None
        self.data_points = None
        self.data_values = None
        self.data_imagepixels = None
        self.data_valid_positions = None

        if self.data and self.attr_x and self.attr_y:
            self.start(self.compute_image, self.data, self.attr_x, self.attr_y,
                       self.parent.image_values(),
                       self.parent.image_values_fixed_levels())
        else:
            self.image_updated.emit()

    def set_visible_image(self, img: np.ndarray, rect: QRectF):
        self.vis_img.setImage(img)
        self.vis_img.setRect(rect)

    def show_visible_image(self):
        if self.vis_img not in self.plot.items:
            self.plot.addItem(self.vis_img)

    def hide_visible_image(self):
        self.plot.removeItem(self.vis_img)

    def set_visible_image_opacity(self, opacity: int):
        """Opacity is an alpha channel intensity integer from 0 to 255"""
        self.vis_img.setOpacity(opacity / 255)

    def set_visible_image_comp_mode(self, comp_mode: QPainter.CompositionMode):
        self.vis_img.setCompositionMode(comp_mode)

    @staticmethod
    def compute_image(data: Orange.data.Table, attr_x, attr_y, image_values,
                      image_values_fixed_levels, state: TaskState):
        def progress_interrupt(i: float):
            if state.is_interruption_requested():
                raise InterruptException

        class Result():
            pass

        res = Result()

        xat = data.domain[attr_x]
        yat = data.domain[attr_y]

        def extract_col(data, var):
            nd = Domain([var])
            d = data.transform(nd)
            return d.X[:, 0]

        progress_interrupt(0)

        res.coorx = extract_col(data, xat)
        res.coory = extract_col(data, yat)
        res.data_points = np.hstack(
            [res.coorx.reshape(-1, 1),
             res.coory.reshape(-1, 1)])
        res.lsx = lsx = values_to_linspace(res.coorx)
        res.lsy = lsy = values_to_linspace(res.coory)
        res.image_values_fixed_levels = image_values_fixed_levels
        progress_interrupt(0)

        if lsx[-1] * lsy[-1] > IMAGE_TOO_BIG:
            raise ImageTooBigException((lsx[-1], lsy[-1]))

        # the code below does this, but part-wise:
        # d = image_values(data).X[:, 0]
        parts = []
        for slice in split_to_size(len(data), 10000):
            part = image_values(data[slice]).X[:, 0]
            parts.append(part)
            progress_interrupt(0)
        d = np.concatenate(parts)

        res.d = d
        progress_interrupt(0)

        return res

    def on_done(self, res):

        self.lsx, self.lsy = res.lsx, res.lsy
        lsx, lsy = self.lsx, self.lsy

        d = res.d

        self.fixed_levels = res.image_values_fixed_levels

        self.data_points = res.data_points

        xindex, xnan = index_values_nan(res.coorx, self.lsx)
        yindex, ynan = index_values_nan(res.coory, self.lsy)
        self.data_valid_positions = valid = np.logical_not(
            np.logical_or(xnan, ynan))
        invalid_positions = len(d) - np.sum(valid)
        if invalid_positions:
            self.parent.Information.not_shown(invalid_positions)

        imdata = np.ones((lsy[2], lsx[2])) * float("nan")
        imdata[yindex[valid], xindex[valid]] = d[valid]
        self.data_values = d
        self.data_imagepixels = np.vstack((yindex, xindex)).T

        self.img.setImage(imdata, autoLevels=False)
        self.update_levels()
        self.update_color_schema()
        self.update_legend_visible()

        # shift centres of the pixels so that the axes are useful
        shiftx = _shift(lsx)
        shifty = _shift(lsy)
        left = lsx[0] - shiftx
        bottom = lsy[0] - shifty
        width = (lsx[1] - lsx[0]) + 2 * shiftx
        height = (lsy[1] - lsy[0]) + 2 * shifty
        self.img.setRect(QRectF(left, bottom, width, height))

        self.refresh_img_selection()
        self.image_updated.emit()

    def on_partial_result(self, result):
        pass

    def on_exception(self, ex: Exception):
        if isinstance(ex, InterruptException):
            return

        if isinstance(ex, ImageTooBigException):
            self.parent.Error.image_too_big(ex.args[0][0], ex.args[0][1])
            self.image_updated.emit()
        else:
            raise ex
Beispiel #10
0
class OWFeatureStatistics(widget.OWWidget):
    name = 'Feature Statistics'
    description = 'Show basic statistics for data features.'
    icon = 'icons/FeatureStatistics.svg'

    class Inputs:
        data = Input('Data', Table, default=True)

    class Outputs:
        reduced_data = Output('Reduced Data', Table, default=True)
        statistics = Output('Statistics', Table)

    want_main_area = True
    buttons_area_orientation = Qt.Vertical

    settingsHandler = DomainContextHandler()

    auto_commit = ContextSetting(True)
    color_var = ContextSetting(None)  # type: Optional[Variable]
    # filter_string = ContextSetting('')

    sorting = ContextSetting((0, Qt.DescendingOrder))
    selected_rows = ContextSetting([])

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Table]

        # Information panel
        info_box = gui.vBox(self.controlArea, 'Info')
        info_box.setMinimumWidth(200)
        self.info_summary = gui.widgetLabel(info_box, wordWrap=True)
        self.info_attr = gui.widgetLabel(info_box, wordWrap=True)
        self.info_class = gui.widgetLabel(info_box, wordWrap=True)
        self.info_meta = gui.widgetLabel(info_box, wordWrap=True)
        self.set_info()

        # TODO: Implement filtering on the model
        # filter_box = gui.vBox(self.controlArea, 'Filter')
        # self.filter_text = gui.lineEdit(
        #     filter_box, self, value='filter_string',
        #     placeholderText='Filter variables by name',
        #     callback=self._filter_table_variables, callbackOnType=True,
        # )
        # shortcut = QShortcut(QKeySequence('Ctrl+f'), self, self.filter_text.setFocus)
        # shortcut.setWhatsThis('Filter variables by name')

        self.color_var_model = DomainModel(
            valid_types=(ContinuousVariable, DiscreteVariable),
            placeholder='None',
        )
        box = gui.vBox(self.controlArea, 'Histogram')
        self.cb_color_var = gui.comboBox(
            box, master=self, value='color_var', model=self.color_var_model,
            label='Color:', orientation=Qt.Horizontal,
        )
        self.cb_color_var.activated.connect(self.__color_var_changed)

        gui.rubber(self.controlArea)
        gui.auto_commit(
            self.buttonsArea, self, 'auto_commit', 'Send Selected Rows',
            'Send Automatically',
        )

        # Main area
        self.model = FeatureStatisticsTableModel(parent=self)
        self.table_view = FeatureStatisticsTableView(self.model, parent=self)
        self.table_view.selectionModel().selectionChanged.connect(self.on_select)
        self.table_view.horizontalHeader().sectionClicked.connect(self.on_header_click)

        self.mainArea.layout().addWidget(self.table_view)

    def sizeHint(self):
        return QSize(1050, 500)

    def _filter_table_variables(self):
        regex = QRegExp(self.filter_string)
        # If the user explicitly types different cases, we assume they know
        # what they are searching for and account for letter case in filter
        different_case = (
            any(c.islower() for c in self.filter_string) and
            any(c.isupper() for c in self.filter_string)
        )
        if not different_case:
            regex.setCaseSensitivity(Qt.CaseInsensitive)

    @Inputs.data
    def set_data(self, data):
        # Clear outputs and reset widget state
        self.closeContext()
        self.selected_rows = []
        self.model.resetSorting()
        self.Outputs.reduced_data.send(None)
        self.Outputs.statistics.send(None)

        # Setup widget state for new data and restore settings
        self.data = data

        if data is not None:
            self.color_var_model.set_domain(data.domain)
            if self.data.domain.class_vars:
                self.color_var = self.data.domain.class_vars[0]
        else:
            self.color_var_model.set_domain(None)
            self.color_var = None
        self.model.set_data(data)

        self.openContext(self.data)
        self.__restore_selection()
        self.__restore_sorting()
        # self._filter_table_variables()
        self.__color_var_changed()

        self.set_info()
        self.commit()

    def __restore_selection(self):
        """Restore the selection on the table view from saved settings."""
        selection_model = self.table_view.selectionModel()
        selection = QItemSelection()
        if len(self.selected_rows):
            for row in self.model.mapFromSourceRows(self.selected_rows):
                selection.append(QItemSelectionRange(
                    self.model.index(row, 0),
                    self.model.index(row, self.model.columnCount() - 1)
                ))
        selection_model.select(selection, QItemSelectionModel.ClearAndSelect)

    def __restore_sorting(self):
        """Restore the sort column and order from saved settings."""
        sort_column, sort_order = self.sorting
        if sort_column < self.model.columnCount():
            self.model.sort(sort_column, sort_order)
            self.table_view.horizontalHeader().setSortIndicator(sort_column, sort_order)

    @pyqtSlot(int)
    def on_header_click(self, *_):
        # Store the header states
        sort_order = self.model.sortOrder()
        sort_column = self.model.sortColumn()
        self.sorting = sort_column, sort_order

    @pyqtSlot(int)
    def __color_var_changed(self, *_):
        if self.model is not None:
            self.model.set_target_var(self.color_var)

    def _format_variables_string(self, variables):
        agg = []
        for var_type_name, var_type in [
                ('categorical', DiscreteVariable),
                ('numeric', ContinuousVariable),
                ('time', TimeVariable),
                ('string', StringVariable)
        ]:
            # Disable pylint here because a `TimeVariable` is also a
            # `ContinuousVariable`, and should be labelled as such. That is why
            # it is necessary to check the type this way instead of using
            # `isinstance`, which would fail in the above case
            var_type_list = [v for v in variables if type(v) is var_type]  # pylint: disable=unidiomatic-typecheck
            if var_type_list:
                shown = var_type in self.model.HIDDEN_VAR_TYPES
                agg.append((
                    '%d %s%s' % (len(var_type_list), var_type_name, ['', ' (not shown)'][shown]),
                    len(var_type_list)
                ))

        if not agg:
            return 'No variables'

        attrs, counts = list(zip(*agg))
        if len(attrs) > 1:
            var_string = ', '.join(attrs[:-1]) + ' and ' + attrs[-1]
        else:
            var_string = attrs[0]
        return plural('%s variable{s}' % var_string, sum(counts))

    def set_info(self):
        if self.data is not None:
            self.info_summary.setText('<b>%s</b> contains %s with %s' % (
                self.data.name,
                plural('{number} instance{s}', self.model.n_instances),
                plural('{number} feature{s}', self.model.n_attributes)
            ))

            self.info_attr.setText(
                '<b>Attributes:</b><br>%s' %
                self._format_variables_string(self.data.domain.attributes)
            )
            self.info_class.setText(
                '<b>Class variables:</b><br>%s' %
                self._format_variables_string(self.data.domain.class_vars)
            )
            self.info_meta.setText(
                '<b>Metas:</b><br>%s' %
                self._format_variables_string(self.data.domain.metas)
            )
        else:
            self.info_summary.setText('No data on input.')
            self.info_attr.setText('')
            self.info_class.setText('')
            self.info_meta.setText('')

    def on_select(self):
        self.selected_rows = self.model.mapToSourceRows([
            i.row() for i in self.table_view.selectionModel().selectedRows()
        ])
        self.commit()

    def commit(self):
        if not len(self.selected_rows):
            self.Outputs.reduced_data.send(None)
            self.Outputs.statistics.send(None)
            return

        # Send a table with only selected columns to output
        variables = self.model.variables[self.selected_rows]
        self.Outputs.reduced_data.send(self.data[:, variables])

        # Send the statistics of the selected variables to ouput
        labels, data = self.model.get_statistics_matrix(variables, return_labels=True)
        var_names = np.atleast_2d([var.name for var in variables]).T
        domain = Domain(
            attributes=[ContinuousVariable(name) for name in labels],
            metas=[StringVariable('Feature')]
        )
        statistics = Table(domain, data, metas=var_names)
        statistics.name = '%s (Feature Statistics)' % self.data.name
        self.Outputs.statistics.send(statistics)

    def send_report(self):
        pass
class OWPieChart(widget.OWWidget):
    name = "Pie Chart"
    description = "Make fun of Pie Charts."
    icon = "icons/PieChart.svg"
    priority = 100

    class Inputs:
        data = Input("Data", Orange.data.Table)

    settingsHandler = DomainContextHandler()
    attribute = ContextSetting(None)
    split_var = ContextSetting(None)
    explode = Setting(False)
    graph_name = "scene"

    def __init__(self):
        super().__init__()
        self.dataset = None

        self.attrs = DomainModel(
            valid_types=Orange.data.DiscreteVariable, separators=False)
        cb = gui.comboBox(
            self.controlArea, self, "attribute", box=True,
            model=self.attrs, callback=self.update_scene, contentsLength=12)
        grid = QGridLayout()
        self.legend = gui.widgetBox(gui.indentedBox(cb.box), orientation=grid)
        grid.setColumnStretch(1, 1)
        grid.setHorizontalSpacing(6)
        self.legend_items = []
        self.split_vars = DomainModel(
            valid_types=Orange.data.DiscreteVariable, separators=False,
            placeholder="None", )
        gui.comboBox(
            self.controlArea, self, "split_var", box="Split by",
            model=self.split_vars, callback=self.update_scene)
        gui.checkBox(
            self.controlArea, self, "explode", "Explode pies", box=True,
            callback=self.update_scene)
        gui.rubber(self.controlArea)
        gui.widgetLabel(
            gui.hBox(self.controlArea, box=True),
            "The aim of this widget is to\n"
            "demonstrate that pie charts are\n"
            "a terrible visualization. Please\n"
            "don't use it for any other purpose.")

        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHints(
            QPainter.Antialiasing | QPainter.TextAntialiasing |
            QPainter.SmoothPixmapTransform)
        self.mainArea.layout().addWidget(self.view)
        self.mainArea.setMinimumWidth(600)

    def sizeHint(self):
        return QSize(200, 150)  # Horizontal size is regulated by mainArea

    @Inputs.data
    def set_data(self, dataset):
        if dataset is not None and (
                not bool(dataset) or not len(dataset.domain)):
            dataset = None
        self.closeContext()
        self.dataset = dataset
        self.attribute = None
        self.split_var = None
        domain = dataset.domain if dataset is not None else None
        self.attrs.set_domain(domain)
        self.split_vars.set_domain(domain)
        if dataset is not None:
            self.select_default_variables(domain)
            self.openContext(self.dataset)
        self.update_scene()

    def select_default_variables(self, domain):
        if len(self.attrs) > len(domain.class_vars):
            first_attr = self.split_vars[len(domain.class_vars)]
        else:
            first_attr = None
        if len(self.attrs):
            self.attribute, self.split_var = self.attrs[0], first_attr
        else:
            self.attribute, self.split_var = self.split_var, None

    def update_scene(self):
        self.scene.clear()
        if self.dataset is None or self.attribute is None:
            return
        dists, labels = self.compute_box_data()
        colors = self.attribute.colors
        for x, (dist, label) in enumerate(zip(dists, labels)):
            self.pie_chart(SCALE * x, 0, 0.8 * SCALE, dist, colors)
            self.pie_label(SCALE * x, 0, label)
        self.update_legend(
            [QColor(*col) for col in colors], self.attribute.values)
        self.view.centerOn(SCALE * len(dists) / 2, 0)

    def update_legend(self, colors, labels):
        layout = self.legend.layout()
        while self.legend_items:
            w = self.legend_items.pop()
            layout.removeWidget(w)
            w.deleteLater()
        for row, (color, label) in enumerate(zip(colors, labels)):
            icon = QLabel()
            p = QPixmap(12, 12)
            p.fill(color)
            icon.setPixmap(p)
            label = QLabel(label)
            layout.addWidget(icon, row, 0)
            layout.addWidget(label, row, 1, alignment=Qt.AlignLeft)
            self.legend_items += (icon, label)

    def pie_chart(self, x, y, r, dist, colors):
        start_angle = 0
        dist = np.asarray(dist)
        spans = dist / (float(np.sum(dist)) or 1) * 360 * 16
        for span, color in zip(spans, colors):
            if not span:
                continue
            if self.explode:
                mid_ang = (start_angle + span / 2) / 360 / 16 * 2 * pi
                dx = r / 30 * cos(mid_ang)
                dy = r / 30 * sin(mid_ang)
            else:
                dx = dy = 0
            ellipse = QGraphicsEllipseItem(x - r / 2 + dx, y - r / 2 - dy, r, r)
            if len(spans) > 1:
                ellipse.setStartAngle(start_angle)
                ellipse.setSpanAngle(span)
            ellipse.setBrush(QColor(*color))
            self.scene.addItem(ellipse)
            start_angle += span

    def pie_label(self, x, y, label):
        if not label:
            return
        text = QGraphicsSimpleTextItem(label)
        for cut in range(1, len(label)):
            if text.boundingRect().width() < 0.95 * SCALE:
                break
            text = QGraphicsSimpleTextItem(label[:-cut] + "...")
        text.setPos(x - text.boundingRect().width() / 2, y + 0.5 * SCALE)
        self.scene.addItem(text)

    def compute_box_data(self):
        if self.split_var:
            return (
                contingency.get_contingency(
                    self.dataset, self.attribute, self.split_var),
                self.split_var.values)
        else:
            return [
                distribution.get_distribution(
                    self.dataset, self.attribute)], [""]

    def send_report(self):
        self.report_plot()
        text = ""
        if self.attribute is not None:
            text += "Box plot for '{}' ".format(self.attribute.name)
        if self.split_var is not None:
            text += "split by '{}'".format(self.split_var.name)
        if text:
            self.report_caption(text)
Beispiel #12
0
class OWTestLearners(OWWidget):
    name = "Test & Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message(
            "Click on the table header to select shown columns",
            "click_header")]

    settingsHandler = settings.PerfectDomainContextHandler(metas_in_res=True)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    BUILTIN_ORDER = {
        DiscreteVariable: ("AUC", "CA", "F1", "Precision", "Recall"),
        ContinuousVariable: ("MSE", "RMSE", "MAE", "R2")}

    shown_scores = \
        settings.Setting(set(chain(*BUILTIN_ORDER.values())))

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train data set is empty.")
        test_data_empty = Msg("Test data set is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg("Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train data sets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        only_one_class_var_value = Msg("Target variable has only one value.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(
            sbox, self, "resampling", callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_folds", label="Number of folds: ",
            items=[str(x) for x in self.NFolds], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.kfold_changed)
        gui.checkBox(
            ibox, self, "cv_stratified", "Stratified",
            callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(
            ibox, self, "fold_feature", model=self.feature_model,
            orientation=Qt.Horizontal, callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_repeats", label="Repeat train/test: ",
            items=[str(x) for x in self.NRepeats], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.shuffle_split_changed)
        gui.comboBox(
            ibox, self, "sample_size", label="Training set size: ",
            items=["{} %".format(x) for x in self.SampleSizes],
            maximumContentsLength=5, orientation=Qt.Horizontal,
            callback=self.shuffle_split_changed)
        gui.checkBox(
            ibox, self, "shuffle_stratified", "Stratified",
            callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox, self, "class_selection", items=[],
            sendSelectedValue=True, valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)

        self.view = gui.TableView(
            wordWrap=True,
        )
        header = self.view.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setStretchLastSection(False)
        header.setContextMenuPolicy(Qt.CustomContextMenu)
        header.customContextMenuRequested.connect(self.show_column_chooser)

        self.result_model = QStandardItemModel(self)
        self.result_model.setHorizontalHeaderLabels(["Method"])
        self.view.setModel(self.result_model)
        self.view.setItemDelegate(ItemDelegate())

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.view)

    def sizeHint(self):
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        else:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not len(data):
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [not data.domain.class_vars,
                     len(data.domain.class_vars) > 1,
                     data.domain.has_discrete_class and len(data.domain.class_var.values) == 1]
            errors = [self.Error.class_required,
                      self.Error.too_many_classes,
                      self.Error.only_one_class_var_value]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not len(data):
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {(True, True): " ",  # both, don't specify
                (True, False): " train ",
                (False, True): " test "}[(self.train_data_missing_vals,
                                          self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data is None or self.data.domain.class_var is None:
            self.scorers = []
            return
        class_var = self.data and self.data.domain.class_var
        order = {name: i
                 for i, name in enumerate(self.BUILTIN_ORDER[type(class_var)])}
        # 'abstract' is retrieved from __dict__ to avoid inheriting
        usable = (cls for cls in scoring.Score.registry.values()
                  if cls.is_scalar and not cls.__dict__.get("abstract")
                  and isinstance(class_var, cls.class_types))
        self.scorers = sorted(usable, key=lambda cls: order.get(cls.name, 99))

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self._update_header()
        self._update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._invalidate()
        self.__update()

    def _update_header(self):
        # Set the correct horizontal header labels on the results_model.
        model = self.result_model
        model.setColumnCount(1 + len(self.scorers))
        for col, score in enumerate(self.scorers):
            item = QStandardItem(score.name)
            item.setToolTip(score.long_name)
            model.setHorizontalHeaderItem(col + 1, item)
        self._update_shown_columns()

    def _update_shown_columns(self):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            header.setSectionHidden(section, col_name not in self.shown_scores)

    def _update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.view.model()
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            if isinstance(slot.results, Try.Fail):
                head.setToolTip(str(slot.results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                errors.append("{name} failed with error:\n"
                              "{exc.__class__.__name__}: {exc!s}"
                              .format(name=name, exc=slot.results.exception))

            row = [head]

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(
                        slot.results.value, target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [Try(scorer_caller(scorer, ovr_results))
                             for scorer in self.scorers]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat in stats:
                    item = QStandardItem()
                    if stat.success:
                        item.setText("{:.3f}".format(stat.value[0]))
                    else:
                        item.setToolTip(str(stat.exception))
                        has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self._update_stats_model()

    def _invalidate(self, which=None):
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.view.model()
        statmodelkeys = [model.item(row, 0).data(Qt.UserRole)
                         for row in range(model.rowCount())]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.__needupdate = True

    def show_column_chooser(self, pos):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        def update(col_name, checked):
            if checked:
                self.shown_scores.add(col_name)
            else:
                self.shown_scores.remove(col_name)
            self._update_shown_columns()

        menu = QMenu()
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            action = menu.addAction(col_name)
            action.setCheckable(True)
            action.setChecked(col_name in self.shown_scores)
            action.triggered.connect(partial(update, col_name))
        menu.exec(header.mapToGlobal(pos))

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [slot for slot in self.learners.values()
                 if slot.results is not None and slot.results.success]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [learner_name(slot.learner)
                                      for slot in valid]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".
                      format(stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [("Sampling type",
                      "{}Shuffle split, {} random samples with {}% data "
                      .format(stratified, self.NRepeats[self.n_repeats],
                              self.SampleSizes[self.sample_size]))]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value, processEvents=False)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Warning.test_data_missing.clear()
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        common_args = dict(
            store_data=True,
            preprocessor=self.preprocessor,
        )
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.KFold:
            folds = self.NFolds[self.n_folds]
            test_f = partial(
                Orange.evaluation.CrossValidation,
                self.data, learners_c, k=folds,
                random_state=rstate, **common_args)
        elif self.resampling == OWTestLearners.FeatureFold:
            test_f = partial(
                Orange.evaluation.CrossValidationFeature,
                self.data, learners_c, self.fold_feature,
                **common_args
            )
        elif self.resampling == OWTestLearners.LeaveOneOut:
            test_f = partial(
                Orange.evaluation.LeaveOneOut,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.ShuffleSplit:
            train_size = self.SampleSizes[self.sample_size] / 100
            test_f = partial(
                Orange.evaluation.ShuffleSplit,
                self.data, learners_c,
                n_resamples=self.NRepeats[self.n_repeats],
                train_size=train_size, test_size=None,
                stratified=self.shuffle_stratified,
                random_state=rstate, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTrain:
            test_f = partial(
                Orange.evaluation.TestOnTrainingData,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData,
                self.data, self.test_data, learners_c, **common_args
            )
        else:
            assert False, "self.resampling %s" % self.resampling

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[float]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = Task()

        def progress_callback(finished):
            if task.cancelled:
                raise UserInterrupt()
            QMetaObject.invokeMethod(
                self, "setProgressValue", Qt.QueuedConnection,
                Q_ARG(float, 100 * finished)
            )

        def ondone(_):
            QMetaObject.invokeMethod(
                self, "__task_complete", Qt.QueuedConnection,
                Q_ARG(object, task))

        testfunc = partial(testfunc, callback=progress_callback)
        task.future = self.__executor.submit(testfunc)
        task.future.add_done_callback(ondone)

        self.progressBarInit(processEvents=None)
        self.setBlocking(True)
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, task):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        if self.__task is not task:
            assert task.cancelled
            log.debug("Reaping cancelled task: %r", "<>")
            return

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)
        self.setStatusMessage("")
        result = task.future
        assert result.done()
        self.__task = None
        try:
            results = result.result()    # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:
            log.exception("testing error (in __task_complete):",
                          exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er), er)))
            self.__state = State.Done
            return

        self.__state = State.Done

        learner_key = {slot.learner: key for key, slot in
                       self.learners.items()}
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [Try(scorer_caller(scorer, result))
                             for scorer in self.scorers]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self._update_header()
        self._update_stats_model()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            assert task.future.done()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
class OWSelectRows(widget.OWWidget):
    name = "Select Rows"
    id = "Orange.widgets.data.file"
    description = "Select rows from the data based on values of variables."
    icon = "icons/SelectRows.svg"
    priority = 100
    category = "Data"
    keywords = ["filter"]

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        matching_data = Output("Matching Data", Table, default=True)
        unmatched_data = Output("Unmatched Data", Table)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    want_main_area = False

    settingsHandler = SelectRowsContextHandler()
    conditions = ContextSetting([])
    update_on_change = Setting(True)
    purge_attributes = Setting(False, schema_only=True)
    purge_classes = Setting(False, schema_only=True)
    auto_commit = Setting(True)

    settings_version = 2

    Operators = {
        ContinuousVariable: [
            (FilterContinuous.Equal, "equals"),
            (FilterContinuous.NotEqual, "is not"),
            (FilterContinuous.Less, "is below"),
            (FilterContinuous.LessEqual, "is at most"),
            (FilterContinuous.Greater, "is greater than"),
            (FilterContinuous.GreaterEqual, "is at least"),
            (FilterContinuous.Between, "is between"),
            (FilterContinuous.Outside, "is outside"),
            (FilterContinuous.IsDefined, "is defined"),
        ],
        DiscreteVariable: [
            (FilterDiscreteType.Equal, "is"),
            (FilterDiscreteType.NotEqual, "is not"),
            (FilterDiscreteType.In, "is one of"),
            (FilterDiscreteType.IsDefined, "is defined")
        ],
        StringVariable: [
            (FilterString.Equal, "equals"),
            (FilterString.NotEqual, "is not"),
            (FilterString.Less, "is before"),
            (FilterString.LessEqual, "is equal or before"),
            (FilterString.Greater, "is after"),
            (FilterString.GreaterEqual, "is equal or after"),
            (FilterString.Between, "is between"),
            (FilterString.Outside, "is outside"),
            (FilterString.Contains, "contains"),
            (FilterString.StartsWith, "begins with"),
            (FilterString.EndsWith, "ends with"),
            (FilterString.IsDefined, "is defined"),
        ]
    }

    Operators[TimeVariable] = Operators[ContinuousVariable]

    AllTypes = {}
    for _all_name, _all_type, _all_ops in (
            ("All variables", 0,
             [(None, "are defined")]),
            ("All numeric variables", 2,
             [(v, _plural(t)) for v, t in Operators[ContinuousVariable]]),
            ("All string variables", 3,
             [(v, _plural(t)) for v, t in Operators[StringVariable]])):
        Operators[_all_name] = _all_ops
        AllTypes[_all_name] = _all_type

    operator_names = {vtype: [name for _, name in filters]
                      for vtype, filters in Operators.items()}

    class Error(widget.OWWidget.Error):
        parsing_error = Msg("{}")

    def __init__(self):
        super().__init__()

        self.old_purge_classes = True

        self.conditions = []
        self.last_output_conditions = None
        self.data = None
        self.data_desc = self.match_desc = self.nonmatch_desc = None
        self.variable_model = DomainModel(
            [list(self.AllTypes), DomainModel.Separator,
             DomainModel.CLASSES, DomainModel.ATTRIBUTES, DomainModel.METAS])

        box = gui.vBox(self.controlArea, 'Conditions', stretch=100)
        self.cond_list = QTableWidget(
            box, showGrid=False, selectionMode=QTableWidget.NoSelection)
        box.layout().addWidget(self.cond_list)
        self.cond_list.setColumnCount(4)
        self.cond_list.setRowCount(0)
        self.cond_list.verticalHeader().hide()
        self.cond_list.horizontalHeader().hide()
        for i in range(3):
            self.cond_list.horizontalHeader().setSectionResizeMode(i, QHeaderView.Stretch)
        self.cond_list.horizontalHeader().resizeSection(3, 30)
        self.cond_list.viewport().setBackgroundRole(QPalette.Window)

        box2 = gui.hBox(box)
        gui.rubber(box2)
        self.add_button = gui.button(
            box2, self, "Add Condition", callback=self.add_row)
        self.add_all_button = gui.button(
            box2, self, "Add All Variables", callback=self.add_all)
        self.remove_all_button = gui.button(
            box2, self, "Remove All", callback=self.remove_all)
        gui.rubber(box2)

        box_setting = gui.vBox(self.buttonsArea)
        self.cb_pa = gui.checkBox(
            box_setting, self, "purge_attributes", "Remove unused features",
            callback=self.conditions_changed)
        self.cb_pc = gui.checkBox(
            box_setting, self, "purge_classes", "Remove unused classes",
            callback=self.conditions_changed)

        self.report_button.setFixedWidth(120)
        gui.rubber(self.buttonsArea.layout())

        acbox = gui.auto_send(self.buttonsArea, self, "auto_commit")

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        self.set_data(None)
        self.resize(600, 400)

    def add_row(self, attr=None, condition_type=None, condition_value=None):
        model = self.cond_list.model()
        row = model.rowCount()
        model.insertRow(row)

        attr_combo = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
        attr_combo.setModel(self.variable_model)
        attr_combo.row = row
        attr_combo.setCurrentIndex(self.variable_model.indexOf(attr) if attr
                                   else len(self.AllTypes) + 1)
        self.cond_list.setCellWidget(row, 0, attr_combo)

        index = QPersistentModelIndex(model.index(row, 3))
        temp_button = QPushButton('×', self, flat=True,
                                  styleSheet='* {font-size: 16pt; color: silver}'
                                             '*:hover {color: black}')
        temp_button.clicked.connect(lambda: self.remove_one(index.row()))
        self.cond_list.setCellWidget(row, 3, temp_button)

        self.remove_all_button.setDisabled(False)
        self.set_new_operators(attr_combo, attr is not None,
                               condition_type, condition_value)
        attr_combo.currentIndexChanged.connect(
            lambda _: self.set_new_operators(attr_combo, False))

        self.cond_list.resizeRowToContents(row)

    def add_all(self):
        if self.cond_list.rowCount():
            Mb = QMessageBox
            if Mb.question(
                    self, "Remove existing filters",
                    "This will replace the existing filters with "
                    "filters for all variables.", Mb.Ok | Mb.Cancel) != Mb.Ok:
                return
            self.remove_all()
        for attr in self.variable_model[len(self.AllTypes) + 1:]:
            self.add_row(attr)
        self.conditions_changed()

    def remove_one(self, rownum):
        self.remove_one_row(rownum)
        self.conditions_changed()

    def remove_all(self):
        self.remove_all_rows()
        self.conditions_changed()

    def remove_one_row(self, rownum):
        self.cond_list.removeRow(rownum)
        if self.cond_list.model().rowCount() == 0:
            self.remove_all_button.setDisabled(True)

    def remove_all_rows(self):
        # Disconnect signals to avoid stray emits when changing variable_model
        for row in range(self.cond_list.rowCount()):
            for col in (0, 1):
                widg = self.cond_list.cellWidget(row, col)
                if widg:
                    widg.currentIndexChanged.disconnect()
        self.cond_list.clear()
        self.cond_list.setRowCount(0)
        self.remove_all_button.setDisabled(True)

    def set_new_operators(self, attr_combo, adding_all,
                          selected_index=None, selected_values=None):
        old_combo = self.cond_list.cellWidget(attr_combo.row, 1)
        prev_text = old_combo.currentText() if old_combo else ""
        oper_combo = QComboBox()
        oper_combo.row = attr_combo.row
        oper_combo.attr_combo = attr_combo
        attr_name = attr_combo.currentText()
        if attr_name in self.AllTypes:
            oper_combo.addItems(self.operator_names[attr_name])
        else:
            var = self.data.domain[attr_name]
            oper_combo.addItems(self.operator_names[type(var)])
        if selected_index is None:
            selected_index = oper_combo.findText(prev_text)
            if selected_index == -1:
                selected_index = 0
        oper_combo.setCurrentIndex(selected_index)
        self.cond_list.setCellWidget(oper_combo.row, 1, oper_combo)
        self.set_new_values(oper_combo, adding_all, selected_values)
        oper_combo.currentIndexChanged.connect(
            lambda _: self.set_new_values(oper_combo, False))

    @staticmethod
    def _get_lineedit_contents(box):
        contents = []
        for child in getattr(box, "controls", [box]):
            if isinstance(child, QLineEdit):
                contents.append(child.text())
            elif isinstance(child, DateTimeWidget):
                if child.format == (0, 1):
                    contents.append(child.time())
                elif child.format == (1, 0):
                    contents.append(child.date())
                elif child.format == (1, 1):
                    contents.append(child.dateTime())
        return contents

    @staticmethod
    def _get_value_contents(box):
        cont = []
        names = []
        for child in getattr(box, "controls", [box]):
            if isinstance(child, QLineEdit):
                cont.append(child.text())
            elif isinstance(child, QComboBox):
                cont.append(child.currentIndex())
            elif isinstance(child, QToolButton):
                if child.popup is not None:
                    model = child.popup.list_view.model()
                    for row in range(model.rowCount()):
                        item = model.item(row)
                        if item.checkState():
                            cont.append(row + 1)
                            names.append(item.text())
                    child.desc_text = ', '.join(names)
                    child.set_text()
            elif isinstance(child, DateTimeWidget):
                if child.format == (0, 1):
                    cont.append(child.time())
                elif child.format == (1, 0):
                    cont.append(child.date())
                elif child.format == (1, 1):
                    cont.append(child.dateTime())
            elif isinstance(child, QLabel) or child is None:
                pass
            else:
                raise TypeError('Type %s not supported.' % type(child))
        return tuple(cont)

    class QDoubleValidatorEmpty(QDoubleValidator):
        def validate(self, input_, pos):
            if not input_:
                return QDoubleValidator.Acceptable, input_, pos
            if self.locale().groupSeparator() in input_:
                return QDoubleValidator.Invalid, input_, pos
            return super().validate(input_, pos)

    def set_new_values(self, oper_combo, adding_all, selected_values=None):
        # def remove_children():
        #     for child in box.children()[1:]:
        #         box.layout().removeWidget(child)
        #         child.setParent(None)

        def add_textual(contents):
            le = gui.lineEdit(box, self, None,
                              sizePolicy=QSizePolicy(QSizePolicy.Expanding,
                                                     QSizePolicy.Expanding))
            if contents:
                le.setText(contents)
            le.setAlignment(Qt.AlignRight)
            le.editingFinished.connect(self.conditions_changed)
            return le

        def add_numeric(contents):
            le = add_textual(contents)
            le.setValidator(OWSelectRows.QDoubleValidatorEmpty())
            return le

        box = self.cond_list.cellWidget(oper_combo.row, 2)
        lc = ["", ""]
        oper = oper_combo.currentIndex()
        attr_name = oper_combo.attr_combo.currentText()
        if attr_name in self.AllTypes:
            vtype = self.AllTypes[attr_name]
            var = None
        else:
            var = self.data.domain[attr_name]
            var_idx = self.data.domain.index(attr_name)
            vtype = vartype(var)
            if selected_values is not None:
                lc = list(selected_values) + ["", ""]
                lc = [str(x) if vtype != 4 else x for x in lc[:2]]
        if box and vtype == box.var_type:
            lc = self._get_lineedit_contents(box) + lc

        if oper_combo.currentText().endswith(" defined"):
            label = QLabel()
            label.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, label)
        elif var is not None and var.is_discrete:
            if oper_combo.currentText().endswith(" one of"):
                if selected_values:
                    lc = list(selected_values)
                button = DropDownToolButton(self, var, lc)
                button.var_type = vtype
                self.cond_list.setCellWidget(oper_combo.row, 2, button)
            else:
                combo = ComboBoxSearch()
                combo.addItems(("", ) + var.values)
                if lc[0]:
                    combo.setCurrentIndex(int(lc[0]))
                else:
                    combo.setCurrentIndex(0)
                combo.var_type = vartype(var)
                self.cond_list.setCellWidget(oper_combo.row, 2, combo)
                combo.currentIndexChanged.connect(self.conditions_changed)
        else:
            box = gui.hBox(self, addToLayout=False)
            box.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, box)
            if vtype == 2:  # continuous:
                box.controls = [add_numeric(lc[0])]
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_numeric(lc[1]))
            elif vtype == 3:  # string:
                box.controls = [add_textual(lc[0])]
                if oper in [6, 7]:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_textual(lc[1]))
            elif vtype == 4:  # time:
                def invalidate_datetime():
                    if w_:
                        if w.dateTime() > w_.dateTime():
                            w_.setDateTime(w.dateTime())
                        if w.format == (1, 1):
                            w.calendarWidget.timeedit.setTime(w.time())
                            w_.calendarWidget.timeedit.setTime(w_.time())
                    elif w.format == (1, 1):
                        w.calendarWidget.timeedit.setTime(w.time())

                def datetime_changed():
                    self.conditions_changed()
                    invalidate_datetime()

                datetime_format = (var.have_date, var.have_time)
                column = self.data.get_column_view(var_idx)[0]
                w = DateTimeWidget(self, column, datetime_format)
                w.set_datetime(lc[0])
                box.controls = [w]
                box.layout().addWidget(w)
                w.dateTimeChanged.connect(datetime_changed)
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    w_ = DateTimeWidget(self, column, datetime_format)
                    w_.set_datetime(lc[1])
                    box.layout().addWidget(w_)
                    box.controls.append(w_)
                    invalidate_datetime()
                    w_.dateTimeChanged.connect(datetime_changed)
                else:
                    w_ = None
            else:
                box.controls = []
        if not adding_all:
            self.conditions_changed()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.cb_pa.setEnabled(not isinstance(data, SqlTable))
        self.cb_pc.setEnabled(not isinstance(data, SqlTable))
        self.remove_all_rows()
        self.add_button.setDisabled(data is None)
        self.add_all_button.setDisabled(
            data is None or
            len(data.domain.variables) + len(data.domain.metas) > 100)
        if not data:
            self.info.set_input_summary(self.info.NoInput)
            self.data_desc = None
            self.variable_model.set_domain(None)
            self.commit()
            return
        self.data_desc = report.describe_data_brief(data)
        self.variable_model.set_domain(data.domain)

        self.conditions = []
        self.openContext(data)
        for attr, cond_type, cond_value in self.conditions:
            if attr in self.variable_model:
                self.add_row(attr, cond_type, cond_value)
        if not self.cond_list.model().rowCount():
            self.add_row()

        self.info.set_input_summary(data.approx_len(),
                                    format_summary_details(data))
        self.unconditional_commit()

    def conditions_changed(self):
        try:
            cells_by_rows = (
                [self.cond_list.cellWidget(row, col) for col in range(3)]
                for row in range(self.cond_list.rowCount())
            )
            self.conditions = [
                (var_cell.currentData(gui.TableVariable) or var_cell.currentText(),
                 oper_cell.currentIndex(),
                 self._get_value_contents(val_cell))
                for var_cell, oper_cell, val_cell in cells_by_rows]
            if self.update_on_change and (
                    self.last_output_conditions is None or
                    self.last_output_conditions != self.conditions):
                self.commit()
        except AttributeError:
            # Attribute error appears if the signal is triggered when the
            # controls are being constructed
            pass

    @staticmethod
    def _values_to_floats(attr, values):
        if len(values) == 0:
            return values
        if not all(values):
            return None
        if isinstance(attr, TimeVariable):
            values = (value.toString(format=Qt.ISODate) for value in values)
            parse = lambda x: (attr.parse(x), True)
        else:
            parse = QLocale().toDouble

        try:
            floats, ok = zip(*[parse(v) for v in values])
            if not all(ok):
                raise ValueError('Some values could not be parsed as floats'
                                 'in the current locale: {}'.format(values))
        except TypeError:
            floats = values  # values already floats
        assert all(isinstance(v, float) for v in floats)
        return floats

    def commit(self):
        matching_output = self.data
        non_matching_output = None
        annotated_output = None

        self.Error.clear()
        if self.data:
            domain = self.data.domain
            conditions = []
            for attr_name, oper_idx, values in self.conditions:
                if attr_name in self.AllTypes:
                    attr_index = attr = None
                    attr_type = self.AllTypes[attr_name]
                    operators = self.Operators[attr_name]
                else:
                    attr_index = domain.index(attr_name)
                    attr = domain[attr_index]
                    attr_type = vartype(attr)
                    operators = self.Operators[type(attr)]
                opertype, _ = operators[oper_idx]
                if attr_type == 0:
                    filt = data_filter.IsDefined()
                elif attr_type in (2, 4):  # continuous, time
                    try:
                        floats = self._values_to_floats(attr, values)
                    except ValueError as e:
                        self.Error.parsing_error(e.args[0])
                        return
                    if floats is None:
                        continue
                    filt = data_filter.FilterContinuous(
                        attr_index, opertype, *floats)
                elif attr_type == 3:  # string
                    filt = data_filter.FilterString(
                        attr_index, opertype, *[str(v) for v in values])
                else:
                    if opertype == FilterDiscreteType.IsDefined:
                        f_values = None
                    else:
                        if not values or not values[0]:
                            continue
                        values = [attr.values[i-1] for i in values]
                        if opertype == FilterDiscreteType.Equal:
                            f_values = {values[0]}
                        elif opertype == FilterDiscreteType.NotEqual:
                            f_values = set(attr.values)
                            f_values.remove(values[0])
                        elif opertype == FilterDiscreteType.In:
                            f_values = set(values)
                        else:
                            raise ValueError("invalid operand")
                    filt = data_filter.FilterDiscrete(attr_index, f_values)
                conditions.append(filt)

            if conditions:
                filters = data_filter.Values(conditions)
                matching_output = filters(self.data)
                filters.negate = True
                non_matching_output = filters(self.data)

                row_sel = np.in1d(self.data.ids, matching_output.ids)
                annotated_output = create_annotated_table(self.data, row_sel)

            # if hasattr(self.data, "name"):
            #     matching_output.name = self.data.name
            #     non_matching_output.name = self.data.name

            purge_attrs = self.purge_attributes
            purge_classes = self.purge_classes
            if (purge_attrs or purge_classes) and \
                    not isinstance(self.data, SqlTable):
                attr_flags = sum([Remove.RemoveConstant * purge_attrs,
                                  Remove.RemoveUnusedValues * purge_attrs])
                class_flags = sum([Remove.RemoveConstant * purge_classes,
                                   Remove.RemoveUnusedValues * purge_classes])
                # same settings used for attributes and meta features
                remover = Remove(attr_flags, class_flags, attr_flags)

                matching_output = remover(matching_output)
                non_matching_output = remover(non_matching_output)
                annotated_output = remover(annotated_output)

        if not matching_output:
            matching_output = None
        if not non_matching_output:
            non_matching_output = None
        if not annotated_output:
            annotated_output = None

        self.Outputs.matching_data.send(matching_output)
        self.Outputs.unmatched_data.send(non_matching_output)
        self.Outputs.annotated_data.send(annotated_output)

        self.match_desc = report.describe_data_brief(matching_output)
        self.nonmatch_desc = report.describe_data_brief(non_matching_output)

        summary = matching_output.approx_len() if matching_output else \
            self.info.NoOutput
        details = format_summary_details(matching_output) if matching_output else ""
        self.info.set_output_summary(summary, details)

    def send_report(self):
        if not self.data:
            self.report_paragraph("No data.")
            return

        pdesc = None
        describe_domain = False
        for d in (self.data_desc, self.match_desc, self.nonmatch_desc):
            if not d or not d["Data instances"]:
                continue
            ndesc = d.copy()
            del ndesc["Data instances"]
            if pdesc is not None and pdesc != ndesc:
                describe_domain = True
            pdesc = ndesc

        conditions = []
        for attr, oper, values in self.conditions:
            if isinstance(attr, str):
                attr_name = attr
                var_type = self.AllTypes[attr]
                names = self.operator_names[attr_name]
            else:
                attr_name = attr.name
                var_type = vartype(attr)
                names = self.operator_names[type(attr)]
            name = names[oper]
            if oper == len(names) - 1:
                conditions.append("{} {}".format(attr_name, name))
            elif var_type == 1:  # discrete
                if name == "is one of":
                    valnames = [attr.values[v - 1] for v in values]
                    if not valnames:
                        continue
                    if len(valnames) == 1:
                        valstr = valnames[0]
                    else:
                        valstr = f"{', '.join(valnames[:-1])} or {valnames[-1]}"
                    conditions.append(f"{attr} is {valstr}")
                elif values and values[0]:
                    value = values[0] - 1
                    conditions.append(f"{attr} {name} {attr.values[value]}")
            elif var_type == 3:  # string variable
                conditions.append(
                    f"{attr} {name} {' and '.join(map(repr, values))}")
            elif var_type == 4:  # time
                values = (value.toString(format=Qt.ISODate) for value in values)
                conditions.append(f"{attr} {name} {' and '.join(values)}")
            elif all(x for x in values):  # numeric variable
                conditions.append(f"{attr} {name} {' and '.join(values)}")
        items = OrderedDict()
        if describe_domain:
            items.update(self.data_desc)
        else:
            items["Instances"] = self.data_desc["Data instances"]
        items["Condition"] = " AND ".join(conditions) or "no conditions"
        self.report_items("Data", items)
        if describe_domain:
            self.report_items("Matching data", self.match_desc)
            self.report_items("Non-matching data", self.nonmatch_desc)
        else:
            match_inst = \
                bool(self.match_desc) and \
                self.match_desc["Data instances"]
            nonmatch_inst = \
                bool(self.nonmatch_desc) and \
                self.nonmatch_desc["Data instances"]
            self.report_items(
                "Output",
                (("Matching data",
                  "{} instances".format(match_inst) if match_inst else "None"),
                 ("Non-matching data",
                  nonmatch_inst > 0 and "{} instances".format(nonmatch_inst))))
class   OWStackAlign(OWWidget):
    # Widget's name as displayed in the canvas
    name = "Align Stack"

    # Short widget description
    description = (
        "Aligns and crops a stack of images using various methods.")

    icon = "icons/stackalign.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Stack of images", Table, default=True)

    class Outputs:
        newstack = Output("Aligned image stack", Table, default=True)

    class Error(OWWidget.Error):
        nan_in_image = Msg("Unknown values within images: {} unknowns")
        invalid_axis = Msg("Invalid axis: {}")

    autocommit = settings.Setting(True)

    want_main_area = True
    want_control_area = True
    resizing_enabled = False

    settingsHandler = DomainContextHandler()

    sobel_filter = settings.Setting(False)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    ref_frame_num = settings.Setting(0)

    def __init__(self):
        super().__init__()

        # TODO: add input box for selecting which should be the reference frame
        box = gui.widgetBox(self.controlArea, "Axes")

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=ContinuousVariable)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self._update_attr,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:", callback=self._update_attr,
            model=self.xy_model, **common_options)

        self.contextAboutToBeOpened.connect(self._init_interface_data)

        box = gui.widgetBox(self.controlArea, "Parameters")

        gui.checkBox(box, self, "sobel_filter",
                     label="Use sobel filter",
                     callback=self._sobel_changed)
        gui.separator(box)
        hbox = gui.hBox(box)
        self.le1 = lineEditIntRange(box, self, "ref_frame_num", bottom=1, default=1,
                                    callback=self._ref_frame_changed)
        hbox.layout().addWidget(QLabel("Reference frame:", self))
        hbox.layout().addWidget(self.le1)

        gui.rubber(self.controlArea)

        plot_box = gui.widgetBox(self.mainArea, "Shift curves")
        self.plotview = pg.PlotWidget(background="w")
        plot_box.layout().addWidget(self.plotview)
        # TODO:  resize widget to make it a bit smaller

        self.data = None

        gui.auto_commit(self.controlArea, self, "autocommit", "Send Data")


    def _sanitize_ref_frame(self):
        if self.ref_frame_num > self.data.X.shape[1]:
            self.ref_frame_num = self.data.X.shape[1]

    def _ref_frame_changed(self):
        self._sanitize_ref_frame()
        self.commit()

    def _sobel_changed(self):
        self.commit()

    def _init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def _init_interface_data(self, args):
        data = args[0]
        same_domain = (self.data and data and
                       data.domain == self.data.domain)
        if not same_domain:
            self._init_attr_values(data)

    def _update_attr(self):
        self.commit()

    @Inputs.data
    def set_data(self, dataset):
        self.closeContext()
        self.openContext(dataset)
        if dataset is not None:
            self.data = dataset
            self._sanitize_ref_frame()
        else:
            self.data = None
        self.Error.nan_in_image.clear()
        self.Error.invalid_axis.clear()
        self.commit()

    def commit(self):
        new_stack = None

        self.Error.nan_in_image.clear()
        self.Error.invalid_axis.clear()

        self.plotview.plotItem.clear()

        if self.data and len(self.data.domain.attributes) and self.attr_x and self.attr_y:
            try:
                shifts, new_stack = process_stack(self.data, self.attr_x, self.attr_y,
                                                  upsample_factor=100, use_sobel=self.sobel_filter,
                                                  ref_frame_num=self.ref_frame_num-1)
            except NanInsideHypercube as e:
                self.Error.nan_in_image(e.args[0])
            except InvalidAxisException as e:
                self.Error.invalid_axis(e.args[0])
            else:
                # TODO: label axes
                frames = np.linspace(1, shifts.shape[0], shifts.shape[0])
                self.plotview.plotItem.plot(frames, shifts[:, 0],
                                            pen=pg.mkPen(color=(255, 40, 0), width=3),
                                            symbol='o', symbolBrush=(255, 40, 0), symbolPen='w',
                                            symbolSize=7)
                self.plotview.plotItem.plot(frames, shifts[:, 1],
                                            pen=pg.mkPen(color=(0, 139, 139), width=3),
                                            symbol='o', symbolBrush=(0, 139, 139), symbolPen='w',
                                            symbolSize=7)
                self.plotview.getPlotItem().setLabel('bottom', 'Frame number')
                self.plotview.getPlotItem().setLabel('left', 'Shift / pixel')
                self.plotview.getPlotItem().addLine(self.ref_frame_num,
                                                    pen=pg.mkPen(color=(150, 150, 150), width=3,
                                                                 style=Qt.DashDotDotLine))

        self.Outputs.newstack.send(new_stack)

    def send_report(self):
        self.report_items((
            ("Use sobel filter", str(self.sobel_filter)),
        ))
Beispiel #15
0
class OWAggregateColumns(widget.OWWidget):
    name = "Aggregate Columns"
    description = "Compute a sum, max, min ... of selected columns."
    icon = "icons/AggregateColumns.svg"
    priority = 100
    keywords = [
        "aggregate", "sum", "product", "max", "min", "mean", "median",
        "variance"
    ]

    class Inputs:
        data = Input("Data", Table, default=True)

    class Outputs:
        data = Output("Data", Table)

    want_main_area = False

    settingsHandler = DomainContextHandler()
    variables: List[Variable] = ContextSetting([])
    operation = Setting("Sum")
    var_name = Setting("agg")
    auto_apply = Setting(True)

    Operations = {
        "Sum": np.nansum,
        "Product": np.nanprod,
        "Min": np.nanmin,
        "Max": np.nanmax,
        "Mean": np.nanmean,
        "Variance": np.nanvar,
        "Median": np.nanmedian
    }
    TimePreserving = ("Min", "Max", "Mean", "Median")

    def __init__(self):
        super().__init__()
        self.data = None

        box = gui.vBox(self.controlArea, box=True)

        self.variable_model = DomainModel(order=DomainModel.MIXED,
                                          valid_types=(ContinuousVariable, ))
        var_list = gui.listView(box,
                                self,
                                "variables",
                                model=self.variable_model,
                                callback=self.commit)
        var_list.setSelectionMode(var_list.ExtendedSelection)

        combo = gui.comboBox(box,
                             self,
                             "operation",
                             label="Operator: ",
                             orientation=Qt.Horizontal,
                             items=list(self.Operations),
                             sendSelectedValue=True,
                             callback=self.commit)
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)

        gui.lineEdit(box,
                     self,
                     "var_name",
                     label="Variable name: ",
                     orientation=Qt.Horizontal,
                     callback=self.commit)

        gui.auto_apply(self.controlArea, self)

    @Inputs.data
    def set_data(self, data: Table = None):
        self.closeContext()
        self.variables.clear()
        self.data = data
        if self.data:
            self.variable_model.set_domain(data.domain)
            self.openContext(data)
        else:
            self.variable_model.set_domain(None)
        self.unconditional_commit()

    def commit(self):
        augmented = self._compute_data()
        self.Outputs.data.send(augmented)

    def _compute_data(self):
        if not self.data or not self.variables:
            return self.data

        new_col = self._compute_column()
        new_var = self._new_var()
        return self.data.add_column(new_var, new_col)

    def _compute_column(self):
        arr = np.empty((len(self.data), len(self.variables)))
        for i, var in enumerate(self.variables):
            arr[:, i] = self.data.get_column_view(var)[0].astype(float)
        func = self.Operations[self.operation]
        return func(arr, axis=1)

    def _new_var_name(self):
        return get_unique_names(self.data.domain, self.var_name)

    def _new_var(self):
        name = self._new_var_name()
        if self.operation in self.TimePreserving \
                and all(isinstance(var, TimeVariable) for var in self.variables):
            return TimeVariable(name)
        return ContinuousVariable(name)

    def send_report(self):
        # fp for self.variables, pylint: disable=unsubscriptable-object
        if not self.data or not self.variables:
            return
        var_list = ", ".join(f"'{var.name}'"
                             for var in self.variables[:31][:-1])
        if len(self.variables) > 30:
            var_list += f" and {len(self.variables) - 30} others"
        else:
            var_list += f" and '{self.variables[-1].name}'"
        self.report_items(((
            "Output:",
            f"'{self._new_var_name()}' as {self.operation.lower()} of {var_list}"
        ), ))
class OWSignificantGroups(widget.OWWidget):
    name = 'Significant Groups'
    description = "Test whether instances grouped by nominal values are " \
                  "significantly different from random samples or the "\
                  "dataset in whole."
    icon = 'icons/SignificantGroups.svg'
    priority = 200

    class Inputs(widget.OWWidget.Inputs):
        data = widget.Input('Data', Table)

    class Outputs(widget.OWWidget.Outputs):
        selected_data = widget.Output('Selected Data', Table, default=True)
        data = widget.Output('Data', Table)
        results = widget.Output('Test Results', Table)

    want_main_area = True
    want_control_area = True

    class Information(widget.OWWidget.Information):
        nothing_significant = widget.Msg(
            'Chosen parameters reveal no significant groups')

    class Error(widget.OWWidget.Error):
        no_vars_selected = widget.Msg('No independent variables selected')
        no_class_selected = widget.Msg('No dependent variable selected')

    TEST_STATISTICS = OrderedDict((
        ('mean', np.nanmean),
        ('variance', np.nanvar),
        ('median', np.nanmedian),
        ('minimum', np.nanmin),
        ('maximum', np.nanmax),
    ))

    settingsHandler = settings.DomainContextHandler()

    chosen_X = settings.ContextSetting([])
    chosen_y = settings.ContextSetting(0)
    is_permutation = settings.Setting(False)
    test_statistic = settings.Setting(next(iter(TEST_STATISTICS)))
    min_count = settings.Setting(20)

    def __init__(self):
        self._task = None  # type: Optional[self.Task]
        self._executor = ThreadExecutor(self)

        self.data = None
        self.test_type = ''

        self.discrete_model = DomainModel(separators=False,
                                          valid_types=(DiscreteVariable, ),
                                          parent=self)
        self.domain_model = DomainModel(valid_types=DomainModel.PRIMITIVE,
                                        parent=self)

        box = gui.vBox(self.controlArea, 'Hypotheses Testing')
        gui.listView(
            box,
            self,
            'chosen_X',
            model=self.discrete_model,
            box='Grouping Variables',
            selectionMode=QListView.ExtendedSelection,
            callback=self.Error.no_vars_selected.clear,
            toolTip='Select multiple variables with Ctrl+ or Shift+Click.')
        target = gui.comboBox(
            box,
            self,
            'chosen_y',
            sendSelectedValue=True,
            label='Test Variable',
            callback=[self.set_test_type, self.Error.no_class_selected.clear])
        target.setModel(self.domain_model)

        gui.checkBox(box,
                     self,
                     'is_permutation',
                     label='Permutation test',
                     callback=self.set_test_type)
        gui.comboBox(box,
                     self,
                     'test_statistic',
                     label='Statistic:',
                     items=tuple(self.TEST_STATISTICS),
                     orientation=Qt.Horizontal,
                     sendSelectedValue=True,
                     callback=self.set_test_type)
        gui.label(box, self, 'Test: %(test_type)s')

        box = gui.vBox(self.controlArea, 'Filter')
        gui.spin(box,
                 self,
                 'min_count',
                 5,
                 1000,
                 5,
                 label='Minimum group size (count):')

        self.btn_compute = gui.button(self.controlArea,
                                      self,
                                      '&Compute',
                                      callback=self.compute)
        gui.rubber(self.controlArea)

        class Model(PyTableModel):
            _n_vars = 0
            _BACKGROUND = [QBrush(QColor('#eee')), QBrush(QColor('#ddd'))]

            def setHorizontalHeaderLabels(self, labels, n_vars):
                self._n_vars = n_vars
                super().setHorizontalHeaderLabels(labels)

            def data(self, index, role=Qt.DisplayRole):
                if role == Qt.BackgroundRole and index.column() < self._n_vars:
                    return self._BACKGROUND[index.row() % 2]
                if role == Qt.DisplayRole or role == Qt.ToolTipRole:
                    colname = self.headerData(index.column(), Qt.Horizontal)
                    if colname.lower() in ('count', 'count | class'):
                        row = self.mapToSourceRows(index.row())
                        return int(self[row][index.column()])
                return super().data(index, role)

        owwidget = self

        class View(gui.TableView):
            _vars = None

            def set_vars(self, vars):
                self._vars = vars

            def selectionChanged(self, *args):
                super().selectionChanged(*args)

                rows = list({
                    index.row()
                    for index in self.selectionModel().selectedRows(0)
                })

                if not rows:
                    owwidget.Outputs.data.send(None)
                    return

                model = self.model().tolist()
                filters = [
                    Values([
                        FilterDiscrete(self._vars[col], {model[row][col]})
                        for col in range(len(self._vars))
                    ]) for row in self.model().mapToSourceRows(rows)
                ]
                data = Values(filters, conjunction=False)(owwidget.data)

                annotated = create_annotated_table(owwidget.data, data.ids)

                owwidget.Outputs.selected_data.send(data)
                owwidget.Outputs.data.send(annotated)

        self.view = view = View(self)
        self.model = Model(parent=self)
        view.setModel(self.model)
        view.horizontalHeader().setStretchLastSection(False)
        self.mainArea.layout().addWidget(view)

        self.set_test_type()

    @Inputs.data
    def set_data(self, data):
        self.data = data
        domain = None if data is None else data.domain

        self.closeContext()

        self.domain_model.set_domain(domain)
        self.discrete_model.set_domain(domain)
        if domain is not None:
            if domain.class_var:
                self.chosen_y = domain.class_var.name

        self.openContext(domain)

        self.set_test_type()

    def set_test_type(self):
        if self.data is None:
            return

        yvar = self.data.domain[self.chosen_y]

        self.controls.test_statistic.setEnabled(yvar.is_continuous)

        if self.is_permutation:
            test = 'Permutation '
            if yvar.is_discrete:
                test += 'χ² '
            else:
                test += str(self.test_statistic) + ' '
        else:
            test = ''
            if yvar.is_discrete:
                test += 'χ² ' if len(yvar.values) > 2 else 'Hypergeometric '
            else:
                if self.test_statistic == 'mean':
                    test += "Student's t-"
                elif self.test_statistic == 'variance':
                    test += "Fligner–Killeen "
                elif self.test_statistic == 'median':
                    test += "Mann–Whitney U "
                elif self.test_statistic in ('minimum', 'maximum'):
                    test += "Gumbel distribution "
                else:
                    assert False, self.test_statistic
        test += 'test'
        self.test_type = test

    def compute(self):
        if not self.chosen_X:
            self.Error.no_vars_selected()
            return

        if not self.chosen_y:
            self.Error.no_class_selected()
            return

        # If listview selection was a single item the list of items is not a list,
        # but futher below we expect it to be
        if not isinstance(self.chosen_X, (list, tuple)):
            self.chosen_X = [self.chosen_X]

        self.btn_compute.setEnabled(False)
        yvar = self.data.domain[self.chosen_y]

        def get_col(var, col):
            values = np.array(list(var.values) + [np.nan], dtype=object)
            pd.Categorical(col, list(var.values))
            col = pd.Series(col).fillna(-1).astype(int)
            return values[col]

        X = np.column_stack([
            get_col(var,
                    self.data.get_column_view(var)[0])
            for var in (self.data.domain[i] for i in self.chosen_X)
        ])
        X = pd.DataFrame(X, columns=self.chosen_X)
        y = pd.Series(self.data.get_column_view(yvar)[0])

        test, args, kwargs = None, (X, y), dict(min_count=self.min_count)
        if self.is_permutation:
            statistic = 'chi2' if yvar.is_discrete else self.TEST_STATISTICS[
                self.test_statistic]
            test = perm_test
            kwargs.update(statistic=statistic,
                          n_jobs=-2,
                          callback=methodinvoke(self, "setProgressValue",
                                                (int, int)))
        else:
            if yvar.is_discrete:
                if len(yvar.values) > 2:
                    test = chi2_test
                else:
                    test = hyper_test
                    args = (X, y.astype(bool))
            else:
                test = {
                    'mean': t_test,
                    'variance': fligner_killeen_test,
                    'median': mannwhitneyu_test,
                    'minimum': gumbel_min_test,
                    'maximum': gumbel_max_test,
                }[self.test_statistic]

        self._task = task = self.Task()
        self.progressBarInit()
        task.future = self._executor.submit(test, *args, **kwargs)
        task.watcher = FutureWatcher(task.future)
        task.watcher.done.connect(self.on_computed)

    @Slot(int, int)
    def setProgressValue(self, n, N):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(n / (N + 1) * 100)

    class Task:
        future = ...  # type: concurrent.futures.Future
        watcher = ...  # type: FutureWatcher
        cancelled = False  # type: bool

        def cancel(self):
            self.cancelled = True
            # Cancel the future. Note this succeeds only if the execution has
            # not yet started (see `concurrent.futures.Future.cancel`) ..
            self.future.cancel()
            # ... and wait until computation finishes
            concurrent.futures.wait([self.future])

    @Slot(concurrent.futures.Future)
    def on_computed(self, future):
        assert self.thread() is QThread.currentThread()
        assert future.done()

        self._task = None
        self.progressBarFinished()

        df = future.result()
        # Only retain "significant" p-values
        df = df[df[CORRECTED_LABEL] < .2]

        columns = [var.name for var in df.index.name] + list(df.columns)
        lst = [list(i) + list(j) for i, j in zip(df.index, df.values)]

        results_table = table_from_frame(pd.DataFrame(lst, columns=columns),
                                         force_nominal=True)
        results_table.name = 'Significant Groups'
        self.Outputs.results.send(results_table)

        self.view.set_vars(list(df.index.name))
        self.model.setHorizontalHeaderLabels(columns, len(df.index.name))
        self.model.wrap(lst)
        self.view.sortByColumn(len(columns) - 1, Qt.AscendingOrder)

        self.Information.nothing_significant(shown=not lst)
        self.btn_compute.setEnabled(True)

    def send_report(self):
        self.report_items([
            ('Test Variable', self.chosen_y),
            ('Test', self.test_type),
            ('Min. group size', self.min_count),
        ])
        self.report_table('Significant Groups', self.view)
Beispiel #17
0
class OWScatterPlot(OWWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 2
    settingsHandler = DomainContextHandler()

    auto_send_selection = Setting(True)
    auto_sample = Setting(True)
    toolbar_selection = Setting(0)

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)

    #: Serialized selection state to be restored
    selection_group = Setting(None, schema_only=True)

    graph = SettingProvider(OWScatterPlotGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Information(OWWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")

    def __init__(self):
        super().__init__()

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWScatterPlotGraph(self, box, "ScatterPlot")
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        axispen = QPen(self.palette().color(QPalette.Text))
        axis = plot.getAxis("bottom")
        axis.setPen(axispen)

        axis = plot.getAxis("left")
        axis.setPen(axispen)

        self.data = None  # Orange.data.Table
        self.subset_data = None  # Orange.data.Table
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        #: Remember the saved state to restore
        self.__pending_selection_restore = self.selection_group
        self.selection_group = None

        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str,
                              contentsLength=14)
        box = gui.vBox(self.controlArea, "Axis Data")
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self.update_attr,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(box,
                                      self,
                                      "attr_y",
                                      label="Axis y:",
                                      callback=self.update_attr,
                                      model=self.xy_model,
                                      **common_options)

        vizrank_box = gui.hBox(box)
        gui.separator(vizrank_box, width=common_options["labelWidth"])
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

        gui.separator(box)

        g = self.graph.gui
        g.add_widgets([g.JitterSizeSlider, g.JitterNumericValues], box)

        self.sampling = gui.auto_commit(self.controlArea,
                                        self,
                                        "auto_sample",
                                        "Sample",
                                        box="Sampling",
                                        callback=self.switch_sampling,
                                        commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

        g.point_properties_box(self.controlArea)
        self.models = [self.xy_model] + g.points_models

        box_plot_prop = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([
            g.ShowLegend, g.ShowGridLines, g.ToolTipShowsAll, g.ClassDensity,
            g.RegressionLine, g.LabelOnlySelected
        ], box_plot_prop)

        self.graph.box_zoom_select(self.controlArea)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_send_selection",
                        "Send Selection", "Send Automatically")

        self.graph.zoom_actions(self)

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            for ext in w.EXTENSIONS:
                self.graph_writers[ext] = w

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
                     len([v for v in chain(self.data.domain.variables, self.data.domain.metas)
                          if v.is_primitive]) > 2\
                     and len(self.data) > 1
        self.vizrank_button.setEnabled(
            is_enabled and self.graph.attr_color is not None and not np.isnan(
                self.data.get_column_view(
                    self.graph.attr_color)[0].astype(float)).all())
        if is_enabled and self.graph.attr_color is None:
            self.vizrank_button.setToolTip(
                "Color variable has to be selected.")
        else:
            self.vizrank_button.setToolTip("")

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.Information.sampled_sql.clear()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled_sql()
                self.sql_data = data
                data_sample = data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if data is not None and (len(data) == 0 or len(data.domain) == 0):
            data = None
        if self.data and data and self.data.checksum() == data.checksum():
            return

        self.closeContext()
        same_domain = (self.data and data and data.domain.checksum()
                       == self.data.domain.checksum())
        self.data = data

        if not same_domain:
            self.init_attr_values()
        self.openContext(self.data)
        self._vizrank_color_change()

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Orange.data.Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.graph.attr_label, str):
            self.graph.attr_label = findvar(self.graph.attr_label,
                                            self.graph.gui.label_model)
        if isinstance(self.graph.attr_color, str):
            self.graph.attr_color = findvar(self.graph.attr_color,
                                            self.graph.gui.color_model)
        if isinstance(self.graph.attr_shape, str):
            self.graph.attr_shape = findvar(self.graph.attr_shape,
                                            self.graph.gui.shape_model)
        if isinstance(self.graph.attr_size, str):
            self.graph.attr_size = findvar(self.graph.attr_size,
                                           self.graph.gui.size_model)

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            return self.__timer.stop()
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        self.subset_data = subset_data
        self.controls.graph.alpha_value.setEnabled(subset_data is None)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.graph.new_data(self.data, self.subset_data)
        if self.attribute_selection_list and self.graph.domain is not None and \
                all(attr in self.graph.domain
                        for attr in self.attribute_selection_list):
            self.attr_x = self.attribute_selection_list[0]
            self.attr_y = self.attribute_selection_list[1]
        self.attribute_selection_list = None
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        if self.data is not None and self.__pending_selection_restore is not None:
            self.apply_selection(self.__pending_selection_restore)
            self.__pending_selection_restore = None
        self.unconditional_commit()

    def apply_selection(self, selection):
        """Apply `selection` to the current plot."""
        if self.data is not None:
            self.graph.selection = np.zeros(len(self.data), dtype=np.uint8)
            self.selection_group = [
                x for x in selection if x[0] < len(self.data)
            ]
            selection_array = np.array(self.selection_group).T
            self.graph.selection[selection_array[0]] = selection_array[1]
            self.graph.update_colors(keep_colors=True)

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def init_attr_values(self):
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x
        self.graph.set_domain(data)

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        self.send_features()

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.attr_x, self.attr_y, reset_view)

    def selection_changed(self):

        # Store current selection in a setting that is stored in workflow
        if isinstance(self.data, SqlTable):
            selection = None
        elif self.data is not None:
            selection = self.graph.get_selection()
        else:
            selection = None
        if selection is not None and len(selection):
            self.selection_group = list(
                zip(selection, self.graph.selection[selection]))
        else:
            self.selection_group = None

        self.commit()

    def send_data(self):
        # TODO: Implement selection for sql data
        def _get_selected():
            if not len(selection):
                return None
            return create_groups_table(data, graph.selection, False, "Group")

        def _get_annotated():
            if graph.selection is not None and np.max(graph.selection) > 1:
                return create_groups_table(data, graph.selection)
            else:
                return create_annotated_table(data, selection)

        graph = self.graph
        data = self.data
        selection = graph.get_selection()
        self.Outputs.annotated_data.send(_get_annotated())
        self.Outputs.selected_data.send(_get_selected())

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def commit(self):
        self.send_data()
        self.send_features()

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", (self.attr_x.is_discrete or self.attr_y.is_discrete
                            or self.graph.jitter_continuous)
              and self.graph.jitter_size)))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1)
                                           for a in settings["selection"]]
Beispiel #18
0
class OWAverage(OWWidget):
    # Widget's name as displayed in the canvas
    name = "Average Spectra"

    # Short widget description
    description = (
        "Calculates averages.")

    icon = "icons/average.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        averages = Output("Averages", Orange.data.Table, default=True)

    settingsHandler = settings.DomainContextHandler()
    group_var = settings.ContextSetting(None)

    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    class Warning(OWWidget.Warning):
        nodata = Msg("No useful data on input!")

    def __init__(self):
        super().__init__()

        self.data = None
        self.set_data(self.data)  # show warning

        self.group_vars = DomainModel(
            placeholder="None", separators=False,
            valid_types=Orange.data.DiscreteVariable)
        self.group_view = gui.listView(
            self.controlArea, self, "group_var", box="Group by",
            model=self.group_vars, callback=self.grouping_changed)

        gui.auto_commit(self.controlArea, self, "autocommit", "Apply")


    @Inputs.data
    def set_data(self, dataset):
        self.Warning.nodata.clear()
        self.closeContext()
        self.data = dataset
        self.group_var = None
        if dataset is None:
            self.Warning.nodata()
        else:
            self.group_vars.set_domain(dataset.domain)
            self.openContext(dataset.domain)

        self.commit()

    @staticmethod
    def average_table(table):
        """
        Return a features-averaged table.

        For metas and class_vars,
          - return average value of ContinuousVariable
          - return value of DiscreteVariable, StringVariable and TimeVariable
            if all are the same.
          - return unknown otherwise.
        """
        if len(table) == 0:
            return table
        mean = np.nanmean(table.X, axis=0, keepdims=True)
        avg_table = Orange.data.Table.from_numpy(table.domain,
                                                 X=mean,
                                                 Y=np.atleast_2d(table.Y[0].copy()),
                                                 metas=np.atleast_2d(table.metas[0].copy()))
        cont_vars = [var for var in table.domain.class_vars + table.domain.metas
                     if isinstance(var, Orange.data.ContinuousVariable)]
        for var in cont_vars:
            index = table.domain.index(var)
            col, _ = table.get_column_view(index)
            try:
                avg_table[0, index] = np.nanmean(col)
            except AttributeError:
                # numpy.lib.nanfunctions._replace_nan just guesses and returns
                # a boolean array mask for object arrays because object arrays
                # do not support `isnan` (numpy-gh-9009)
                # Since we know that ContinuousVariable values must be np.float64
                # do an explicit cast here
                avg_table[0, index] = np.nanmean(col, dtype=np.float64)

        other_vars = [var for var in table.domain.class_vars + table.domain.metas
                      if not isinstance(var, Orange.data.ContinuousVariable)]
        for var in other_vars:
            index = table.domain.index(var)
            col, _ = table.get_column_view(index)
            val = var.to_val(avg_table[0, var])
            if not np.all(col == val):
                avg_table[0, var] = Orange.data.Unknown

        return avg_table

    def grouping_changed(self):
        """Calls commit() indirectly to respect auto_commit setting."""
        self.commit()

    def commit(self):
        averages = None
        if self.data is not None:
            if self.group_var is None:
                averages = self.average_table(self.data)
            else:
                averages = Orange.data.Table.from_domain(self.data.domain)
                for value in self.group_var.values:
                    svfilter = SameValue(self.group_var, value)
                    v_table = self.average_table(svfilter(self.data))
                    averages.extend(v_table)
                # Using "None" as in OWSelectRows
                # Values is required because FilterDiscrete doesn't have
                # negate keyword or IsDefined method
                deffilter = Values(conditions=[FilterDiscrete(self.group_var, None)],
                                   negate=True)
                v_table = self.average_table(deffilter(self.data))
                averages.extend(v_table)
        self.Outputs.averages.send(averages)
Beispiel #19
0
class OWSNR(OWWidget):
    # Widget's name as displayed in the canvas
    name = "SNR"

    # Short widget description
    description = (
        "Calculates Signal-to-Noise Ratio (SNR), Averages or Standard Deviation by coordinates."
    )

    icon = "icons/snr.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        final_data = Output("SNR", Orange.data.Table, default=True)

    OUT_OPTIONS = {
        'Signal-to-noise ratio': 0,  #snr
        'Average': 1,  # average
        'Standard Deviation': 2
    }  # std

    settingsHandler = settings.DomainContextHandler()
    group_x = settings.ContextSetting(None)
    group_y = settings.ContextSetting(None)
    out_choiced = settings.Setting(0)

    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    class Warning(OWWidget.Warning):
        nodata = Msg("No useful data on input!")

    def __init__(self):
        super().__init__()

        self.data = None
        self.set_data(self.data)  # show warning

        self.group_x = None
        self.group_y = None

        # methods in this widget assume group axes in metas
        self.xy_model = DomainModel(DomainModel.METAS,
                                    placeholder="None",
                                    separators=False,
                                    valid_types=Orange.data.ContinuousVariable)
        self.group_view_x = gui.comboBox(self.controlArea,
                                         self,
                                         "group_x",
                                         box="Select axis: x",
                                         model=self.xy_model,
                                         callback=self.grouping_changed)

        self.group_view_y = gui.comboBox(self.controlArea,
                                         self,
                                         "group_y",
                                         box="Select axis: y",
                                         model=self.xy_model,
                                         callback=self.grouping_changed)

        self.selected_out = gui.comboBox(self.controlArea,
                                         self,
                                         "out_choiced",
                                         box="Select Output:",
                                         items=self.OUT_OPTIONS,
                                         callback=self.out_choice_changed)

        gui.auto_commit(self.controlArea, self, "autocommit", "Apply")

        # prepare interface according to the new context
        self.contextAboutToBeOpened.connect(
            lambda x: self.init_attr_values(x[0]))

    def init_attr_values(self, domain):
        self.xy_model.set_domain(domain)
        self.group_x = None
        self.group_y = None

    @Inputs.data
    def set_data(self, dataset):
        self.Warning.nodata.clear()
        self.closeContext()
        self.data = dataset
        self.group = None
        if dataset is None:
            self.Warning.nodata()
        else:
            self.openContext(dataset.domain)

        self.commit()

    def calc_table_np(self, array):
        if len(array) == 0:
            return array
        if self.out_choiced == 0:  #snr
            return self.make_table(
                (bottleneck.nanmean(array, axis=0) /
                 bottleneck.nanstd(array, axis=0)).reshape(1, -1), self.data)
        elif self.out_choiced == 1:  #avg
            return self.make_table(
                bottleneck.nanmean(array, axis=0).reshape(1, -1), self.data)
        else:  # std
            return self.make_table(
                bottleneck.nanstd(array, axis=0).reshape(1, -1), self.data)

    @staticmethod
    def make_table(array, data_table):
        new_table = Orange.data.Table.from_numpy(
            data_table.domain,
            X=array.copy(),
            Y=np.atleast_2d(data_table.Y[0]).copy(),
            metas=np.atleast_2d(data_table.metas[0]).copy())
        cont_vars = data_table.domain.class_vars + data_table.domain.metas
        with new_table.unlocked():
            for var in cont_vars:
                index = data_table.domain.index(var)
                col, _ = data_table.get_column_view(index)
                val = var.to_val(new_table[0, var])
                if not np.all(col == val):
                    new_table[0, var] = Orange.data.Unknown

        return new_table

    def grouping_changed(self):
        """Calls commit() indirectly to respect auto_commit setting."""
        self.commit()

    def out_choice_changed(self):
        self.commit()

    def select_2coordinates(self, attr_x, attr_y):
        xat = self.data.domain[attr_x]
        yat = self.data.domain[attr_y]

        def extract_col(data, var):
            nd = Orange.data.Domain([var])
            d = self.data.transform(nd)
            return d.X[:, 0]

        coorx = extract_col(self.data, xat)
        coory = extract_col(self.data, yat)

        lsx = values_to_linspace(coorx)
        lsy = values_to_linspace(coory)

        xindex, xnan = index_values_nan(coorx, lsx)
        yindex, ynan = index_values_nan(coory, lsy)

        # trick:
        # https://stackoverflow.com/questions/31878240/numpy-average-of-values-corresponding-to-unique-coordinate-positions

        coo = np.hstack([xindex.reshape(-1, 1), yindex.reshape(-1, 1)])
        sortidx = np.lexsort(coo.T)
        sorted_coo = coo[sortidx]
        unqID_mask = np.append(True, np.any(np.diff(sorted_coo, axis=0),
                                            axis=1))
        ID = unqID_mask.cumsum() - 1
        unq_coo = sorted_coo[unqID_mask]
        unique, counts = np.unique(ID, return_counts=True)

        pos = 0
        bins = []
        for size in counts:
            bins.append(sortidx[pos:pos + size])
            pos += size

        matrix = []
        for indices in bins:
            selection = self.data.X[indices]
            array = self.calc_table_np(selection)
            matrix.append(array)
        table_2_coord = Orange.data.Table.concatenate(matrix, axis=0)

        with table_2_coord.unlocked():
            table_2_coord[:, attr_x] = np.linspace(*lsx)[unq_coo[:,
                                                                 0]].reshape(
                                                                     -1, 1)
            table_2_coord[:, attr_y] = np.linspace(*lsy)[unq_coo[:,
                                                                 1]].reshape(
                                                                     -1, 1)
        return table_2_coord

    def select_1coordinate(self, attr):
        at = self.data.domain[attr]

        def extract_col(data, var):
            nd = Orange.data.Domain([var])
            d = self.data.transform(nd)
            return d.X[:, 0]

        coor = extract_col(self.data, at)
        ls = values_to_linspace(coor)
        index, _ = index_values_nan(coor, ls)
        coo = np.hstack([index.reshape(-1, 1)])
        sortidx = np.lexsort(coo.T)
        sorted_coo = coo[sortidx]
        unqID_mask = np.append(True, np.any(np.diff(sorted_coo, axis=0),
                                            axis=1))
        ID = unqID_mask.cumsum() - 1
        unq_coo = sorted_coo[unqID_mask]
        unique, counts = np.unique(ID, return_counts=True)

        pos = 0
        bins = []
        for size in counts:
            bins.append(sortidx[pos:pos + size])
            pos += size

        matrix = []
        for indices in bins:
            selection = self.data.X[indices]
            array = self.calc_table_np(selection)
            matrix.append(array)
        table_1_coord = Orange.data.Table.concatenate(matrix, axis=0)

        with table_1_coord.unlocked():
            table_1_coord[:,
                          attr] = np.linspace(*ls)[unq_coo[:,
                                                           0]].reshape(-1, 1)

        return table_1_coord

    def select_coordinate(self):
        if self.group_y is None and self.group_x is None:
            final_data = self.calc_table_np(self.data.X)
        elif None in [self.group_x, self.group_y]:
            if self.group_x is None:
                group = self.group_y
            else:
                group = self.group_x
            final_data = self.select_1coordinate(group)
        else:
            final_data = self.select_2coordinates(self.group_x, self.group_y)

        return final_data

    def commit(self):
        final_data = None
        if self.data is not None:
            final_data = self.select_coordinate()

        self.Outputs.final_data.send(final_data)
Beispiel #20
0
class OWLinePlot(OWWidget):
    # 根据其使用场景,更像是在对不同类别的数据做一个画像, 而不是我们理解的折线图
    # 参考: https://orange.biolab.si/blog/2019/6/gene-expression-profiles-with-line-plot/
    name = "数据画像(Line Plot)"
    description = "数据画像的可视化(例如,时间序列)。"
    icon = "icons/LinePlot.svg"
    priority = 180
    keywords = ["shujuhuaxiang"]
    category = "可视化(Visualize)"
    buttons_area_orientation = Qt.Vertical
    enable_selection = Signal(bool)

    class Inputs:
        data = Input("数据(Data)", Table, default=True, replaces=["Data"])
        data_subset = Input("数据子集(Data Subset)",
                            Table,
                            replaces=["Data Subset"])

    class Outputs:
        selected_data = Output("选定的数据(Selected Data)",
                               Table,
                               default=True,
                               replaces=["Selected Data"])
        annotated_data = Output("数据(Data)", Table, replaces=["Data"])

    settingsHandler = DomainContextHandler()
    group_var = ContextSetting(None)
    show_profiles = Setting(False)
    show_range = Setting(True)
    show_mean = Setting(True)
    show_error = Setting(False)
    auto_commit = Setting(True)
    selection = Setting(None, schema_only=True)
    visual_settings = Setting({}, schema_only=True)

    graph_name = "graph.plotItem"

    class Error(OWWidget.Error):
        not_enough_attrs = Msg("Need at least one numeric feature.")

    class Warning(OWWidget.Warning):
        no_display_option = Msg("No display option is selected.")

    class Information(OWWidget.Information):
        too_many_features = Msg("Data has too many features. Only first {}"
                                " are shown.".format(MAX_FEATURES))

    def __init__(self, parent=None):
        super().__init__(parent)
        self.__groups = []
        self.data = None
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.graph_variables = []
        self.graph = None
        self.group_vars = None
        self.group_view = None
        self.setup_gui()

        VisualSettingsDialog(self,
                             self.graph.parameter_setter.initial_settings)
        self.graph.view_box.selection_changed.connect(self.selection_changed)
        self.enable_selection.connect(self.graph.view_box.enable_selection)

    def setup_gui(self):
        self._add_graph()
        self._add_controls()

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = LinePlotGraph(self)
        box.layout().addWidget(self.graph)

    def _add_controls(self):
        displaybox = gui.widgetBox(self.controlArea, "显示")
        gui.checkBox(
            displaybox,
            self,
            "show_profiles",
            "线",
            callback=self.__show_profiles_changed,
            tooltip="Plot lines",
        )
        gui.checkBox(
            displaybox,
            self,
            "show_range",
            "范围",
            callback=self.__show_range_changed,
            tooltip="Plot range between 10th and 90th percentile",
        )
        gui.checkBox(
            displaybox,
            self,
            "show_mean",
            "平均值",
            callback=self.__show_mean_changed,
            tooltip="Plot mean curve",
        )
        gui.checkBox(
            displaybox,
            self,
            "show_error",
            "误差线",
            callback=self.__show_error_changed,
            tooltip="Show standard deviation",
        )

        self.group_vars = DomainModel(placeholder="None",
                                      separators=False,
                                      valid_types=DiscreteVariable)
        self.group_view = gui.listView(
            self.controlArea,
            self,
            "group_var",
            box="分组依据",
            model=self.group_vars,
            callback=self.__group_var_changed,
            sizeHint=QSize(30, 100),
            viewType=ListViewSearch,
            sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Expanding),
        )
        self.group_view.setEnabled(False)

        plot_gui = OWPlotGUI(self)
        plot_gui.box_zoom_select(self.buttonsArea)
        gui.auto_send(self.buttonsArea, self, "auto_commit")

    def __show_profiles_changed(self):
        self.check_display_options()
        self._update_visibility("profiles")

    def __show_range_changed(self):
        self.check_display_options()
        self._update_visibility("range")

    def __show_mean_changed(self):
        self.check_display_options()
        self._update_visibility("mean")

    def __show_error_changed(self):
        self._update_visibility("error")

    def __group_var_changed(self):
        if self.data is None or not self.graph_variables:
            return
        self.plot_groups()
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self._update_sub_profiles()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.clear()
        self.check_data()
        self.check_display_options()

        if self.data is not None:
            self.group_vars.set_domain(self.data.domain)
            self.group_view.setEnabled(len(self.group_vars) > 1)
            self.group_var = (self.data.domain.class_var
                              if self.data.domain.has_discrete_class else None)

        self.openContext(data)
        self.setup_plot()
        self.commit.now()

    def check_data(self):
        def error(err):
            err()
            self.data = None

        self.clear_messages()
        if self.data is not None:
            self.graph_variables = [
                var for var in self.data.domain.attributes if var.is_continuous
            ]
            if len(self.graph_variables) < 1:
                error(self.Error.not_enough_attrs)
            else:
                if len(self.graph_variables) > MAX_FEATURES:
                    self.Information.too_many_features()
                    self.graph_variables = self.graph_variables[:MAX_FEATURES]

    def check_display_options(self):
        self.Warning.no_display_option.clear()
        if self.data is not None:
            if not (self.show_profiles or self.show_range or self.show_mean):
                self.Warning.no_display_option()
            enable = (self.show_profiles or
                      self.show_range) and len(self.data) < SEL_MAX_INSTANCES
            self.enable_selection.emit(enable)

    @Inputs.data_subset
    @check_sql_input
    def set_subset_data(self, subset):
        self.subset_data = subset

    def handleNewSignals(self):
        self.set_subset_ids()
        if self.data is not None:
            self._update_profiles_color()
            self._update_sel_profiles_color()
            self._update_sub_profiles()

    def set_subset_ids(self):
        sub_ids = ({e.id
                    for e in self.subset_data}
                   if self.subset_data is not None else {})
        self.subset_indices = None
        if self.data is not None and sub_ids:
            self.subset_indices = [x.id for x in self.data if x.id in sub_ids]

    def setup_plot(self):
        if self.data is None:
            return

        ticks = [a.name for a in self.graph_variables]
        self.graph.getAxis("bottom").set_ticks(ticks)
        self.plot_groups()
        self.apply_selection()
        self.graph.view_box.enableAutoRange()
        self.graph.view_box.updateAutoRange()

    def plot_groups(self):
        self._remove_groups()
        data = self.data[:, self.graph_variables]
        if self.group_var is None:
            self._plot_group(data, np.arange(len(data)))
        else:
            class_col_data, _ = self.data.get_column_view(self.group_var)
            for index in range(len(self.group_var.values)):
                indices = np.flatnonzero(class_col_data == index)
                if len(indices) == 0:
                    continue
                group_data = self.data[indices, self.graph_variables]
                self._plot_group(group_data, indices, index)
        self.graph.update_legend(self.group_var)
        self.graph.groups = self.__groups
        self.graph.view_box.add_profiles(data.X)

    def _remove_groups(self):
        for group in self.__groups:
            group.remove_items()
        self.graph.view_box.remove_profiles()
        self.graph.groups = []
        self.__groups = []

    def _plot_group(self, data, indices, index=None):
        color = self.__get_group_color(index)
        group = ProfileGroup(data, indices, color, self.graph)
        kwargs = self.__get_visibility_flags()
        group.set_visible_error(**kwargs)
        group.set_visible_mean(**kwargs)
        group.set_visible_range(**kwargs)
        group.set_visible_profiles(**kwargs)
        self.__groups.append(group)

    def __get_group_color(self, index):
        if self.group_var is not None:
            return QColor(*self.group_var.colors[index])
        return QColor(LinePlotStyle.DEFAULT_COLOR)

    def __get_visibility_flags(self):
        return {
            "show_profiles": self.show_profiles,
            "show_range": self.show_range,
            "show_mean": self.show_mean,
            "show_error": self.show_error,
        }

    def _update_profiles_color(self):
        # color alpha depends on subset and selection; with selection or
        # subset profiles color has more opacity
        if not self.show_profiles:
            return
        for group in self.__groups:
            has_sel = bool(self.subset_indices) or bool(self.selection)
            group.update_profiles_color(has_sel)

    def _update_sel_profiles_and_range(self):
        # mark selected instances and selected range
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [i for i in group.indices if self.__in(i, self.selection)]
            table = self.data[inds, self.graph_variables].X if inds else None
            if self.show_profiles:
                group.update_sel_profiles(table)
            if self.show_range:
                group.update_sel_range(table)

    def _update_sel_profiles_color(self):
        # color depends on subset; when subset is present,
        # selected profiles are black
        if not self.selection or not self.show_profiles:
            return
        for group in self.__groups:
            group.update_sel_profiles_color(bool(self.subset_indices))

    def _update_sub_profiles(self):
        # mark subset instances
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [
                i for i, _id in zip(group.indices, group.ids)
                if self.__in(_id, self.subset_indices)
            ]
            table = self.data[inds, self.graph_variables].X if inds else None
            group.update_sub_profiles(table)

    def _update_visibility(self, obj_name):
        if len(self.__groups) == 0:
            return
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        kwargs = self.__get_visibility_flags()
        for group in self.__groups:
            getattr(group, "set_visible_{}".format(obj_name))(**kwargs)
        self.graph.view_box.updateAutoRange()

    def apply_selection(self):
        if self.data is not None and self.__pending_selection is not None:
            sel = [i for i in self.__pending_selection if i < len(self.data)]
            mask = np.zeros(len(self.data), dtype=bool)
            mask[sel] = True
            self.selection_changed(mask)
            self.__pending_selection = None

    def selection_changed(self, mask):
        if self.data is None:
            return
        indices = np.arange(len(self.data))[mask]
        self.graph.select(indices)
        old = self.selection
        self.selection = (None
                          if self.data and isinstance(self.data, SqlTable) else
                          list(self.graph.selection))
        if not old and self.selection or old and not self.selection:
            self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self.commit.deferred()

    @gui.deferred
    def commit(self):
        selected = (self.data[self.selection] if self.data is not None
                    and bool(self.selection) else None)
        annotated = create_annotated_table(self.data, self.selection)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        if self.data is None:
            return

        caption = report.render_items_vert((("Group by", self.group_var), ))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def sizeHint(self):
        return QSize(1132, 708)

    def clear(self):
        self.selection = None
        self.__groups = []
        self.graph_variables = []
        self.graph.reset()
        self.group_vars.set_domain(None)
        self.group_view.setEnabled(False)

    @staticmethod
    def __in(obj, collection):
        return collection is not None and obj in collection

    def set_visual_settings(self, key, value):
        self.graph.parameter_setter.set_parameter(key, value)
        self.visual_settings[key] = value
Beispiel #21
0
class OWScatterPlot(OWDataProjectionWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs(OWDataProjectionWidget.Inputs):
        features = Input("Features", AttributeList)

    class Outputs(OWDataProjectionWidget.Outputs):
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 4
    auto_sample = Setting(True)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    tooltip_shows_all = Setting(True)

    GRAPH_CLASS = OWScatterPlotGraph
    graph = SettingProvider(OWScatterPlotGraph)
    embedding_variables_names = None

    xy_changed_manually = Signal(Variable, Variable)

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg(
            "Plot cannot be displayed because '{}' or '{}' "
            "is missing for all data points")
        no_continuous_vars = Msg("Data has no continuous variables")

    class Information(OWDataProjectionWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        super().__init__()

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            for ext in w.EXTENSIONS:
                self.graph_writers[ext] = w

    def _add_controls(self):
        self._add_controls_axis()
        self._add_controls_sampling()
        super()._add_controls()
        self.gui.add_widgets(
            [self.gui.ShowGridLines,
             self.gui.ToolTipShowsAll,
             self.gui.RegressionLine],
            self._plot_box)
        gui.checkBox(
            gui.indentedBox(self._plot_box), self,
            value="graph.orthonormal_regression",
            label="Treat variables as independent",
            callback=self.graph.update_regression_line,
            tooltip=
            "If checked, fit line to group (minimize distance from points);\n"
            "otherwise fit y as a function of x (minimize vertical distances)")

    def _add_controls_axis(self):
        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str, contentsLength=14
        )
        box = gui.vBox(self.controlArea, True)
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=ContinuousVariable)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:",
            callback=self.set_attr_from_combo,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:",
            callback=self.set_attr_from_combo,
            model=self.xy_model, **common_options)
        vizrank_box = gui.hBox(box)
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

    def _add_controls_sampling(self):
        self.sampling = gui.auto_commit(
            self.controlArea, self, "auto_sample", "Sample", box="Sampling",
            callback=self.switch_sampling, commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

    @property
    def effective_variables(self):
        return [self.attr_x, self.attr_y]

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
            len(self.xy_model) > 2 and len(self.data[self.valid_data]) > 1 \
            and np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0)
        self.vizrank_button.setEnabled(
            is_enabled and self.attr_color is not None and
            not np.isnan(self.data.get_column_view(
                self.attr_color)[0].astype(float)).all())
        text = "Color variable has to be selected." \
            if is_enabled and self.attr_color is None else ""
        self.vizrank_button.setToolTip(text)

    def set_data(self, data):
        if self.data and data and self.data.checksum() == data.checksum():
            return
        super().set_data(data)

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.attr_label, str):
            self.attr_label = findvar(self.attr_label, self.gui.label_model)
        if isinstance(self.attr_color, str):
            self.attr_color = findvar(self.attr_color, self.gui.color_model)
        if isinstance(self.attr_shape, str):
            self.attr_shape = findvar(self.attr_shape, self.gui.shape_model)
        if isinstance(self.attr_size, str):
            self.attr_size = findvar(self.attr_size, self.gui.size_model)

    def check_data(self):
        self.clear_messages()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(self.data, SqlTable):
            if self.data.approx_len() < 4000:
                self.data = Table(self.data)
            else:
                self.Information.sampled_sql()
                self.sql_data = self.data
                data_sample = self.data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                self.data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if self.data is not None:
            if not self.data.domain.has_continuous_attributes(True, True):
                self.Warning.no_continuous_vars()
                self.data = None

        if self.data is not None and (len(self.data) == 0 or
                                      len(self.data.domain) == 0):
            self.data = None

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        x_data = self.get_column(self.attr_x, filter_valid=False)
        y_data = self.get_column(self.attr_y, filter_valid=False)
        if x_data is None or y_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_x.name, self.attr_y.name)
        return np.vstack((x_data, y_data)).T

    # Tooltip
    def _point_tooltip(self, point_id, skip_attrs=()):
        point_data = self.data[point_id]
        xy_attrs = (self.attr_x, self.attr_y)
        text = "<br/>".join(
            escape('{} = {}'.format(var.name, point_data[var]))
            for var in xy_attrs)
        if self.tooltip_shows_all:
            others = super()._point_tooltip(point_id, skip_attrs=xy_attrs)
            if others:
                text = "<b>{}</b><br/><br/>{}".format(text, others)
        return text

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            self.__timer.stop()
            return
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def init_attr_values(self):
        super().init_attr_values()
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        super().set_subset_data(subset_data)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        if self.attribute_selection_list and self.data is not None and \
                self.data.domain is not None and \
                all(attr in self.data.domain for attr
                        in self.attribute_selection_list):
            self.attr_x, self.attr_y = self.attribute_selection_list[:2]
            self.attribute_selection_list = None
        super().handleNewSignals()
        self._vizrank_color_change()

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
            self._invalidated = self._invalidated \
                or self.attr_x != attributes[0] \
                or self.attr_y != attributes[1]
        else:
            self.attribute_selection_list = None

    def set_attr(self, attr_x, attr_y):
        if attr_x != self.attr_x or attr_y != self.attr_y:
            self.attr_x, self.attr_y = attr_x, attr_y
            self.attr_changed()

    def set_attr_from_combo(self):
        self.attr_changed()
        self.xy_changed_manually.emit(self.attr_x, self.attr_y)

    def attr_changed(self):
        self.setup_plot()
        self.commit()

    def get_axes(self):
        return {"bottom": self.attr_x, "left": self.attr_y}

    def colors_changed(self):
        super().colors_changed()
        self._vizrank_color_change()

    def commit(self):
        super().commit()
        self.send_features()

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
        return None

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1) for a in settings["selection"]]
        if version < 3:
            if "auto_send_selection" in settings:
                settings["auto_commit"] = settings["auto_send_selection"]
            if "selection_group" in settings:
                settings["selection"] = settings["selection_group"]

    @classmethod
    def migrate_context(cls, context, version):
        values = context.values
        if version < 3:
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
        if version < 4:
            if values["attr_x"][1] % 100 == 1 or values["attr_y"][1] % 100 == 1:
                raise IncompatibleContext()
Beispiel #22
0
class OWScatterPlot(OWDataProjectionWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = '散点图(Scatter Plot)'
    description = "具有智能数据可视化增强功能的交互式散点图可视化工具。"
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs(OWDataProjectionWidget.Inputs):
        features = Input("特征(Features)", AttributeList, replaces=['Features'])

    class Outputs(OWDataProjectionWidget.Outputs):
        features = Output("特征(Features)",
                          AttributeList,
                          dynamic=False,
                          replaces=['Features'])

    settings_version = 4
    auto_sample = Setting(True)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    tooltip_shows_all = Setting(True)

    GRAPH_CLASS = OWScatterPlotGraph
    graph = SettingProvider(OWScatterPlotGraph)
    embedding_variables_names = None

    xy_changed_manually = Signal(Variable, Variable)

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")
        no_continuous_vars = Msg("Data has no continuous variables")

    class Information(OWDataProjectionWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        super().__init__()

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            self.graph_writers.append(w)

    def _add_controls(self):
        self._add_controls_axis()
        self._add_controls_sampling()
        super()._add_controls()
        self.gui.add_widgets([
            self.gui.ShowGridLines, self.gui.ToolTipShowsAll,
            self.gui.RegressionLine
        ], self._plot_box)
        gui.checkBox(
            gui.indentedBox(self._plot_box),
            self,
            value="graph.orthonormal_regression",
            label="将变量视为独立变量",
            callback=self.graph.update_regression_line,
            tooltip=
            "If checked, fit line to group (minimize distance from points);\n"
            "otherwise fit y as a function of x (minimize vertical distances)")

    def _add_controls_axis(self):
        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str,
                              contentsLength=14)
        self.attr_box = gui.vBox(self.controlArea, True)
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=ContinuousVariable)
        self.cb_attr_x = gui.comboBox(self.attr_box,
                                      self,
                                      "attr_x",
                                      label="x 轴:",
                                      callback=self.set_attr_from_combo,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(self.attr_box,
                                      self,
                                      "attr_y",
                                      label="y 轴:",
                                      callback=self.set_attr_from_combo,
                                      model=self.xy_model,
                                      **common_options)
        vizrank_box = gui.hBox(self.attr_box)
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "查找信息投影(Find Informative Projections)",
            self.set_attr)

    def _add_controls_sampling(self):
        self.sampling = gui.auto_commit(self.controlArea,
                                        self,
                                        "auto_sample",
                                        "Sample",
                                        box="Sampling",
                                        callback=self.switch_sampling,
                                        commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

    @property
    def effective_variables(self):
        return [self.attr_x, self.attr_y
                ] if self.attr_x and self.attr_y else []

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        err_msg = ""
        if self.data is None:
            err_msg = "No data on input"
        elif self.data.is_sparse():
            err_msg = "Data is sparse"
        elif len(self.xy_model) < 3:
            err_msg = "Not enough features for ranking"
        elif self.attr_color is None:
            err_msg = "Color variable is not selected"
        elif np.isnan(
                self.data.get_column_view(
                    self.attr_color)[0].astype(float)).all():
            err_msg = "Color variable has no values"
        self.vizrank_button.setEnabled(not err_msg)
        self.vizrank_button.setToolTip(err_msg)

    def set_data(self, data):
        super().set_data(data)
        self._vizrank_color_change()

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.attr_label, str):
            self.attr_label = findvar(self.attr_label, self.gui.label_model)
        if isinstance(self.attr_color, str):
            self.attr_color = findvar(self.attr_color, self.gui.color_model)
        if isinstance(self.attr_shape, str):
            self.attr_shape = findvar(self.attr_shape, self.gui.shape_model)
        if isinstance(self.attr_size, str):
            self.attr_size = findvar(self.attr_size, self.gui.size_model)

    def check_data(self):
        super().check_data()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(self.data, SqlTable):
            if self.data.approx_len() < 4000:
                self.data = Table(self.data)
            else:
                self.Information.sampled_sql()
                self.sql_data = self.data
                data_sample = self.data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                self.data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if self.data is not None:
            if not self.data.domain.has_continuous_attributes(True, True):
                self.Warning.no_continuous_vars()
                self.data = None

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        x_data = self.get_column(self.attr_x, filter_valid=False)
        y_data = self.get_column(self.attr_y, filter_valid=False)
        if x_data is None or y_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_x.name, self.attr_y.name)
        return np.vstack((x_data, y_data)).T

    # Tooltip
    def _point_tooltip(self, point_id, skip_attrs=()):
        point_data = self.data[point_id]
        xy_attrs = (self.attr_x, self.attr_y)
        text = "<br/>".join(
            escape('{} = {}'.format(var.name, point_data[var]))
            for var in xy_attrs)
        if self.tooltip_shows_all:
            others = super()._point_tooltip(point_id, skip_attrs=xy_attrs)
            if others:
                text = "<b>{}</b><br/><br/>{}".format(text, others)
        return text

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            self.__timer.stop()
            return
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def init_attr_values(self):
        super().init_attr_values()
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        super().set_subset_data(subset_data)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.attr_box.setEnabled(True)
        self.vizrank.setEnabled(True)
        if self.attribute_selection_list and self.data is not None and \
                self.data.domain is not None and \
                all(attr in self.data.domain for attr
                        in self.attribute_selection_list):
            self.attr_x, self.attr_y = self.attribute_selection_list[:2]
            self.attr_box.setEnabled(False)
            self.vizrank.setEnabled(False)
        super().handleNewSignals()
        if self._domain_invalidated:
            self.graph.update_axes()
            self._domain_invalidated = False

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
            self._invalidated = self._invalidated \
                or self.attr_x != attributes[0] \
                or self.attr_y != attributes[1]
        else:
            self.attribute_selection_list = None

    def set_attr(self, attr_x, attr_y):
        if attr_x != self.attr_x or attr_y != self.attr_y:
            self.attr_x, self.attr_y = attr_x, attr_y
            self.attr_changed()

    def set_attr_from_combo(self):
        self.attr_changed()
        self.xy_changed_manually.emit(self.attr_x, self.attr_y)

    def attr_changed(self):
        self.setup_plot()
        self.commit()

    def get_axes(self):
        return {"bottom": self.attr_x, "left": self.attr_y}

    def colors_changed(self):
        super().colors_changed()
        self._vizrank_color_change()

    def commit(self):
        super().commit()
        self.send_features()

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
        return None

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1)
                                           for a in settings["selection"]]
        if version < 3:
            if "auto_send_selection" in settings:
                settings["auto_commit"] = settings["auto_send_selection"]
            if "selection_group" in settings:
                settings["selection"] = settings["selection_group"]

    @classmethod
    def migrate_context(cls, context, version):
        values = context.values
        if version < 3:
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
        if version < 4:
            if values["attr_x"][1] % 100 == 1 or values["attr_y"][1] % 100 == 1:
                raise IncompatibleContext()
Beispiel #23
0
class OWLinePlot(OWWidget):
    name = "Line Plot"
    description = "Visualization of data profiles (e.g., time series)."
    icon = "icons/LinePlot.svg"
    priority = 180

    enable_selection = Signal(bool)

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler()
    group_var = ContextSetting(None)
    show_profiles = Setting(False)
    show_range = Setting(True)
    show_mean = Setting(True)
    show_error = Setting(False)
    auto_commit = Setting(True)
    selection = Setting(None, schema_only=True)

    graph_name = "graph.plotItem"

    class Error(OWWidget.Error):
        not_enough_attrs = Msg("Need at least one continuous feature.")
        no_valid_data = Msg("No plot due to no valid data.")

    class Warning(OWWidget.Warning):
        no_display_option = Msg("No display option is selected.")

    class Information(OWWidget.Information):
        hidden_instances = Msg("Instances with unknown values are not shown.")
        too_many_features = Msg("Data has too many features. Only first {}"
                                " are shown.".format(MAX_FEATURES))

    def __init__(self, parent=None):
        super().__init__(parent)
        self.__groups = []
        self.data = None
        self.valid_data = None
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.graph_variables = []
        self.setup_gui()

        self.graph.view_box.selection_changed.connect(self.selection_changed)
        self.enable_selection.connect(self.graph.view_box.enable_selection)

    def setup_gui(self):
        self._add_graph()
        self._add_controls()

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = LinePlotGraph(self)
        box.layout().addWidget(self.graph)

    def _add_controls(self):
        infobox = gui.widgetBox(self.controlArea, "Info")
        self.infoLabel = gui.widgetLabel(infobox, "No data on input.")
        displaybox = gui.widgetBox(self.controlArea, "Display")
        gui.checkBox(displaybox, self, "show_profiles", "Lines",
                     callback=self.__show_profiles_changed,
                     tooltip="Plot lines")
        gui.checkBox(displaybox, self, "show_range", "Range",
                     callback=self.__show_range_changed,
                     tooltip="Plot range between 10th and 90th percentile")
        gui.checkBox(displaybox, self, "show_mean", "Mean",
                     callback=self.__show_mean_changed,
                     tooltip="Plot mean curve")
        gui.checkBox(displaybox, self, "show_error", "Error bars",
                     callback=self.__show_error_changed,
                     tooltip="Show standard deviation")

        self.group_vars = DomainModel(
            placeholder="None", separators=False, valid_types=DiscreteVariable)
        self.group_view = gui.listView(
            self.controlArea, self, "group_var", box="Group by",
            model=self.group_vars, callback=self.__group_var_changed)
        self.group_view.setEnabled(False)
        self.group_view.setMinimumSize(QSize(30, 100))
        self.group_view.setSizePolicy(QSizePolicy.Expanding,
                                      QSizePolicy.Ignored)

        plot_gui = OWPlotGUI(self)
        plot_gui.box_zoom_select(self.controlArea)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

    def __show_profiles_changed(self):
        self.check_display_options()
        self._update_visibility("profiles")

    def __show_range_changed(self):
        self.check_display_options()
        self._update_visibility("range")

    def __show_mean_changed(self):
        self.check_display_options()
        self._update_visibility("mean")

    def __show_error_changed(self):
        self._update_visibility("error")

    def __group_var_changed(self):
        if self.data is None or not self.graph_variables:
            return
        self.plot_groups()
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self._update_sub_profiles()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.clear()
        self.check_data()
        self.check_display_options()

        if self.data is not None:
            self.group_vars.set_domain(self.data.domain)
            self.group_view.setEnabled(len(self.group_vars) > 1)
            self.group_var = self.data.domain.class_var \
                if self.data.domain.has_discrete_class else None

        self.openContext(data)
        self.setup_plot()
        self.commit()

    def check_data(self):
        def error(err):
            err()
            self.data = None

        self.clear_messages()
        if self.data is not None:
            self.infoLabel.setText("%i instances on input\n%i features" % (
                len(self.data), len(self.data.domain.attributes)))
            self.graph_variables = [var for var in self.data.domain.attributes
                                    if var.is_continuous]
            self.valid_data = ~countnans(self.data.X, axis=1).astype(bool)
            if len(self.graph_variables) < 1:
                error(self.Error.not_enough_attrs)
            elif not np.sum(self.valid_data):
                error(self.Error.no_valid_data)
            else:
                if not np.all(self.valid_data):
                    self.Information.hidden_instances()
                if len(self.graph_variables) > MAX_FEATURES:
                    self.Information.too_many_features()
                    self.graph_variables = self.graph_variables[:MAX_FEATURES]

    def check_display_options(self):
        self.Warning.no_display_option.clear()
        if self.data is not None:
            if not (self.show_profiles or self.show_range or self.show_mean):
                self.Warning.no_display_option()
            enable = (self.show_profiles or self.show_range) and \
                len(self.data[self.valid_data]) < SEL_MAX_INSTANCES
            self.enable_selection.emit(enable)

    @Inputs.data_subset
    @check_sql_input
    def set_subset_data(self, subset):
        self.subset_data = subset

    def handleNewSignals(self):
        self.set_subset_ids()
        if self.data is not None:
            self._update_profiles_color()
            self._update_sel_profiles_color()
            self._update_sub_profiles()

    def set_subset_ids(self):
        sub_ids = {e.id for e in self.subset_data} \
            if self.subset_data is not None else {}
        self.subset_indices = None
        if self.data is not None and sub_ids:
            self.subset_indices = [x.id for x in self.data[self.valid_data]
                                   if x.id in sub_ids]

    def setup_plot(self):
        if self.data is None:
            return

        ticks = [a.name for a in self.graph_variables]
        self.graph.getAxis("bottom").set_ticks(ticks)
        self.plot_groups()
        self.apply_selection()
        self.graph.view_box.enableAutoRange()
        self.graph.view_box.updateAutoRange()

    def plot_groups(self):
        self._remove_groups()
        data = self.data[self.valid_data, self.graph_variables]
        if self.group_var is None:
            self._plot_group(data, np.where(self.valid_data)[0])
        else:
            class_col_data, _ = self.data.get_column_view(self.group_var)
            for index in range(len(self.group_var.values)):
                mask = np.logical_and(class_col_data == index, self.valid_data)
                indices = np.flatnonzero(mask)
                if not len(indices):
                    continue
                group_data = self.data[indices, self.graph_variables]
                self._plot_group(group_data, indices, index)
        self.graph.update_legend(self.group_var)
        self.graph.view_box.add_profiles(data.X)

    def _remove_groups(self):
        for group in self.__groups:
            group.remove_items()
        self.graph.view_box.remove_profiles()
        self.__groups = []

    def _plot_group(self, data, indices, index=None):
        color = self.__get_group_color(index)
        group = ProfileGroup(data, indices, color, self.graph)
        kwargs = self.__get_visibility_flags()
        group.set_visible_error(**kwargs)
        group.set_visible_mean(**kwargs)
        group.set_visible_range(**kwargs)
        group.set_visible_profiles(**kwargs)
        self.__groups.append(group)

    def __get_group_color(self, index):
        if self.group_var is not None:
            return QColor(*self.group_var.colors[index])
        return QColor(LinePlotStyle.DEFAULT_COLOR)

    def __get_visibility_flags(self):
        return {"show_profiles": self.show_profiles,
                "show_range": self.show_range,
                "show_mean": self.show_mean,
                "show_error": self.show_error}

    def _update_profiles_color(self):
        # color alpha depends on subset and selection; with selection or
        # subset profiles color has more opacity
        if not self.show_profiles:
            return
        for group in self.__groups:
            has_sel = bool(self.subset_indices) or bool(self.selection)
            group.update_profiles_color(has_sel)

    def _update_sel_profiles_and_range(self):
        # mark selected instances and selected range
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [i for i in group.indices if self.__in(i, self.selection)]
            table = self.data[inds, self.graph_variables].X if inds else None
            if self.show_profiles:
                group.update_sel_profiles(table)
            if self.show_range:
                group.update_sel_range(table)

    def _update_sel_profiles_color(self):
        # color depends on subset; when subset is present,
        # selected profiles are black
        if not self.selection or not self.show_profiles:
            return
        for group in self.__groups:
            group.update_sel_profiles_color(bool(self.subset_indices))

    def _update_sub_profiles(self):
        # mark subset instances
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [i for i, _id in zip(group.indices, group.ids)
                    if self.__in(_id, self.subset_indices)]
            table = self.data[inds, self.graph_variables].X if inds else None
            group.update_sub_profiles(table)

    def _update_visibility(self, obj_name):
        if not len(self.__groups):
            return
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        kwargs = self.__get_visibility_flags()
        for group in self.__groups:
            getattr(group, "set_visible_{}".format(obj_name))(**kwargs)
        self.graph.view_box.updateAutoRange()

    def apply_selection(self):
        if self.data is not None and self.__pending_selection is not None:
            sel = [i for i in self.__pending_selection if i < len(self.data)]
            mask = np.zeros(len(self.data), dtype=bool)
            mask[sel] = True
            mask = mask[self.valid_data]
            self.selection_changed(mask)
            self.__pending_selection = None

    def selection_changed(self, mask):
        if self.data is None:
            return
        # need indices for self.data: mask refers to self.data[self.valid_data]
        indices = np.arange(len(self.data))[self.valid_data][mask]
        self.graph.select(indices)
        old = self.selection
        self.selection = None if self.data and isinstance(self.data, SqlTable)\
            else list(self.graph.selection)
        if not old and self.selection or old and not self.selection:
            self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self.commit()

    def commit(self):
        selected = self.data[self.selection] \
            if self.data is not None and bool(self.selection) else None
        annotated = create_annotated_table(self.data, self.selection)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        if self.data is None:
            return

        caption = report.render_items_vert((("Group by", self.group_var),))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def sizeHint(self):
        return QSize(1132, 708)

    def clear(self):
        self.valid_data = None
        self.selection = None
        self.__groups = []
        self.graph_variables = []
        self.graph.reset()
        self.infoLabel.setText("No data on input.")
        self.group_vars.set_domain(None)
        self.group_view.setEnabled(False)

    @staticmethod
    def __in(obj, collection):
        return collection is not None and obj in collection
Beispiel #24
0
class OWCorrelations(OWWidget):
    name = "Correlations"
    description = "Compute all pairwise attribute correlations."
    icon = "icons/Correlations.svg"
    priority = 1106

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table)
        features = Output("Features", AttributeList)
        correlations = Output("Correlations", Table)

    want_control_area = False

    correlation_type: int

    settings_version = 3
    settingsHandler = DomainContextHandler()
    selection = ContextSetting([])
    feature = ContextSetting(None)
    correlation_type = Setting(0)

    class Information(OWWidget.Information):
        removed_cons_feat = Msg("Constant features have been removed.")

    class Warning(OWWidget.Warning):
        not_enough_vars = Msg("At least two continuous features are needed.")
        not_enough_inst = Msg("At least two instances are needed.")

    def __init__(self):
        super().__init__()
        self.data = None  # type: Table
        self.cont_data = None  # type: Table

        # GUI
        box = gui.vBox(self.mainArea)
        self.correlation_combo = gui.comboBox(
            box,
            self,
            "correlation_type",
            items=CorrelationType.items(),
            orientation=Qt.Horizontal,
            callback=self._correlation_combo_changed)

        self.feature_model = DomainModel(order=DomainModel.ATTRIBUTES,
                                         separators=False,
                                         placeholder="(All combinations)",
                                         valid_types=ContinuousVariable)
        gui.comboBox(box,
                     self,
                     "feature",
                     callback=self._feature_combo_changed,
                     model=self.feature_model)

        self.vizrank, _ = CorrelationRank.add_vizrank(
            None, self, None, self._vizrank_selection_changed)
        self.vizrank.button.setEnabled(False)
        self.vizrank.threadStopped.connect(self._vizrank_stopped)

        gui.separator(box)
        box.layout().addWidget(self.vizrank.filter)
        box.layout().addWidget(self.vizrank.rank_table)

        button_box = gui.hBox(self.mainArea)
        button_box.layout().addWidget(self.vizrank.button)

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

    @staticmethod
    def sizeHint():
        return QSize(350, 400)

    def _correlation_combo_changed(self):
        self.apply()

    def _feature_combo_changed(self):
        self.apply()

    def _vizrank_selection_changed(self, *args):
        self.selection = list(args)
        self.commit()

    def _vizrank_stopped(self):
        self._vizrank_select()

    def _vizrank_select(self):
        model = self.vizrank.rank_table.model()
        if not model.rowCount():
            return
        selection = QItemSelection()

        # This flag is needed because data in the model could be
        # filtered by a feature and therefore selection could not be found
        selection_in_model = False
        if self.selection:
            sel_names = sorted(var.name for var in self.selection)
            for i in range(model.rowCount()):
                # pylint: disable=protected-access
                names = sorted(x.name for x in model.data(
                    model.index(i, 0), CorrelationRank._AttrRole))
                if names == sel_names:
                    selection.select(model.index(i, 0),
                                     model.index(i,
                                                 model.columnCount() - 1))
                    selection_in_model = True
                    break
        if not selection_in_model:
            selection.select(model.index(0, 0),
                             model.index(0,
                                         model.columnCount() - 1))
        self.vizrank.rank_table.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear_messages()
        self.data = data
        self.cont_data = None
        self.selection = []
        if data is not None:
            if len(data) < 2:
                self.Warning.not_enough_inst()
            else:
                domain = data.domain
                cont_vars = [
                    a for a in domain.class_vars + domain.metas +
                    domain.attributes if a.is_continuous
                ]
                cont_data = Table.from_table(Domain(cont_vars), data)
                remover = Remove(Remove.RemoveConstant)
                cont_data = remover(cont_data)
                if remover.attr_results["removed"]:
                    self.Information.removed_cons_feat()
                if len(cont_data.domain.attributes) < 2:
                    self.Warning.not_enough_vars()
                else:
                    self.cont_data = SklImpute()(cont_data)
            self.info.set_input_summary(len(data),
                                        format_summary_details(data))
        else:
            self.info.set_input_summary(self.info.NoInput)
        self.set_feature_model()
        self.openContext(self.cont_data)
        self.apply()
        self.vizrank.button.setEnabled(self.cont_data is not None)

    def set_feature_model(self):
        self.feature_model.set_domain(self.cont_data and self.cont_data.domain)
        data = self.data
        if self.cont_data and data.domain.has_continuous_class:
            self.feature = self.cont_data.domain[data.domain.class_var.name]
        else:
            self.feature = None

    def apply(self):
        self.vizrank.initialize()
        if self.cont_data is not None:
            # this triggers self.commit() by changing vizrank selection
            self.vizrank.toggle()
        else:
            self.commit()

    def commit(self):
        self.Outputs.data.send(self.data)
        summary = len(self.data) if self.data else self.info.NoOutput
        details = format_summary_details(self.data) if self.data else ""
        self.info.set_output_summary(summary, details)

        if self.data is None or self.cont_data is None:
            self.Outputs.features.send(None)
            self.Outputs.correlations.send(None)
            return

        attrs = [ContinuousVariable("Correlation"), ContinuousVariable("FDR")]
        metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
        domain = Domain(attrs, metas=metas)
        model = self.vizrank.rank_model
        x = np.array([[
            float(model.data(model.index(row, 0), role))
            for role in (Qt.DisplayRole, CorrelationRank.PValRole)
        ] for row in range(model.rowCount())])
        x[:, 1] = FDR(list(x[:, 1]))
        # pylint: disable=protected-access
        m = np.array([[
            a.name
            for a in model.data(model.index(row, 0), CorrelationRank._AttrRole)
        ] for row in range(model.rowCount())],
                     dtype=object)
        corr_table = Table(domain, x, metas=m)
        corr_table.name = "Correlations"

        # data has been imputed; send original attributes
        self.Outputs.features.send(
            AttributeList(
                [self.data.domain[var.name] for var in self.selection]))
        self.Outputs.correlations.send(corr_table)

    def send_report(self):
        self.report_table(CorrelationType.items()[self.correlation_type],
                          self.vizrank.rank_table)

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            sel = context.values["selection"]
            context.values["selection"] = [(var.name, vartype(var))
                                           for var in sel[0]]
        if version < 3:
            sel = context.values["selection"]
            context.values["selection"] = ([(name, vtype + 100)
                                            for name, vtype in sel], -3)
Beispiel #25
0
class OWHeatMap(widget.OWWidget):
    name = "Heat Map"
    description = "Plot a data matrix heatmap."
    icon = "icons/Heatmap.svg"
    priority = 260
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settings_version = 3

    settingsHandler = settings.DomainContextHandler()

    # Disable clustering for inputs bigger than this
    MaxClustering = 25000
    # Disable cluster leaf ordering for inputs bigger than this
    MaxOrderedClustering = 1000

    threshold_low = settings.Setting(0.0)
    threshold_high = settings.Setting(1.0)

    merge_kmeans = settings.Setting(False)
    merge_kmeans_k = settings.Setting(50)

    # Display column with averages
    averages: bool = settings.Setting(True)
    # Display legend
    legend: bool = settings.Setting(True)
    # Annotations
    #: text row annotation (row names)
    annotation_var = settings.ContextSetting(None)
    #: color row annotation
    annotation_color_var = settings.ContextSetting(None)
    # Discrete variable used to split that data/heatmaps (vertically)
    split_by_var = settings.ContextSetting(None)
    # Selected row/column clustering method (name)
    col_clustering_method: str = settings.Setting(Clustering.None_.name)
    row_clustering_method: str = settings.Setting(Clustering.None_.name)

    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
    column_label_pos: int = settings.Setting(1)
    selected_rows: List[int] = settings.Setting(None, schema_only=True)

    auto_commit = settings.Setting(True)

    graph_name = "scene"

    left_side_scrolling = True

    class Information(widget.OWWidget.Information):
        sampled = Msg("Data has been sampled")
        discrete_ignored = Msg("{} categorical feature{} ignored")
        row_clust = Msg("{}")
        col_clust = Msg("{}")
        sparse_densified = Msg("Showing this data may require a lot of memory")

    class Error(widget.OWWidget.Error):
        no_continuous = Msg("No numeric features")
        not_enough_features = Msg("Not enough features for column clustering")
        not_enough_instances = Msg("Not enough instances for clustering")
        not_enough_instances_k_means = Msg(
            "Not enough instances for k-means merging")
        not_enough_memory = Msg("Not enough memory to show this data")

    class Warning(widget.OWWidget.Warning):
        empty_clusters = Msg("Empty clusters were removed")

    def __init__(self):
        super().__init__()
        self.__pending_selection = self.selected_rows

        # A kingdom for a save_state/restore_state
        self.col_clustering = enum_get(Clustering, self.col_clustering_method,
                                       Clustering.None_)
        self.row_clustering = enum_get(Clustering, self.row_clustering_method,
                                       Clustering.None_)

        @self.settingsAboutToBePacked.connect
        def _():
            self.col_clustering_method = self.col_clustering.name
            self.row_clustering_method = self.row_clustering.name

        self.keep_aspect = False

        #: The original data with all features (retained to
        #: preserve the domain on the output)
        self.input_data = None
        #: The effective data striped of discrete features, and often
        #: merged using k-means
        self.data = None
        self.effective_data = None
        #: kmeans model used to merge rows of input_data
        self.kmeans_model = None
        #: merge indices derived from kmeans
        #: a list (len==k) of int ndarray where the i-th item contains
        #: the indices which merge the input_data into the heatmap row i
        self.merge_indices = None
        self.parts: Optional[Parts] = None
        self.__rows_cache = {}
        self.__columns_cache = {}

        # GUI definition
        colorbox = gui.vBox(self.controlArea, "Color")
        self.color_cb = gui.palette_combo_box(self.palette_name)
        self.color_cb.currentIndexChanged.connect(self.update_color_schema)
        colorbox.layout().addWidget(self.color_cb)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        lowslider = gui.hSlider(colorbox,
                                self,
                                "threshold_low",
                                minValue=0.0,
                                maxValue=1.0,
                                step=0.05,
                                ticks=True,
                                intOnly=False,
                                createLabel=False,
                                callback=self.update_lowslider)
        highslider = gui.hSlider(colorbox,
                                 self,
                                 "threshold_high",
                                 minValue=0.0,
                                 maxValue=1.0,
                                 step=0.05,
                                 ticks=True,
                                 intOnly=False,
                                 createLabel=False,
                                 callback=self.update_highslider)

        form.addRow("Low:", lowslider)
        form.addRow("High:", highslider)

        colorbox.layout().addLayout(form)

        mergebox = gui.vBox(
            self.controlArea,
            "Merge",
        )
        gui.checkBox(mergebox,
                     self,
                     "merge_kmeans",
                     "Merge by k-means",
                     callback=self.__update_row_clustering)
        ibox = gui.indentedBox(mergebox)
        gui.spin(ibox,
                 self,
                 "merge_kmeans_k",
                 minv=5,
                 maxv=500,
                 label="Clusters:",
                 keyboardTracking=False,
                 callbackOnReturn=True,
                 callback=self.update_merge)

        cluster_box = gui.vBox(self.controlArea, "Clustering")
        # Row clustering
        self.row_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.row_clustering, ClusteringRole)
        self.connect_control(
            "row_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_row_clustering(cb.itemData(idx, ClusteringRole))

        # Column clustering
        self.col_cluster_cb = cb = ComboBox()
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.col_clustering, ClusteringRole)
        self.connect_control(
            "col_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_col_clustering(cb.itemData(idx, ClusteringRole))

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )
        form.addRow("Rows:", self.row_cluster_cb)
        form.addRow("Columns:", self.col_cluster_cb)
        cluster_box.layout().addLayout(form)
        box = gui.vBox(self.controlArea, "Split By")

        self.row_split_model = DomainModel(
            placeholder="(None)",
            valid_types=(Orange.data.DiscreteVariable, ),
            parent=self,
        )
        self.row_split_cb = cb = ComboBoxSearch(
            enabled=not self.merge_kmeans,
            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=14,
            toolTip="Split the heatmap vertically by a categorical column")
        self.row_split_cb.setModel(self.row_split_model)
        self.connect_control("split_by_var",
                             lambda value, cb=cb: cbselect(cb, value))
        self.connect_control("merge_kmeans", self.row_split_cb.setDisabled)
        self.split_by_var = None

        self.row_split_cb.activated.connect(self.__on_split_rows_activated)
        box.layout().addWidget(self.row_split_cb)

        box = gui.vBox(self.controlArea, 'Annotation && Legends')

        gui.checkBox(box,
                     self,
                     'legend',
                     'Show legend',
                     callback=self.update_legend)

        gui.checkBox(box,
                     self,
                     'averages',
                     'Stripes with averages',
                     callback=self.update_averages_stripe)
        annotbox = QGroupBox("Row Annotations", flat=True)
        form = QFormLayout(annotbox,
                           formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        self.annotation_model = DomainModel(placeholder="(None)")
        self.annotation_text_cb = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength)
        self.annotation_text_cb.setModel(self.annotation_model)
        self.annotation_text_cb.activated.connect(self.set_annotation_var)
        self.connect_control("annotation_var", self.annotation_var_changed)

        self.row_side_color_model = DomainModel(
            order=(DomainModel.CLASSES, DomainModel.Separator,
                   DomainModel.METAS),
            placeholder="(None)",
            valid_types=DomainModel.PRIMITIVE,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
            parent=self,
        )
        self.row_side_color_cb = ComboBoxSearch(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            minimumContentsLength=12)
        self.row_side_color_cb.setModel(self.row_side_color_model)
        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
        self.connect_control("annotation_color_var",
                             self.annotation_color_var_changed)
        form.addRow("Text", self.annotation_text_cb)
        form.addRow("Color", self.row_side_color_cb)
        box.layout().addWidget(annotbox)
        posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
        posbox.setFlat(True)
        cb = gui.comboBox(posbox,
                          self,
                          "column_label_pos",
                          callback=self.update_column_annotations)
        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
        cb.setCurrentIndex(self.column_label_pos)
        gui.checkBox(self.controlArea,
                     self,
                     "keep_aspect",
                     "Keep aspect ratio",
                     box="Resize",
                     callback=self.__aspect_mode_changed)

        gui.rubber(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

        # Scene with heatmap
        class HeatmapScene(QGraphicsScene):
            widget: Optional[HeatmapGridWidget] = None

        self.scene = self.scene = HeatmapScene(parent=self)
        self.view = GraphicsView(
            self.scene,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
            widgetResizable=True,
        )
        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
        self.view.customContextMenuRequested.connect(
            self._on_view_context_menu)
        self.mainArea.layout().addWidget(self.view)
        self.selected_rows = []
        self.__font_inc = QAction("Increase Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+>"))
        self.__font_dec = QAction("Decrease Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+<"))
        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
            apply_all([self.__font_inc, self.__font_dec],
                      lambda a: a.setShortcutVisibleInContextMenu(True))
        self.addActions([self.__font_inc, self.__font_dec])

    @property
    def center_palette(self):
        palette = self.color_cb.currentData()
        return bool(palette.flags & palette.Diverging)

    @property
    def _column_label_pos(self) -> HeatmapGridWidget.Position:
        return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole]

    def annotation_color_var_changed(self, value):
        cbselect(self.row_side_color_cb, value, Qt.EditRole)

    def annotation_var_changed(self, value):
        cbselect(self.annotation_text_cb, value, Qt.EditRole)

    def set_row_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.row_clustering != method:
            self.row_clustering = method
            cbselect(self.row_cluster_cb, method, ClusteringRole)
            self.__update_row_clustering()

    def set_col_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.col_clustering != method:
            self.col_clustering = method
            cbselect(self.col_cluster_cb, method, ClusteringRole)
            self.__update_column_clustering()

    def sizeHint(self) -> QSize:
        return super().sizeHint().expandedTo(QSize(900, 700))

    def color_palette(self):
        return self.color_cb.currentData().lookup_table()

    def color_map(self) -> GradientColorMap:
        return GradientColorMap(self.color_palette(),
                                (self.threshold_low, self.threshold_high),
                                0 if self.center_palette else None)

    def clear(self):
        self.data = None
        self.input_data = None
        self.effective_data = None
        self.kmeans_model = None
        self.merge_indices = None
        self.annotation_model.set_domain(None)
        self.annotation_var = None
        self.row_side_color_model.set_domain(None)
        self.annotation_color_var = None
        self.row_split_model.set_domain(None)
        self.split_by_var = None
        self.parts = None
        self.clear_scene()
        self.selected_rows = []
        self.__columns_cache.clear()
        self.__rows_cache.clear()
        self.__update_clustering_enable_state(None)

    def clear_scene(self):
        if self.scene.widget is not None:
            self.scene.widget.layoutDidActivate.disconnect(
                self.__on_layout_activate)
            self.scene.widget.selectionFinished.disconnect(
                self.on_selection_finished)
        self.scene.widget = None
        self.scene.clear()

        self.view.setSceneRect(QRectF())
        self.view.setHeaderSceneRect(QRectF())
        self.view.setFooterSceneRect(QRectF())

    @Inputs.data
    def set_dataset(self, data=None):
        """Set the input dataset to display."""
        self.closeContext()
        self.clear()
        self.clear_messages()

        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)

        if data is not None and not len(data):
            data = None

        if data is not None and sp.issparse(data.X):
            try:
                data = data.to_dense()
            except MemoryError:
                data = None
                self.Error.not_enough_memory()
            else:
                self.Information.sparse_densified()

        input_data = data

        # Data contains no attributes or meta attributes only
        if data is not None and len(data.domain.attributes) == 0:
            self.Error.no_continuous()
            input_data = data = None

        # Data contains some discrete attributes which must be filtered
        if data is not None and \
                any(var.is_discrete for var in data.domain.attributes):
            ndisc = sum(var.is_discrete for var in data.domain.attributes)
            data = data.transform(
                Domain([
                    var for var in data.domain.attributes if var.is_continuous
                ], data.domain.class_vars, data.domain.metas))
            if not data.domain.attributes:
                self.Error.no_continuous()
                input_data = data = None
            else:
                self.Information.discrete_ignored(ndisc,
                                                  "s" if ndisc > 1 else "")

        self.data = data
        self.input_data = input_data

        if data is not None:
            self.annotation_model.set_domain(self.input_data.domain)
            self.row_side_color_model.set_domain(self.input_data.domain)
            self.annotation_var = None
            self.annotation_color_var = None
            self.row_split_model.set_domain(data.domain)
            if data.domain.has_discrete_class:
                self.split_by_var = data.domain.class_var
            else:
                self.split_by_var = None
            self.openContext(self.input_data)
            if self.split_by_var not in self.row_split_model:
                self.split_by_var = None

        self.update_heatmaps()
        if data is not None and self.__pending_selection is not None:
            assert self.scene.widget is not None
            self.scene.widget.selectRows(self.__pending_selection)
            self.selected_rows = self.__pending_selection
            self.__pending_selection = None

        self.unconditional_commit()

    def __on_split_rows_activated(self):
        self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))

    def set_split_variable(self, var):
        if var != self.split_by_var:
            self.split_by_var = var
            self.update_heatmaps()

    def update_heatmaps(self):
        if self.data is not None:
            self.clear_scene()
            self.clear_messages()
            if self.col_clustering != Clustering.None_ and \
                    len(self.data.domain.attributes) < 2:
                self.Error.not_enough_features()
            elif (self.col_clustering != Clustering.None_ or
                  self.row_clustering != Clustering.None_) and \
                    len(self.data) < 2:
                self.Error.not_enough_instances()
            elif self.merge_kmeans and len(self.data) < 3:
                self.Error.not_enough_instances_k_means()
            else:
                parts = self.construct_heatmaps(self.data, self.split_by_var)
                self.construct_heatmaps_scene(parts, self.effective_data)
                self.selected_rows = []
        else:
            self.clear()

    def update_merge(self):
        self.kmeans_model = None
        self.merge_indices = None
        if self.data is not None and self.merge_kmeans:
            self.update_heatmaps()
            self.commit()

    def _make_parts(self, data, group_var=None):
        """
        Make initial `Parts` for data, split by group_var, group_key
        """
        if group_var is not None:
            assert group_var.is_discrete
            _col_data = table_column_data(data, group_var)
            row_indices = [
                np.flatnonzero(_col_data == i)
                for i in range(len(group_var.values))
            ]
            row_groups = [
                RowPart(title=name,
                        indices=ind,
                        cluster=None,
                        cluster_ordered=None)
                for name, ind in zip(group_var.values, row_indices)
            ]
        else:
            row_groups = [
                RowPart(title=None,
                        indices=range(0, len(data)),
                        cluster=None,
                        cluster_ordered=None)
            ]

        col_groups = [
            ColumnPart(title=None,
                       indices=range(0, len(data.domain.attributes)),
                       domain=data.domain,
                       cluster=None,
                       cluster_ordered=None)
        ]

        minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
        return Parts(row_groups, col_groups, span=(minv, maxv))

    def cluster_rows(self,
                     data: Table,
                     parts: 'Parts',
                     ordered=False) -> 'Parts':
        row_groups = []
        for row in parts.rows:
            if row.cluster is not None:
                cluster = row.cluster
            else:
                cluster = None
            if row.cluster_ordered is not None:
                cluster_ord = row.cluster_ordered
            else:
                cluster_ord = None

            if row.can_cluster:
                matrix = None
                need_dist = cluster is None or (ordered
                                                and cluster_ord is None)
                if need_dist:
                    subset = data[row.indices]
                    matrix = Orange.distance.Euclidean(subset)

                if cluster is None:
                    assert len(matrix) < self.MaxClustering
                    cluster = hierarchical.dist_matrix_clustering(
                        matrix, linkage=hierarchical.WARD)
                if ordered and cluster_ord is None:
                    assert len(matrix) < self.MaxOrderedClustering
                    cluster_ord = hierarchical.optimal_leaf_ordering(
                        cluster,
                        matrix,
                    )
            row_groups.append(
                row._replace(cluster=cluster, cluster_ordered=cluster_ord))

        return parts._replace(rows=row_groups)

    def cluster_columns(self, data, parts, ordered=False):
        assert len(parts.columns) == 1, "columns split is no longer supported"
        assert all(var.is_continuous for var in data.domain.attributes)

        col0 = parts.columns[0]
        if col0.cluster is not None:
            cluster = col0.cluster
        else:
            cluster = None
        if col0.cluster_ordered is not None:
            cluster_ord = col0.cluster_ordered
        else:
            cluster_ord = None
        need_dist = cluster is None or (ordered and cluster_ord is None)
        matrix = None
        if need_dist:
            data = Orange.distance._preprocess(data)
            matrix = np.asarray(Orange.distance.PearsonR(data, axis=0))
            # nan values break clustering below
            matrix = np.nan_to_num(matrix)

        if cluster is None:
            assert matrix is not None
            assert len(matrix) < self.MaxClustering
            cluster = hierarchical.dist_matrix_clustering(
                matrix, linkage=hierarchical.WARD)
        if ordered and cluster_ord is None:
            assert len(matrix) < self.MaxOrderedClustering
            cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix)

        col_groups = [
            col._replace(cluster=cluster, cluster_ordered=cluster_ord)
            for col in parts.columns
        ]
        return parts._replace(columns=col_groups)

    def construct_heatmaps(self, data, group_var=None) -> 'Parts':
        if self.merge_kmeans:
            if self.kmeans_model is None:
                effective_data = self.input_data.transform(
                    Orange.data.Domain([
                        var for var in self.input_data.domain.attributes
                        if var.is_continuous
                    ], self.input_data.domain.class_vars,
                                       self.input_data.domain.metas))
                nclust = min(self.merge_kmeans_k, len(effective_data) - 1)
                self.kmeans_model = kmeans_compress(effective_data, k=nclust)
                effective_data.domain = self.kmeans_model.domain
                merge_indices = [
                    np.flatnonzero(self.kmeans_model.labels == ind)
                    for ind in range(nclust)
                ]
                not_empty_indices = [
                    i for i, x in enumerate(merge_indices) if len(x) > 0
                ]
                self.merge_indices = \
                    [merge_indices[i] for i in not_empty_indices]
                if len(merge_indices) != len(self.merge_indices):
                    self.Warning.empty_clusters()
                effective_data = Orange.data.Table(
                    Orange.data.Domain(effective_data.domain.attributes),
                    self.kmeans_model.centroids[not_empty_indices])
            else:
                effective_data = self.effective_data

            group_var = None
        else:
            self.kmeans_model = None
            self.merge_indices = None
            effective_data = data

        self.effective_data = effective_data

        self.__update_clustering_enable_state(effective_data)

        parts = self._make_parts(effective_data, group_var)
        # Restore/update the row/columns items descriptions from cache if
        # available
        rows_cache_key = (group_var,
                          self.merge_kmeans_k if self.merge_kmeans else None)
        if rows_cache_key in self.__rows_cache:
            parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)

        if self.row_clustering != Clustering.None_:
            parts = self.cluster_rows(
                effective_data,
                parts,
                ordered=self.row_clustering == Clustering.OrderedClustering)
        if self.col_clustering != Clustering.None_:
            parts = self.cluster_columns(
                effective_data,
                parts,
                ordered=self.col_clustering == Clustering.OrderedClustering)

        # Cache the updated parts
        self.__rows_cache[rows_cache_key] = parts
        return parts

    def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None:
        _T = TypeVar("_T", bound=Union[RowPart, ColumnPart])

        def select_cluster(clustering: Clustering, item: _T) -> _T:
            if clustering == Clustering.None_:
                return item._replace(cluster=None, cluster_ordered=None)
            elif clustering == Clustering.Clustering:
                return item._replace(cluster=item.cluster,
                                     cluster_ordered=None)
            elif clustering == Clustering.OrderedClustering:
                return item._replace(cluster=item.cluster_ordered,
                                     cluster_ordered=None)
            else:  # pragma: no cover
                raise TypeError()

        rows = [
            select_cluster(self.row_clustering, rowitem)
            for rowitem in parts.rows
        ]
        cols = [
            select_cluster(self.col_clustering, colitem)
            for colitem in parts.columns
        ]
        parts = Parts(columns=cols, rows=rows, span=parts.span)

        self.setup_scene(parts, data)

    def setup_scene(self, parts, data):
        # type: (Parts, Table) -> None
        widget = HeatmapGridWidget()
        widget.setColorMap(self.color_map())
        self.scene.addItem(widget)
        self.scene.widget = widget
        columns = [v.name for v in data.domain.attributes]
        parts = HeatmapGridWidget.Parts(
            rows=[
                HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster)
                for r in parts.rows
            ],
            columns=[
                HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster)
                for c in parts.columns
            ],
            data=data.X,
            span=parts.span,
            row_names=None,
            col_names=columns,
        )
        widget.setHeatmaps(parts)
        side = self.row_side_colors()
        if side is not None:
            widget.setRowSideColorAnnotations(side[0],
                                              side[1],
                                              name=side[2].name)
        widget.setColumnLabelsPosition(self._column_label_pos)
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        widget.setShowAverages(self.averages)
        widget.setLegendVisible(self.legend)

        widget.layoutDidActivate.connect(self.__on_layout_activate)
        widget.selectionFinished.connect(self.on_selection_finished)

        self.update_annotations()
        self.view.setCentralWidget(widget)
        self.parts = parts

    def __update_scene_rects(self):
        widget = self.scene.widget
        if widget is None:
            return
        rect = widget.geometry()
        self.scene.setSceneRect(rect)
        self.view.setSceneRect(rect)
        self.view.setHeaderSceneRect(widget.headerGeometry())
        self.view.setFooterSceneRect(widget.footerGeometry())

    def __on_layout_activate(self):
        self.__update_scene_rects()

    def __aspect_mode_changed(self):
        widget = self.scene.widget
        if widget is None:
            return
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        # when aspect fixed the vertical sh is fixex, when not, it can
        # shrink vertically
        sp = widget.sizePolicy()
        if self.keep_aspect:
            sp.setVerticalPolicy(QSizePolicy.Fixed)
        else:
            sp.setVerticalPolicy(QSizePolicy.Preferred)
        widget.setSizePolicy(sp)

    def __update_clustering_enable_state(self, data):
        if data is not None:
            N = len(data)
            M = len(data.domain.attributes)
        else:
            N = M = 0

        rc_enabled = N <= self.MaxClustering
        rco_enabled = N <= self.MaxOrderedClustering
        cc_enabled = M <= self.MaxClustering
        cco_enabled = M <= self.MaxOrderedClustering
        row_clust, col_clust = self.row_clustering, self.col_clustering

        row_clust_msg = ""
        col_clust_msg = ""

        if not rco_enabled and row_clust == Clustering.OrderedClustering:
            row_clust = Clustering.Clustering
            row_clust_msg = "Row cluster ordering was disabled due to the " \
                            "input matrix being to big"
        if not rc_enabled and row_clust == Clustering.Clustering:
            row_clust = Clustering.None_
            row_clust_msg = "Row clustering was was disabled due to the " \
                            "input matrix being to big"

        if not cco_enabled and col_clust == Clustering.OrderedClustering:
            col_clust = Clustering.Clustering
            col_clust_msg = "Column cluster ordering was disabled due to " \
                            "the input matrix being to big"
        if not cc_enabled and col_clust == Clustering.Clustering:
            col_clust = Clustering.None_
            col_clust_msg = "Column clustering was disabled due to the " \
                            "input matrix being to big"

        self.col_clustering = col_clust
        self.row_clustering = row_clust

        self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg))
        self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg))

        # Disable/enable the combobox items for the clustering methods
        def setenabled(cb: QComboBox, clu: bool, clu_op: bool):
            model = cb.model()
            assert isinstance(model, QStandardItemModel)
            idx = cb.findData(Clustering.OrderedClustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu_op)
            idx = cb.findData(Clustering.Clustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu)

        setenabled(self.row_cluster_cb, rc_enabled, rco_enabled)
        setenabled(self.col_cluster_cb, cc_enabled, cco_enabled)

    def update_averages_stripe(self):
        """Update the visibility of the averages stripe.
        """
        widget = self.scene.widget
        if widget is not None:
            widget.setShowAverages(self.averages)

    def update_lowslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            low.setSliderPosition(high.value() - 1)
        self.update_color_schema()

    def update_highslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            high.setSliderPosition(low.value() + 1)
        self.update_color_schema()

    def update_color_schema(self):
        self.palette_name = self.color_cb.currentData().name
        w = self.scene.widget
        if w is not None:
            w.setColorMap(self.color_map())

    def __update_column_clustering(self):
        self.update_heatmaps()
        self.commit()

    def __update_row_clustering(self):
        self.update_heatmaps()
        self.commit()

    def update_legend(self):
        widget = self.scene.widget
        if widget is not None:
            widget.setLegendVisible(self.legend)

    def row_annotation_var(self):
        return self.annotation_var

    def row_annotation_data(self):
        var = self.row_annotation_var()
        if var is None:
            return None
        return column_str_from_table(self.input_data, var)

    def _merge_row_indices(self):
        if self.merge_kmeans and self.kmeans_model is not None:
            return self.merge_indices
        else:
            return None

    def set_annotation_var(self, var: Union[None, Variable, int]):
        if isinstance(var, int):
            var = self.annotation_model[var]
        if self.annotation_var != var:
            self.annotation_var = var
            self.update_annotations()

    def update_annotations(self):
        widget = self.scene.widget
        if widget is not None:
            annot_col = self.row_annotation_data()
            merge_indices = self._merge_row_indices()
            if merge_indices is not None and annot_col is not None:
                join = lambda _1: join_elided(", ", 42, _1, " ({} more)")
                annot_col = aggregate_apply(join, annot_col, merge_indices)
            if annot_col is not None:
                widget.setRowLabels(annot_col)
                widget.setRowLabelsVisible(True)
            else:
                widget.setRowLabelsVisible(False)
                widget.setRowLabels(None)

    def row_side_colors(self):
        var = self.annotation_color_var
        if var is None:
            return None
        column_data = column_data_from_table(self.input_data, var)
        span = (np.nanmin(column_data), np.nanmax(column_data))
        merges = self._merge_row_indices()
        if merges is not None:
            column_data = aggregate(var, column_data, merges)
        data, colormap = self._colorize(var, column_data)
        if var.is_continuous:
            colormap.span = span
        return data, colormap, var

    def set_annotation_color_var(self, var: Union[None, Variable, int]):
        """Set the current side color annotation variable."""
        if isinstance(var, int):
            var = self.row_side_color_model[var]
        if self.annotation_color_var != var:
            self.annotation_color_var = var
            self.update_row_side_colors()

    def update_row_side_colors(self):
        widget = self.scene.widget
        if widget is None:
            return
        colors = self.row_side_colors()
        if colors is None:
            widget.setRowSideColorAnnotations(None)
        else:
            widget.setRowSideColorAnnotations(colors[0], colors[1],
                                              colors[2].name)

    def _colorize(self, var: Variable,
                  data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
        palette = var.palette  # type: Palette
        colors = np.array(
            [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
            dtype=np.uint8,
        )
        if var.is_discrete:
            mask = np.isnan(data)
            data[mask] = -1
            data = data.astype(int)
            if mask.any():
                values = (*var.values, "N/A")
            else:
                values = var.values
                colors = colors[:-1]
            return data, CategoricalColorMap(colors, values)
        elif var.is_continuous:
            cmap = GradientColorMap(colors[:-1])
            return data, cmap
        else:
            raise TypeError

    def update_column_annotations(self):
        widget = self.scene.widget
        if self.data is not None and widget is not None:
            widget.setColumnLabelsPosition(self._column_label_pos)

    def __adjust_font_size(self, diff):
        widget = self.scene.widget
        if widget is None:
            return
        curr = widget.font().pointSizeF()
        new = curr + diff

        self.__font_dec.setEnabled(new > 1.0)
        self.__font_inc.setEnabled(new <= 32)
        if new > 1.0:
            font = QFont()
            font.setPointSizeF(new)
            widget.setFont(font)

    def _on_view_context_menu(self, pos):
        widget = self.scene.widget
        if widget is None:
            return
        assert isinstance(widget, HeatmapGridWidget)
        menu = QMenu(self.view.viewport())
        menu.setAttribute(Qt.WA_DeleteOnClose)
        menu.addActions(self.view.actions())
        menu.addSeparator()
        menu.addActions([self.__font_inc, self.__font_dec])
        menu.addSeparator()
        a = QAction("Keep aspect ratio", menu, checkable=True)
        a.setChecked(self.keep_aspect)

        def ontoggled(state):
            self.keep_aspect = state
            self.__aspect_mode_changed()

        a.toggled.connect(ontoggled)
        menu.addAction(a)
        menu.popup(self.view.viewport().mapToGlobal(pos))

    def on_selection_finished(self):
        if self.scene.widget is not None:
            self.selected_rows = list(self.scene.widget.selectedRows())
        else:
            self.selected_rows = []
        self.commit()

    def commit(self):
        data = None
        indices = None
        if self.merge_kmeans:
            merge_indices = self.merge_indices
        else:
            merge_indices = None

        if self.input_data is not None and self.selected_rows:
            indices = self.selected_rows
            if merge_indices is not None:
                # expand merged indices
                indices = np.hstack([merge_indices[i] for i in indices])

            data = self.input_data[indices]

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.input_data, indices))

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()

    def send_report(self):
        self.report_items((
            ("Columns:",
             "Clustering" if self.col_clustering else "No sorting"),
            ("Rows:", "Clustering" if self.row_clustering else "No sorting"),
            ("Split:", self.split_by_var is not None
             and self.split_by_var.name),
            ("Row annotation", self.annotation_var is not None
             and self.annotation_var.name),
        ))
        self.report_plot()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version is not None and version < 3:

            def st2cl(state: bool) -> Clustering:
                return Clustering.OrderedClustering if state else \
                    Clustering.None_

            rc = settings.pop("row_clustering", False)
            cc = settings.pop("col_clustering", False)
            settings["row_clustering_method"] = st2cl(rc).name
            settings["col_clustering_method"] = st2cl(cc).name
class OWStackAlign(OWWidget):
    # Widget's name as displayed in the canvas
    name = "Align Stack"

    # Short widget description
    description = ("Aligns and crops a stack of images using various methods.")

    icon = "icons/stackalign.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Stack of images", Table, default=True)

    class Outputs:
        newstack = Output("Aligned image stack", Table, default=True)

    class Error(OWWidget.Error):
        nan_in_image = Msg("Unknown values within images: {} unknowns")
        invalid_axis = Msg("Invalid axis: {}")

    autocommit = settings.Setting(True)

    want_main_area = True
    want_control_area = True
    resizing_enabled = False

    settingsHandler = DomainContextHandler()

    sobel_filter = settings.Setting(False)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    ref_frame_num = settings.Setting(0)

    def __init__(self):
        super().__init__()

        # TODO: add input box for selecting which should be the reference frame
        box = gui.widgetBox(self.controlArea, "Axes")

        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=ContinuousVariable)
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self._update_attr,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(box,
                                      self,
                                      "attr_y",
                                      label="Axis y:",
                                      callback=self._update_attr,
                                      model=self.xy_model,
                                      **common_options)

        self.contextAboutToBeOpened.connect(self._init_interface_data)

        box = gui.widgetBox(self.controlArea, "Parameters")

        gui.checkBox(box,
                     self,
                     "sobel_filter",
                     label="Use sobel filter",
                     callback=self._sobel_changed)
        gui.separator(box)
        hbox = gui.hBox(box)
        self.le1 = lineEditIntRange(box,
                                    self,
                                    "ref_frame_num",
                                    bottom=1,
                                    default=1,
                                    callback=self._ref_frame_changed)
        hbox.layout().addWidget(QLabel("Reference frame:", self))
        hbox.layout().addWidget(self.le1)

        gui.rubber(self.controlArea)

        plot_box = gui.widgetBox(self.mainArea, "Shift curves")
        self.plotview = pg.PlotWidget(background="w")
        plot_box.layout().addWidget(self.plotview)
        # TODO:  resize widget to make it a bit smaller

        self.data = None

        gui.auto_commit(self.controlArea, self, "autocommit", "Send Data")

    def _sanitize_ref_frame(self):
        if self.ref_frame_num > self.data.X.shape[1]:
            self.ref_frame_num = self.data.X.shape[1]

    def _ref_frame_changed(self):
        self._sanitize_ref_frame()
        self.commit()

    def _sobel_changed(self):
        self.commit()

    def _init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def _init_interface_data(self, args):
        data = args[0]
        same_domain = (self.data and data and data.domain == self.data.domain)
        if not same_domain:
            self._init_attr_values(data)

    def _update_attr(self):
        self.commit()

    @Inputs.data
    def set_data(self, dataset):
        self.closeContext()
        self.openContext(dataset)
        if dataset is not None:
            self.data = dataset
            self._sanitize_ref_frame()
        else:
            self.data = None
        self.Error.nan_in_image.clear()
        self.Error.invalid_axis.clear()
        self.commit()

    def commit(self):
        new_stack = None

        self.Error.nan_in_image.clear()
        self.Error.invalid_axis.clear()

        self.plotview.plotItem.clear()

        if self.data and len(
                self.data.domain.attributes) and self.attr_x and self.attr_y:
            try:
                shifts, new_stack = process_stack(
                    self.data,
                    self.attr_x,
                    self.attr_y,
                    upsample_factor=100,
                    use_sobel=self.sobel_filter,
                    ref_frame_num=self.ref_frame_num - 1)
            except NanInsideHypercube as e:
                self.Error.nan_in_image(e.args[0])
            except InvalidAxisException as e:
                self.Error.invalid_axis(e.args[0])
            else:
                # TODO: label axes
                frames = np.linspace(1, shifts.shape[0], shifts.shape[0])
                self.plotview.plotItem.plot(frames,
                                            shifts[:, 0],
                                            pen=pg.mkPen(color=(255, 40, 0),
                                                         width=3),
                                            symbol='o',
                                            symbolBrush=(255, 40, 0),
                                            symbolPen='w',
                                            symbolSize=7)
                self.plotview.plotItem.plot(frames,
                                            shifts[:, 1],
                                            pen=pg.mkPen(color=(0, 139, 139),
                                                         width=3),
                                            symbol='o',
                                            symbolBrush=(0, 139, 139),
                                            symbolPen='w',
                                            symbolSize=7)
                self.plotview.getPlotItem().setLabel('bottom', 'Frame number')
                self.plotview.getPlotItem().setLabel('left', 'Shift / pixel')
                self.plotview.getPlotItem().addLine(
                    self.ref_frame_num,
                    pen=pg.mkPen(color=(150, 150, 150),
                                 width=3,
                                 style=Qt.DashDotDotLine))

        self.Outputs.newstack.send(new_stack)

    def send_report(self):
        self.report_items((("Use sobel filter", str(self.sobel_filter)), ))
class NormalizeEditor(BaseEditorOrange):
    """
    Normalize spectra.
    """
    # Normalization methods
    Normalizers = [
        ("Vector Normalization", Normalize.Vector),
        ("Area Normalization", Normalize.Area),
        ("Attribute Normalization", Normalize.Attribute),
        ("Normalize by Reference", NORMALIZE_BY_REFERENCE)]

    def __init__(self, parent=None, **kwargs):
        super().__init__(parent, **kwargs)
        layout = QVBoxLayout()
        self.controlArea.setLayout(layout)

        self.__method = Normalize.Vector
        self.lower = 0
        self.upper = 4000
        self.int_method = 0
        self.attrs = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                 valid_types=ContinuousVariable)
        self.attrform = QFormLayout()
        self.chosen_attr = None
        self.last_domain = None
        self.saved_attr = None
        self.attrcb = gui.comboBox(None, self, "chosen_attr", callback=self.edited.emit,
                                   model=self.attrs)
        self.attrform.addRow("Normalize to", self.attrcb)

        self.areaform = QFormLayout()
        self.int_method_cb = QComboBox(enabled=False)
        self.int_method_cb.addItems(IntegrateEditor.Integrators)
        minf, maxf = -sys.float_info.max, sys.float_info.max
        self.lspin = SetXDoubleSpinBox(
            minimum=minf, maximum=maxf, singleStep=0.5,
            value=self.lower, enabled=False)
        self.uspin = SetXDoubleSpinBox(
            minimum=minf, maximum=maxf, singleStep=0.5,
            value=self.upper, enabled=False)
        self.areaform.addRow("Normalize to", self.int_method_cb)
        self.areaform.addRow("Lower limit", self.lspin)
        self.areaform.addRow("Upper limit", self.uspin)

        self._group = group = QButtonGroup(self)

        for name, method in self.Normalizers:
            rb = QRadioButton(self, text=name, checked=self.__method == method)

            layout.addWidget(rb)
            if method is Normalize.Attribute:
                layout.addLayout(self.attrform)
            elif method is Normalize.Area:
                layout.addLayout(self.areaform)
            group.addButton(rb, method)

        group.buttonClicked.connect(self.__on_buttonClicked)

        self.lspin.focusIn = self.activateOptions
        self.uspin.focusIn = self.activateOptions
        self.focusIn = self.activateOptions

        self.lspin.valueChanged[float].connect(self.setL)
        self.lspin.editingFinished.connect(self.reorderLimits)
        self.uspin.valueChanged[float].connect(self.setU)
        self.uspin.editingFinished.connect(self.reorderLimits)
        self.int_method_cb.currentIndexChanged.connect(self.setinttype)
        self.int_method_cb.activated.connect(self.edited)

        self.lline = MovableVline(position=self.lower, label="Low limit")
        self.lline.sigMoved.connect(self.setL)
        self.lline.sigMoveFinished.connect(self.reorderLimits)
        self.uline = MovableVline(position=self.upper, label="High limit")
        self.uline.sigMoved.connect(self.setU)
        self.uline.sigMoveFinished.connect(self.reorderLimits)

        self.user_changed = False

    def activateOptions(self):
        self.parent_widget.curveplot.clear_markings()
        if self.__method == Normalize.Area:
            if self.lline not in self.parent_widget.curveplot.markings:
                self.parent_widget.curveplot.add_marking(self.lline)
            if (self.uline not in self.parent_widget.curveplot.markings
                    and IntegrateEditor.Integrators_classes[self.int_method]
                    is not Integrate.PeakAt):
                self.parent_widget.curveplot.add_marking(self.uline)

    def setParameters(self, params):
        if params: #parameters were manually set somewhere else
            self.user_changed = True
        method = params.get("method", Normalize.Vector)
        lower = params.get("lower", 0)
        upper = params.get("upper", 4000)
        int_method = params.get("int_method", 0)
        if method not in [method for name, method in self.Normalizers]:
            # handle old worksheets
            method = Normalize.Vector
        self.setMethod(method)
        self.int_method_cb.setCurrentIndex(int_method)
        self.setL(lower, user=False)
        self.setU(upper, user=False)
        self.saved_attr = params.get("attr")  # chosen_attr will be set when data are connected

    def parameters(self):
        return {"method": self.__method, "lower": self.lower,
                "upper": self.upper, "int_method": self.int_method,
                "attr": self.chosen_attr}

    def setMethod(self, method):
        if self.__method != method:
            self.__method = method
            b = self._group.button(method)
            b.setChecked(True)
            for widget in [self.attrcb, self.int_method_cb, self.lspin, self.uspin]:
                widget.setEnabled(False)
            if method is Normalize.Attribute:
                self.attrcb.setEnabled(True)
            elif method is Normalize.Area:
                self.int_method_cb.setEnabled(True)
                self.lspin.setEnabled(True)
                self.uspin.setEnabled(True)
            self.activateOptions()
            self.changed.emit()

    def setL(self, lower, user=True):
        if user:
            self.user_changed = True
        if self.lower != lower:
            self.lower = lower
            with blocked(self.lspin):
                self.lspin.setValue(lower)
                self.lline.setValue(lower)
            self.changed.emit()

    def setU(self, upper, user=True):
        if user:
            self.user_changed = True
        if self.upper != upper:
            self.upper = upper
            with blocked(self.uspin):
                self.uspin.setValue(upper)
                self.uline.setValue(upper)
            self.changed.emit()

    def reorderLimits(self):
        if (IntegrateEditor.Integrators_classes[self.int_method]
                is Integrate.PeakAt):
            self.upper = self.lower + 10
        limits = [self.lower, self.upper]
        self.lower, self.upper = min(limits), max(limits)
        self.lspin.setValue(self.lower)
        self.uspin.setValue(self.upper)
        self.lline.setValue(self.lower)
        self.uline.setValue(self.upper)
        self.edited.emit()

    def setinttype(self):
        if self.int_method != self.int_method_cb.currentIndex():
            self.int_method = self.int_method_cb.currentIndex()
            self.reorderLimits()
            self.activateOptions()
            self.changed.emit()

    def __on_buttonClicked(self):
        method = self._group.checkedId()
        if method != self.__method:
            self.setMethod(self._group.checkedId())
            self.edited.emit()

    @staticmethod
    def createinstance(params):
        method = params.get("method", Normalize.Vector)
        lower = params.get("lower", 0)
        upper = params.get("upper", 4000)
        int_method_index = params.get("int_method", 0)
        int_method = IntegrateEditor.Integrators_classes[int_method_index]
        attr = params.get("attr", None)
        if method != NORMALIZE_BY_REFERENCE:
            return Normalize(method=method, lower=lower, upper=upper,
                             int_method=int_method, attr=attr)
        else:
            # avoids circular imports
            from orangecontrib.spectroscopy.widgets.owpreprocess import REFERENCE_DATA_PARAM
            reference = params.get(REFERENCE_DATA_PARAM, None)
            return NormalizeReference(reference=reference)

    def set_preview_data(self, data):
        edited = False
        if not self.user_changed:
            x = getx(data)
            if len(x):
                self.setL(min(x))
                self.setU(max(x))
                edited = True
        if data is not None and data.domain != self.last_domain:
            self.last_domain = data.domain
            self.attrs.set_domain(data.domain)
            try:  # try to load the feature
                self.chosen_attr = self.saved_attr
            except ValueError:  # could not load the chosen attr
                self.chosen_attr = self.attrs[0] if self.attrs else None
                self.saved_attr = self.chosen_attr
            edited = True
        if edited:
            self.edited.emit()
Beispiel #28
0
class OWAlignDatasets(widget.OWWidget):
    name = "Align Datasets"
    description = "Alignment of multiple datasets with a diagram of correlation visualization."
    icon = "icons/AlignDatasets.svg"
    priority = 240

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        transformed_data = Output("Transformed Data", Table)
        genes_components = Output("Genes per n. Components", Table)

    settingsHandler = DomainContextHandler()
    axis_labels = ContextSetting(10)
    source_id = ContextSetting(None)
    ncomponents = ContextSetting(20)
    ngenes = ContextSetting(30)
    scoring = ContextSetting(list(SCORINGS.keys())[0])
    quantile_normalization = ContextSetting(False)
    quantile_normalization_perc = ContextSetting(2.5)
    dynamic_time_warping = ContextSetting(False)

    auto_update = Setting(True)
    auto_commit = Setting(True)

    graph_name = "plot.plotItem"

    class Error(widget.OWWidget.Error):
        no_features = widget.Msg("At least 1 feature is required")
        no_instances = widget.Msg("At least 2 data instances are required for each class")
        no_class = widget.Msg("At least 1 Discrete class variable is required")
        nan_class = widget.Msg(
            "Data contains undefined instances for the selected Data source indicator")
        nan_input = widget.Msg("Input data contains non numeric values")
        sparse_data = widget.Msg("Sparse data is not supported")

    def __init__(self):
        super().__init__()
        self.data = None
        self.source_id = None
        self._mas = None
        self._Ws = None
        self._transformed = None
        self._components = None
        self._use_genes = None
        self._shared_correlations = None
        self._transformed_table = None
        self._line = False
        self._feature_model = DomainModel(valid_types=DiscreteVariable, separators=False)
        self._feature_model.set_domain(None)
        self._init_mas()
        self._legend = None
        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10
        )
        # Data source indicator
        box = gui.vBox(self.controlArea, "Data source indicator")

        gui.comboBox(
            box, self, "source_id", sendSelectedValue=True,
            callback=self._update_combo_source_id,
            model=self._feature_model,
        )

        # Canonical correlation analysis
        box = gui.vBox(self.controlArea, "Canonical correlation analysis")
        gui.spin(
            box, self, "ncomponents", 1, MAX_COMPONENTS,
            callback=self._update_selection_component_spin,
            keyboardTracking=False,
            label="Num. of components"
        )

        # Shared genes
        box = gui.vBox(self.controlArea, "Shared genes")
        gui.spin(
            box, self, "ngenes", 1, MAX_GENES,
            callback=self._update_ngenes_spin,
            keyboardTracking=False,
        )
        form.addRow(
            "Num. of genes",
            self.controls.ngenes
        )

        gui.comboBox(
            box, self, "scoring",
            callback=self._update_scoring_combo,
            items=list(SCORINGS.keys()), sendSelectedValue=True,
            editable=False,
        )
        form.addRow(
            "Scoring:",
            self.controls.scoring
        )

        box.layout().addLayout(form)

        # Post-processing
        box = gui.vBox(self.controlArea, "Post-processing")
        gui.doubleSpin(
            box, self, "quantile_normalization_perc", minv=0, maxv=49, step=5e-1,
            callback=self._update_quantile_normalization,
            checkCallback=self._update_quantile_normalization,
            controlWidth=80, alignment=Qt.AlignRight,
            label="Quantile normalization", checked="quantile_normalization",
        )
        self.controls.quantile_normalization_perc.setSuffix("%")
        gui.checkBox(
            box, self, "dynamic_time_warping",
            callback=self._update_dynamic_time_warping,
            label="Dynamic time warping"
        )

        self.controlArea.layout().addStretch()

        gui.auto_commit(self.controlArea, self, "auto_commit", "Apply",
                        callback=self._invalidate_selection(),
                        checkbox_label="Apply automatically")

        self.plot = pg.PlotWidget(background="w")

        axis = self.plot.getAxis("bottom")
        axis.setLabel("Correlation components")
        axis = self.plot.getAxis("left")
        axis.setLabel("Correlation strength")
        self.plot_horlabels = []
        self.plot_horlines = []

        self.plot.getViewBox().setMenuEnabled(False)
        self.plot.getViewBox().setMouseEnabled(False, False)
        self.plot.showGrid(True, True, alpha=0.5)
        self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0))

        self.mainArea.layout().addWidget(self.plot)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.clear_messages()
        self.clear()
        self.information()
        self.clear_outputs()
        self._feature_model.set_domain(None)
        self.data = data

        if self.data:
            self._feature_model.set_domain(self.data.domain)
            if self._feature_model:
                # self.openContext(data)
                self.openContext(self.data.domain)
                if self.source_id is None or self.source_id == '':
                    for model in self._feature_model:
                        y = np.array(self.data.get_column_view(model)[0], dtype=np.float64)
                        _, counts = np.unique(y, return_counts=True)
                        if np.isfinite(y).all() and min(counts) > 1:
                            self.source_id = model
                            self._reset_max_components()
                            break

                if not self.source_id:
                    self.Error.nan_class()
                    return
                if len(self.data.domain.attributes) == 0:
                    self.Error.no_features()
                    return
                if len(self.data) == 0:
                    self.Error.no_instances()
                    return
                if np.isnan(self.data.X).any():
                    self.Error.nan_input()
                    return
                y = np.array(self.data.get_column_view(self.source_id)[0], dtype=np.float64)
                _, counts = np.unique(y, return_counts=True)
                if min(counts) < 2:
                    self.Error.no_instances()
                    return
                self._reset_max_components()
                self.fit()

            else:
                self.Error.no_class()
                self.clear()
                return

    def fit(self):
        if self.data is None:
            return
        global MAX_COMPONENTS
        if self.ncomponents > MAX_COMPONENTS:
            self.ncomponents = MAX_COMPONENTS

        self._init_mas()
        X = self.data.X
        y = self.data.get_column_view(self.source_id)[0]

        self._Ws = self._mas.fit(X, y)
        self._shared_correlations = self._mas.shared_correlations
        if np.isnan(np.sum(self._shared_correlations)):
            self._shared_correlations = np.array([interpolate_nans(x) for x in self._shared_correlations])
        self._use_genes = self._mas.use_genes

        self._setup_plot()
        if self.auto_commit:
            self.commit()

    def clear(self):
        self.data = None
        self.source_id = None
        self._mas = None
        self._Ws = None
        self._transformed = None
        self._transformed_table = None
        self._components = None
        self._use_genes = None
        self._shared_correlations = None
        self._feature_model.set_domain(None)
        self.clear_plot()

    def clear_plot(self):
        try:
            self._legend.scene().removeItem(self._legend)
            self._legend = None
        except Exception as e:
            pass
        self._line = False
        self.plot_horlabels = []
        self.plot_horlines = []
        self._mas = None
        self._setup_plot()

    def clear_outputs(self):
        self.Outputs.transformed_data.send(None)
        self.Outputs.genes_components.send(None)

    def _reset_max_components(self):
        y = np.array(self.data.get_column_view(self.source_id)[0], dtype=np.float64)
        _, counts = np.unique(y, return_counts=True)
        global MAX_COMPONENTS
        if min(counts) < MAX_COMPONENTS_DEFAULT or len(
                self.data.domain.attributes) < MAX_COMPONENTS_DEFAULT:
            MAX_COMPONENTS = min(min(counts), len(self.data.domain.attributes)) - 1
            if self.ncomponents > MAX_COMPONENTS:
                self.ncomponents = MAX_COMPONENTS // 2
            self.controls.ncomponents.setMaximum(MAX_COMPONENTS)
        else:
            MAX_COMPONENTS = MAX_COMPONENTS_DEFAULT
            self.ncomponents = 20
            self.controls.ncomponents.setMaximum(MAX_COMPONENTS)

    def _init_mas(self):
        self._mas = SeuratAlignmentModel(
            n_components=MAX_COMPONENTS,
            n_metagenes=self.ngenes,
            gene_scoring=SCORINGS[self.scoring],
        )

    def get_model(self):
        if self.data is None:
            return

        self.fit()
        self._setup_plot()
        self.commit()

    def _setup_plot(self):
        self.plot.clear()
        if self._mas is None:
            return

        shared_correlations = self._shared_correlations
        p = MAX_COMPONENTS

        # Colors chosen based on: http://colorbrewer2.org/?type=qualitative&scheme=Set1&n=9
        colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628',
                  '#f781bf', '#999999']

        if self._legend is not None:
            self._legend.scene().removeItem(self._legend)
        self._legend = self.plot.addLegend(offset=(-1, 1))
        # correlation lines
        offset = 2
        if MAX_COMPONENTS > 2 * offset + 1:
            smoothed_correlations = smooth_correlations(shared_correlations, offset=offset)
        else:
            smoothed_correlations = shared_correlations
        plotitem = dict()
        for i, corr in enumerate(smoothed_correlations):
            plotitem[i] = self.plot.plot(np.arange(p), corr,
                                         pen=pg.mkPen(QColor(colors[i]), width=2),
                                         antialias=True)  # name=self.source_id.values[i]
        # self.plot.plotItem.legend.addItem(3, "maximum value")

        for i in range(len(plotitem)):
            self._legend.addItem(MyLegendItem(pg.ScatterPlotItem(pen=colors[i])),
                                 self.source_id.values[i])

        # vertical movable line
        cutpos = self.ncomponents - 1
        self._line = pg.InfiniteLine(
            angle=90, pos=cutpos, movable=True, bounds=(0, p - 1))
        self._line.setCursor(Qt.SizeHorCursor)
        self._line.setPen(pg.mkPen(QColor(Qt.black), width=2))
        self._line.sigPositionChanged.connect(self._on_cut_changed)
        self.plot.addItem(self._line)

        # horizontal lines
        self.plot_horlines = tuple(
            pg.PlotCurveItem(pen=pg.mkPen(QColor(colors[i]), style=Qt.DashLine)) for i in
            range(len(shared_correlations))
        )
        self.plot_horlabels = tuple(
            pg.TextItem(color=QColor('k'), anchor=(0, 1)) for _ in range(len(shared_correlations))
        )

        for item in self.plot_horlabels + self.plot_horlines:
            self.plot.addItem(item)
        self._set_horline_pos()

        # self.plot.setRange(xRange=(0.0, p - 1), yRange=(0.0, 1.0))
        self.plot.setXRange(0.0, p - 1, padding=0)
        self.plot.setYRange(0.0, 1.0, padding=0)
        self._update_axis()

    def _set_horline_pos(self):
        cutidx = self.ncomponents - 1
        for line, label, curve in zip(self.plot_horlines, self.plot_horlabels,
                                      self._shared_correlations):
            y = curve[cutidx]
            line.setData([-1, cutidx], 2 * [y])
            label.setPos(cutidx, y)
            label.setPlainText("{:.3f}".format(y))

    def _on_cut_changed(self, line):
        # cut changed by means of a cut line over the scree plot.
        value = int(round(line.value()))
        components = value + 1

        if not (self.ncomponents == 0 and
                components == len(self._components)):
            self.ncomponents = components

        self._line.setValue(value)
        self._set_horline_pos()
        self.commit()

    def _update_selection_component_spin(self):
        # cut changed by "ncomponents" spin.
        if self._mas is None:
            self._invalidate_selection()
            return

        if np.floor(self._line.value()) + 1 != self.ncomponents:
            self._line.setValue(self.ncomponents - 1)

        self.commit()

    def _invalidate_selection(self):
        if self.data is not None:
            self._transformed = None
            self.commit()

    def _update_scoring_combo(self):
        self.fit()
        self._invalidate_selection()

    def _update_dynamic_time_warping(self):
        self._invalidate_selection()

    def _update_quantile_normalization(self):
        self._invalidate_selection()

    def _update_ngenes_spin(self):
        self.clear_plot()
        if self.data is None:
            return
        if self._has_nan_classes():
            self.Error.nan_class()
            return
        self.clear_messages()
        self.fit()
        self._invalidate_selection()

    def _update_combo_source_id(self):
        self.clear_plot()
        if self.data is None:
            return
        y = np.array(self.data.get_column_view(self.source_id)[0], dtype=np.float64)
        _, counts = np.unique(y, return_counts=True)
        if min(counts) < 2:
            self.Error.no_instances()
            return
        self._reset_max_components()
        if self._has_nan_classes():
            self.Error.nan_class()
            return
        self.clear_messages()
        self.fit()
        self._invalidate_selection()

    def _update_axis(self):
        p = MAX_COMPONENTS
        axis = self.plot.getAxis("bottom")
        d = max((p - 1) // (self.axis_labels - 1), 1)
        axis.setTicks([[(i, str(i + 1)) for i in range(0, p, d)]])

    def _has_nan_classes(self):
        y = np.array(self.data.get_column_view(self.source_id)[0], dtype=np.float64)
        return not np.isfinite(y).all()

    def commit(self):
        transformed_table = meta_genes = None
        if self._mas is not None:
            # Compute the full transform (MAX_COMPONENTS components) only once.
            if self._transformed is None:
                X = self.data.X
                y = self.data.get_column_view(self.source_id)[0]
                self._transformed = self._mas.transform(X, y, normalize=self.quantile_normalization,
                                                        quantile=self.quantile_normalization_perc,
                                                        dtw=self.dynamic_time_warping)

                attributes = tuple(ContinuousVariable.make("CCA{}".format(x + 1)) for x in
                                   range(MAX_COMPONENTS))
                dom = Domain(
                    attributes,
                    self.data.domain.class_vars,
                    self.data.domain.metas
                )

                # Meta-genes
                meta_genes = self.data.transform(dom)
                genes_components = np.zeros((self.data.X.shape[1], MAX_COMPONENTS))
                for key, genes in self._mas.use_genes.items():
                    for gene in genes:
                        genes_components[gene - 1, key] = genes.index(gene) + 1
                genes_components[genes_components == 0] = np.NaN
                meta_genes.X = genes_components
                self.meta_genes = Table.from_numpy(Domain(attributes), genes_components)

                # Transformed data
                transformed = self._transformed
                new_domain = add_columns(self.data.domain, attributes=attributes)
                transformed_table_temp = self.data.transform(new_domain)
                transformed_table_temp.X[:, -MAX_COMPONENTS:] = transformed
                self.transformed_table = Table.from_table(dom, transformed_table_temp)

            ncomponents_attributes = tuple(ContinuousVariable.make("CCA{}".format(x + 1)) for x in
                                           range(self.ncomponents))
            ncomponents_domain = Domain(
                ncomponents_attributes,
                self.data.domain.class_vars,
                self.data.domain.metas
            )

            meta_genes = self.meta_genes.transform(Domain(ncomponents_attributes))
            transformed_table = self.transformed_table.transform(ncomponents_domain)

        self.Outputs.transformed_data.send(transformed_table)
        self.Outputs.genes_components.send(meta_genes)

    def send_report(self):
        if self.data is None:
            return
        self.report_items((
            ("Source ID", self.source_id),
            ("Selected num. of components", self.ncomponents),
            ("Selected num. of genes", self.ngenes),
            ("Scoring", self.scoring),
            ("Quantile normalization", True if self.quantile_normalization else "False"),
            ("Quantile normalization percentage",
             self.quantile_normalization_perc if self.quantile_normalization else False),
            ("Dynamic time warping", True if self.dynamic_time_warping else "False")
        ))
        self.report_plot()

    """
class OWHyper(OWWidget):
    name = "HyperSpectra"

    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        selected_data = Output("Selection", Orange.data.Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    icon = "icons/hyper.svg"
    priority = 20
    replaces = ["orangecontrib.infrared.widgets.owhyper.OWHyper"]
    keywords = ["image", "spectral", "chemical", "imaging"]

    settings_version = 5
    settingsHandler = DomainContextHandler()

    imageplot = SettingProvider(ImagePlot)
    curveplot = SettingProvider(CurvePlotHyper)

    integration_method = Setting(0)
    integration_methods = Integrate.INTEGRALS
    value_type = Setting(0)
    attr_value = ContextSetting(None)

    show_visible_image = Setting(False)
    visible_image_name = Setting(None)
    visible_image_composition = Setting('Normal')
    visible_image_opacity = Setting(120)

    lowlim = Setting(None)
    highlim = Setting(None)
    choose = Setting(None)

    graph_name = "imageplot.plotview"  # defined so that the save button is shown

    class Warning(OWWidget.Warning):
        threshold_error = Msg("Low slider should be less than High")

    class Error(OWWidget.Error):
        image_too_big = Msg("Image for chosen features is too big ({} x {}).")

    class Information(OWWidget.Information):
        not_shown = Msg("Undefined positions: {} data point(s) are not shown.")

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            # delete the saved attr_value to prevent crashes
            try:
                del settings_["context_settings"][0].values["attr_value"]
            except:  # pylint: disable=bare-except
                pass

        # migrate selection
        if version <= 2:
            try:
                current_context = settings_["context_settings"][0]
                selection = getattr(current_context, "selection", None)
                if selection is not None:
                    selection = [(i, 1)
                                 for i in np.flatnonzero(np.array(selection))]
                    settings_.setdefault(
                        "imageplot", {})["selection_group_saved"] = selection
            except:  # pylint: disable=bare-except
                pass

    @classmethod
    def migrate_context(cls, context, version):
        if version <= 3 and "curveplot" in context.values:
            CurvePlot.migrate_context_sub_feature_color(
                context.values["curveplot"], version)

    def __init__(self):
        super().__init__()

        dbox = gui.widgetBox(self.controlArea, "Image values")

        rbox = gui.radioButtons(dbox,
                                self,
                                "value_type",
                                callback=self._change_integration)

        gui.appendRadioButton(rbox, "From spectra")

        self.box_values_spectra = gui.indentedBox(rbox)

        gui.comboBox(self.box_values_spectra,
                     self,
                     "integration_method",
                     items=(a.name for a in self.integration_methods),
                     callback=self._change_integral_type)
        gui.rubber(self.controlArea)

        gui.appendRadioButton(rbox, "Use feature")

        self.box_values_feature = gui.indentedBox(rbox)

        self.feature_value_model = DomainModel(
            DomainModel.SEPARATED, valid_types=DomainModel.PRIMITIVE)
        self.feature_value = gui.comboBox(self.box_values_feature,
                                          self,
                                          "attr_value",
                                          contentsLength=12,
                                          searchable=True,
                                          callback=self.update_feature_value,
                                          model=self.feature_value_model)

        splitter = QSplitter(self)
        splitter.setOrientation(Qt.Vertical)
        self.imageplot = ImagePlot(self)
        self.imageplot.selection_changed.connect(self.output_image_selection)

        # do not save visible image (a complex structure as a setting;
        # only save its name)
        self.visible_image = None
        self.setup_visible_image_controls()

        self.curveplot = CurvePlotHyper(self, select=SELECTONE)
        self.curveplot.selection_changed.connect(self.redraw_integral_info)
        self.curveplot.plot.vb.x_padding = 0.005  # pad view so that lines are not hidden
        splitter.addWidget(self.imageplot)
        splitter.addWidget(self.curveplot)
        self.mainArea.layout().addWidget(splitter)

        self.line1 = MovableVline(position=self.lowlim,
                                  label="",
                                  report=self.curveplot)
        self.line1.sigMoved.connect(lambda v: setattr(self, "lowlim", v))
        self.line2 = MovableVline(position=self.highlim,
                                  label="",
                                  report=self.curveplot)
        self.line2.sigMoved.connect(lambda v: setattr(self, "highlim", v))
        self.line3 = MovableVline(position=self.choose,
                                  label="",
                                  report=self.curveplot)
        self.line3.sigMoved.connect(lambda v: setattr(self, "choose", v))
        for line in [self.line1, self.line2, self.line3]:
            line.sigMoveFinished.connect(self.changed_integral_range)
            self.curveplot.add_marking(line)
            line.hide()

        self.markings_integral = []

        self.data = None
        self.disable_integral_range = False

        self.resize(900, 700)
        self._update_integration_type()

        # prepare interface according to the new context
        self.contextAboutToBeOpened.connect(
            lambda x: self.init_interface_data(x[0]))

    def setup_visible_image_controls(self):
        self.visbox = gui.widgetBox(self.controlArea, True)

        gui.checkBox(self.visbox,
                     self,
                     'show_visible_image',
                     label='Show visible image',
                     callback=lambda: (self.update_visible_image_interface(),
                                       self.update_visible_image()))

        self.visible_image_model = VisibleImageListModel()
        gui.comboBox(self.visbox,
                     self,
                     'visible_image',
                     model=self.visible_image_model,
                     callback=self.update_visible_image)

        self.visual_image_composition_modes = OrderedDict([
            ('Normal', QPainter.CompositionMode_Source),
            ('Overlay', QPainter.CompositionMode_Overlay),
            ('Multiply', QPainter.CompositionMode_Multiply),
            ('Difference', QPainter.CompositionMode_Difference)
        ])
        gui.comboBox(self.visbox,
                     self,
                     'visible_image_composition',
                     label='Composition mode:',
                     model=PyListModel(
                         self.visual_image_composition_modes.keys()),
                     callback=self.update_visible_image_composition_mode)

        gui.hSlider(self.visbox,
                    self,
                    'visible_image_opacity',
                    label='Opacity:',
                    minValue=0,
                    maxValue=255,
                    step=10,
                    createLabel=False,
                    callback=self.update_visible_image_opacity)

        self.update_visible_image_interface()
        self.update_visible_image_composition_mode()
        self.update_visible_image_opacity()

    def update_visible_image_interface(self):
        controlled = [
            'visible_image', 'visible_image_composition',
            'visible_image_opacity'
        ]
        for c in controlled:
            getattr(self.controls, c).setEnabled(self.show_visible_image)

    def update_visible_image_composition_mode(self):
        self.imageplot.set_visible_image_comp_mode(
            self.visual_image_composition_modes[
                self.visible_image_composition])

    def update_visible_image_opacity(self):
        self.imageplot.set_visible_image_opacity(self.visible_image_opacity)

    def init_interface_data(self, data):
        same_domain = (self.data and data and data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)
        self.init_visible_images(data)

    def output_image_selection(self):
        if not self.data:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            self.curveplot.set_data(None)
            return

        indices = np.flatnonzero(self.imageplot.selection_group)

        annotated_data = groups_or_annotated_table(
            self.data, self.imageplot.selection_group)
        self.Outputs.annotated_data.send(annotated_data)

        selected = self.data[indices]
        self.Outputs.selected_data.send(selected if selected else None)
        if selected:
            self.curveplot.set_data(selected)
        else:
            self.curveplot.set_data(self.data)

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.feature_value_model.set_domain(domain)
        self.attr_value = self.feature_value_model[
            0] if self.feature_value_model else None

    def init_visible_images(self, data):
        self.visible_image_model.clear()
        if data is not None and 'visible_images' in data.attributes:
            self.visbox.setEnabled(True)
            for img in data.attributes['visible_images']:
                self.visible_image_model.append(img)
        else:
            self.visbox.setEnabled(False)
            self.show_visible_image = False
        self.update_visible_image_interface()
        self._choose_visible_image()
        self.update_visible_image()

    def _choose_visible_image(self):
        # choose an image according to visible_image_name setting
        if len(self.visible_image_model):
            for img in self.visible_image_model:
                if img["name"] == self.visible_image_name:
                    self.visible_image = img
                    break
            else:
                self.visible_image = self.visible_image_model[0]

    def redraw_integral_info(self):
        di = {}
        integrate = self.image_values()
        if isinstance(integrate, Integrate) and np.any(
                self.curveplot.selection_group):
            # curveplot can have a subset of curves on the input> match IDs
            ind = np.flatnonzero(self.curveplot.selection_group)[0]
            dind = self.imageplot.data_ids[self.curveplot.data[ind].id]
            dshow = self.data[dind:dind + 1]
            datai = integrate(dshow)
            draw_info = datai.domain.attributes[0].compute_value.draw_info
            di = draw_info(dshow)
        self.refresh_markings(di)

    def refresh_markings(self, di):
        refresh_integral_markings([{
            "draw": di
        }], self.markings_integral, self.curveplot)

    def image_values(self):
        if self.value_type == 0:  # integrals
            imethod = self.integration_methods[self.integration_method]

            if imethod != Integrate.PeakAt:
                return Integrate(methods=imethod,
                                 limits=[[self.lowlim, self.highlim]])
            else:
                return Integrate(methods=imethod,
                                 limits=[[self.choose, self.choose]])
        else:
            return lambda data, attr=self.attr_value: \
                data.transform(Domain([data.domain[attr]]))

    def image_values_fixed_levels(self):
        if self.value_type == 1 and isinstance(self.attr_value,
                                               DiscreteVariable):
            return 0, len(self.attr_value.values) - 1
        return None

    def redraw_data(self):
        self.redraw_integral_info()
        self.imageplot.update_view()

    def update_feature_value(self):
        self.redraw_data()

    def _update_integration_type(self):
        self.line1.hide()
        self.line2.hide()
        self.line3.hide()
        if self.value_type == 0:
            self.box_values_spectra.setDisabled(False)
            self.box_values_feature.setDisabled(True)
            if self.integration_methods[
                    self.integration_method] != Integrate.PeakAt:
                self.line1.show()
                self.line2.show()
            else:
                self.line3.show()
        elif self.value_type == 1:
            self.box_values_spectra.setDisabled(True)
            self.box_values_feature.setDisabled(False)
        QTest.qWait(1)  # first update the interface

    def _change_integration(self):
        # change what to show on the image
        self._update_integration_type()
        self.redraw_data()

    def changed_integral_range(self):
        if self.disable_integral_range:
            return
        self.redraw_data()

    def _change_integral_type(self):
        self._change_integration()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()

        def valid_context(data):
            if data is None:
                return False
            annotation_features = [
                v for v in data.domain.metas + data.domain.class_vars
                if isinstance(v, (DiscreteVariable, ContinuousVariable))
            ]
            return len(annotation_features) >= 1

        if valid_context(data):
            self.openContext(data)
        else:
            # to generate valid interface even if context was not loaded
            self.contextAboutToBeOpened.emit([data])
        self.data = data
        self.imageplot.set_data(data)
        self.curveplot.set_data(data)
        self._init_integral_boundaries()
        self.imageplot.update_view()
        self.output_image_selection()
        self.update_visible_image()

    def _init_integral_boundaries(self):
        # requires data in curveplot
        self.disable_integral_range = True
        if self.curveplot.data_x is not None and len(self.curveplot.data_x):
            minx = self.curveplot.data_x[0]
            maxx = self.curveplot.data_x[-1]
        else:
            minx = 0.
            maxx = 1.

        if self.lowlim is None or not minx <= self.lowlim <= maxx:
            self.lowlim = minx
        self.line1.setValue(self.lowlim)

        if self.highlim is None or not minx <= self.highlim <= maxx:
            self.highlim = maxx
        self.line2.setValue(self.highlim)

        if self.choose is None:
            self.choose = (minx + maxx) / 2
        elif self.choose < minx:
            self.choose = minx
        elif self.choose > maxx:
            self.choose = maxx
        self.line3.setValue(self.choose)
        self.disable_integral_range = False

    def save_graph(self):
        self.imageplot.save_graph()

    def onDeleteWidget(self):
        self.curveplot.shutdown()
        self.imageplot.shutdown()
        super().onDeleteWidget()

    def update_visible_image(self):
        img_info = self.visible_image
        if self.show_visible_image and img_info is not None:
            self.visible_image_name = img_info[
                "name"]  # save visual image name
            img = Image.open(img_info['image_ref']).convert('RGBA')
            # image must be vertically flipped
            # https://github.com/pyqtgraph/pyqtgraph/issues/315#issuecomment-214042453
            # Behavior may change at pyqtgraph 1.0 version
            img = np.array(img)[::-1]
            width = img_info['img_size_x'] if 'img_size_x' in img_info \
                else img.shape[1] * img_info['pixel_size_x']
            height = img_info['img_size_y'] if 'img_size_y' in img_info \
                else img.shape[0] * img_info['pixel_size_y']
            rect = QRectF(img_info['pos_x'], img_info['pos_y'], width, height)
            self.imageplot.set_visible_image(img, rect)
            self.imageplot.show_visible_image()
        else:
            self.imageplot.hide_visible_image()
Beispiel #30
0
class OWTestAndScore(OWWidget):
    name = "Test and Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100
    keywords = ['Cross Validation', 'CV']
    replaces = ["Orange.widgets.evaluate.owtestlearners.OWTestLearners"]

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message("Click on the table header to select shown columns",
                       "click_header")
    ]

    settingsHandler = settings.PerfectDomainContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(2)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    use_rope = settings.Setting(False)
    rope = settings.Setting(0.1)
    comparison_criterion = settings.Setting(0, schema_only=True)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    class Error(OWWidget.Error):
        test_data_empty = Msg("Test dataset is empty.")
        class_required_test = Msg(
            "Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        test_data_incompatible = Msg(
            "Test data may be incompatible with train data.")
        train_data_error = Msg("{}")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")
        test_data_transformed = Msg(
            "Test data has been transformed to match the train data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           searchable=True,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            contentsLength=8,
            searchable=True,
            callback=self._on_target_class_changed)

        self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
        gui.comboBox(box,
                     self,
                     "comparison_criterion",
                     callback=self.update_comparison_table)

        hbox = gui.hBox(box)
        gui.checkBox(hbox,
                     self,
                     "use_rope",
                     "Negligible difference: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(hbox,
                     self,
                     "rope",
                     validator=QDoubleValidator(),
                     controlWidth=70,
                     callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)

        self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
        table = self.comparison_table = QTableWidget(
            wordWrap=False,
            editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(
            QLabel(
                "<small>Table shows probabilities that the score for the model in "
                "the row is higher than that of the model in the column. "
                "Small numbers show the probability that the difference is "
                "negligible.</small>",
                wordWrap=True))

    @staticmethod
    def sizeHint():
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestAndScore.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestAndScore.FeatureFold and not enabled:
            self.resampling = OWTestAndScore.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        elif learner is not None:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.cancel()
        self.Information.data_sampled.clear()
        self.Error.train_data_error.clear()

        if data is not None:
            data_errors = [
                ("Train dataset is empty.", len(data) == 0),
                ("Train data input requires a target variable.",
                 not data.domain.class_vars),
                ("Too many target variables.",
                 len(data.domain.class_vars) > 1),
                ("Target variable has no values.", np.isnan(data.Y).all()),
                ("Target variable has only one value.",
                 data.domain.has_discrete_class and len(unique(data.Y)) < 2),
                ("Data has no features to learn from.", data.X.shape[1] == 0),
            ]

            for error_msg, cond in data_errors:
                if cond:
                    self.Error.train_data_error(error_msg)
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestAndScore.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not data:
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestAndScore.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {
            (True, True): " ",  # both, don't specify
            (True, False): " train ",
            (False, True): " test "
        }[(self.train_data_missing_vals, self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data and self.data.domain.class_var:
            new_scorers = usable_scorers(self.data.domain.class_var)
        else:
            new_scorers = []
        # Don't unnecessarily reset the combo because this would always reset
        # comparison_criterion; we also set it explicitly, though, for clarity
        if new_scorers != self.scorers:
            self.scorers = new_scorers
            combo = self.controls.comparison_criterion
            combo.clear()
            combo.addItems(
                [scorer.long_name or scorer.name for scorer in self.scorers])
            if self.scorers:
                self.comparison_criterion = 0
        if self.__pending_comparison_criterion is not None:
            # Check for the unlikely case that some scorers have been removed
            # from modules
            if self.__pending_comparison_criterion < len(self.scorers):
                self.comparison_criterion = self.__pending_comparison_criterion
            self.__pending_comparison_criterion = None
        self._update_compbox_title()

    def _update_compbox_title(self):
        criterion = self.comparison_criterion
        if criterion < len(self.scorers):
            scorer = self.scorers[criterion]()
            self.compbox.setTitle(f"Model Comparison by {scorer.name}")
        else:
            self.compbox.setTitle(f"Model Comparison")

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self.score_table.update_header(self.scorers)
        self._update_view_enabled()
        self.update_stats_model()
        self.set_input_summary()
        if self.__needupdate:
            self.__update()

    def set_input_summary(self):
        summary, details, kwargs = self.info.NoInput, "", {}
        if self.data and self.test_data:
            summary = f"{self.info.format_number(len(self.data))}," \
                      f" {self.info.format_number(len(self.test_data))}"
            details = format_multiple_summaries([("Data", self.data),
                                                 ("Test data", self.test_data)
                                                 ])
            kwargs = {"format": Qt.RichText}
        elif self.data and not self.test_data:
            summary, details = len(self.data), format_summary_details(
                self.data)
        elif self.test_data and not self.data:
            summary = len(self.test_data)
            details = format_summary_details(self.test_data)
        self.info.set_input_summary(summary, details, **kwargs)

    def kfold_changed(self):
        self.resampling = OWTestAndScore.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestAndScore.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestAndScore.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self.modcompbox.setEnabled(self.resampling == OWTestAndScore.KFold)
        self._update_view_enabled()
        self._invalidate()
        self.__update()

    def _update_view_enabled(self):
        self.comparison_table.setEnabled(
            self.resampling == OWTestAndScore.KFold and len(self.learners) > 1
            and self.data is not None)
        self.score_table.view.setEnabled(self.data is not None)

    def update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.score_table.model
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        names = []
        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            names.append(name)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            results = slot.results
            if results is not None and results.success:
                train = QStandardItem("{:.3f}".format(
                    results.value.train_time))
                train.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                train.setData(key, Qt.UserRole)
                test = QStandardItem("{:.3f}".format(results.value.test_time))
                test.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                test.setData(key, Qt.UserRole)
                row = [head, train, test]
            else:
                row = [head]
            if isinstance(results, Try.Fail):
                head.setToolTip(str(results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                if isinstance(results.exception, DomainTransformationError) \
                        and self.resampling == self.TestOnTest:
                    self.Error.test_data_incompatible()
                    self.Information.test_data_transformed.clear()
                else:
                    errors.append("{name} failed with error:\n"
                                  "{exc.__class__.__name__}: {exc!s}".format(
                                      name=name, exc=slot.results.exception))

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(slot.results.value,
                                                      target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [
                        Try(scorer_caller(scorer, ovr_results, target=1))
                        for scorer in self.scorers
                    ]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat, scorer in zip(stats, self.scorers):
                    item = QStandardItem()
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    if stat.success:
                        item.setData(float(stat.value[0]), Qt.DisplayRole)
                    else:
                        item.setToolTip(str(stat.exception))
                        if scorer.name in self.score_table.shown_scores:
                            has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        # Resort rows based on current sorting
        header = self.score_table.view.horizontalHeader()
        model.sort(header.sortIndicatorSection(), header.sortIndicatorOrder())
        self._set_comparison_headers(names)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _on_use_rope_changed(self):
        self.controls.rope.setEnabled(self.use_rope)
        self.update_comparison_table()

    def update_comparison_table(self):
        self.comparison_table.clearContents()
        slots = self._successful_slots()
        if not (slots and self.scorers):
            return
        names = [learner_name(slot.learner) for slot in slots]
        self._set_comparison_headers(names)
        if self.resampling == OWTestAndScore.KFold:
            scores = self._scores_by_folds(slots)
            self._fill_table(names, scores)

    def _successful_slots(self):
        model = self.score_table.model
        proxy = self.score_table.sorted_model

        keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
                for row in range(proxy.rowCount()))
        slots = [
            slot for slot in (self.learners[key] for key in keys)
            if slot.results is not None and slot.results.success
        ]
        return slots

    def _set_comparison_headers(self, names):
        table = self.comparison_table
        try:
            # Prevent glitching during update
            table.setUpdatesEnabled(False)
            header = table.horizontalHeader()
            if len(names) > 2:
                header.setSectionResizeMode(QHeaderView.Stretch)
            else:
                header.setSectionResizeMode(QHeaderView.Fixed)
            table.setRowCount(len(names))
            table.setColumnCount(len(names))
            table.setVerticalHeaderLabels(names)
            table.setHorizontalHeaderLabels(names)
        finally:
            table.setUpdatesEnabled(True)

    def _scores_by_folds(self, slots):
        scorer = self.scorers[self.comparison_criterion]()
        self._update_compbox_title()
        if scorer.is_binary:
            if self.class_selection != self.TARGET_AVERAGE:
                class_var = self.data.domain.class_var
                target_index = class_var.values.index(self.class_selection)
                kw = dict(target=target_index)
            else:
                kw = dict(average='weighted')
        else:
            kw = {}

        def call_scorer(results):
            def thunked():
                return scorer.scores_by_folds(results.value, **kw).flatten()

            return thunked

        scores = [Try(call_scorer(slot.results)) for slot in slots]
        scores = [score.value if score.success else None for score in scores]
        # `None in scores doesn't work -- these are np.arrays)
        if any(score is None for score in scores):
            self.Warning.scores_not_computed()
        return scores

    def _fill_table(self, names, scores):
        table = self.comparison_table
        for row, row_name, row_scores in zip(count(), names, scores):
            for col, col_name, col_scores in zip(range(row), names, scores):
                if row_scores is None or col_scores is None:
                    continue
                if self.use_rope and self.rope:
                    p0, rope, p1 = baycomp.two_on_single(
                        row_scores, col_scores, self.rope)
                    if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(
                        table, row, col,
                        f"{p0:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({row_name} > {col_name}) = {p0:.3f}\n"
                        f"p({row_name} = {col_name}) = {rope:.3f}")
                    self._set_cell(
                        table, col, row,
                        f"{p1:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({col_name} > {row_name}) = {p1:.3f}\n"
                        f"p({col_name} = {row_name}) = {rope:.3f}")
                else:
                    p0, p1 = baycomp.two_on_single(row_scores, col_scores)
                    if np.isnan(p0) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(table, row, col, f"{p0:.3f}",
                                   f"p({row_name} > {col_name}) = {p0:.3f}")
                    self._set_cell(table, col, row, f"{p1:.3f}",
                                   f"p({col_name} > {row_name}) = {p1:.3f}")

    @classmethod
    def _set_cells_na(cls, table, row, col):
        cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
        cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

    @staticmethod
    def _set_cell(table, row, col, label, tooltip):
        item = QLabel(label)
        item.setToolTip(tooltip)
        item.setAlignment(Qt.AlignCenter)
        table.setCellWidget(row, col, item)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = (self.TARGET_AVERAGE, ) + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self.update_stats_model()
        self.update_comparison_table()

    def _invalidate(self, which=None):
        self.cancel()
        self.fold_feature_selected = \
            self.resampling == OWTestAndScore.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.score_table.model
        statmodelkeys = [
            model.item(row, 0).data(Qt.UserRole)
            for row in range(model.rowCount())
        ]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.comparison_table.clearContents()

        self.__needupdate = True

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [
            slot for slot in self.learners.values()
            if slot.results is not None and slot.results.success
        ]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [
                learner_name(slot.learner) for slot in valid
            ]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(
                    combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        summary = len(predictions) if predictions else self.info.NoOutput
        details = format_summary_details(predictions) if predictions else ""
        self.info.set_output_summary(summary, details)

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".format(
                stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [
                ("Sampling type",
                 "{}Shuffle split, {} random samples with {}% data ".format(
                     stratified, self.NRepeats[self.n_repeats],
                     self.SampleSizes[self.sample_size]))
            ]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.score_table.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')
            ]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Error.test_data_incompatible.clear()
        self.Warning.test_data_missing.clear()
        self.Information.test_data_transformed(
            shown=self.resampling == self.TestOnTest and self.data is not None
            and self.test_data is not None and
            self.data.domain.attributes != self.test_data.domain.attributes)
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestAndScore.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestAndScore.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestAndScore.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData(store_data=True,
                                                 store_models=True), self.data,
                self.test_data, learners_c, self.preprocessor)
        else:
            if self.resampling == OWTestAndScore.KFold:
                sampler = Orange.evaluation.CrossValidation(
                    k=self.NFolds[self.n_folds], random_state=rstate)
            elif self.resampling == OWTestAndScore.FeatureFold:
                sampler = Orange.evaluation.CrossValidationFeature(
                    feature=self.fold_feature)
            elif self.resampling == OWTestAndScore.LeaveOneOut:
                sampler = Orange.evaluation.LeaveOneOut()
            elif self.resampling == OWTestAndScore.ShuffleSplit:
                sampler = Orange.evaluation.ShuffleSplit(
                    n_resamples=self.NRepeats[self.n_repeats],
                    train_size=self.SampleSizes[self.sample_size] / 100,
                    test_size=None,
                    stratified=self.shuffle_stratified,
                    random_state=rstate)
            elif self.resampling == OWTestAndScore.TestOnTrain:
                sampler = Orange.evaluation.TestOnTrainingData(
                    store_models=True)
            else:
                assert False, "self.resampling %s" % self.resampling

            sampler.store_data = True
            test_f = partial(sampler, self.data, learners_c, self.preprocessor)

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[[float], None]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = TaskState()

        def progress_callback(finished):
            if task.is_interruption_requested():
                raise UserInterrupt()
            task.set_progress_value(100 * finished)

        testfunc = partial(testfunc, callback=progress_callback)
        task.start(self.__executor, testfunc)

        task.progress_changed.connect(self.setProgressValue)
        task.watcher.finished.connect(self.__task_complete)

        self.Outputs.evaluations_results.invalidate()
        self.Outputs.predictions.invalidate()
        self.progressBarInit()
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, f: 'Future[Results]'):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        assert self.__task is not None and self.__task.future is f
        self.progressBarFinished()
        self.setStatusMessage("")
        assert f.done()
        self.__task = None
        self.__state = State.Done
        try:
            results = f.result()  # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:  # pylint: disable=broad-except
            log.exception("testing error (in __task_complete):", exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er),
                                                                 er)))
            return

        learner_key = {
            slot.learner: key
            for key, slot in self.learners.items()
        }
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [
                        Try(scorer_caller(scorer, result))
                        for scorer in self.scorers
                    ]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self.score_table.update_header(self.scorers)
        self.update_stats_model()
        self.update_comparison_table()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            task.progress_changed.disconnect(self.setProgressValue)
            task.watcher.finished.disconnect(self.__task_complete)

            self.progressBarFinished()
            self.setStatusMessage("")

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=False)
        super().onDeleteWidget()
Beispiel #31
0
class OWTranspose(OWWidget, ConcurrentWidgetMixin):
    name = "Transpose"
    description = "Transpose data table."
    category = "Transform"
    icon = "icons/Transpose.svg"
    priority = 110
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table, dynamic=False)

    GENERIC, FROM_VAR = range(2)

    resizing_enabled = False
    want_main_area = False

    DEFAULT_PREFIX = "Feature"

    settingsHandler = DomainContextHandler()
    feature_type = ContextSetting(GENERIC)
    feature_name = ContextSetting("")
    feature_names_column = ContextSetting(None)
    remove_redundant_inst = ContextSetting(False)
    auto_apply = Setting(True)

    class Warning(OWWidget.Warning):
        duplicate_names = Msg("Values are not unique.\nTo avoid multiple "
                              "features with the same name, values \nof "
                              "'{}' have been augmented with indices.")
        discrete_attrs = Msg(
            "Categorical features have been encoded as numbers.")

    class Error(OWWidget.Error):
        value_error = Msg("{}")

    def __init__(self):
        OWWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        self.data = None

        # self.apply is changed later, pylint: disable=unnecessary-lambda
        box = gui.radioButtons(self.controlArea,
                               self,
                               "feature_type",
                               box="Feature names",
                               callback=self.commit.deferred)

        button = gui.appendRadioButton(box, "Generic")
        edit = gui.lineEdit(gui.indentedBox(box,
                                            gui.checkButtonOffsetHint(button)),
                            self,
                            "feature_name",
                            placeholderText="Type a prefix ...",
                            toolTip="Custom feature name")
        edit.editingFinished.connect(self._apply_editing)

        self.meta_button = gui.appendRadioButton(box, "From variable:")
        self.feature_model = DomainModel(valid_types=(ContinuousVariable,
                                                      StringVariable),
                                         alphabetical=False)
        self.feature_combo = gui.comboBox(gui.indentedBox(
            box, gui.checkButtonOffsetHint(button)),
                                          self,
                                          "feature_names_column",
                                          contentsLength=12,
                                          searchable=True,
                                          callback=self._feature_combo_changed,
                                          model=self.feature_model)

        self.remove_check = gui.checkBox(gui.indentedBox(
            box, gui.checkButtonOffsetHint(button)),
                                         self,
                                         "remove_redundant_inst",
                                         "Remove redundant instance",
                                         callback=self.commit.deferred)

        gui.auto_apply(self.buttonsArea, self)

        self.set_controls()

    def _apply_editing(self):
        self.feature_type = self.GENERIC
        self.feature_name = self.feature_name.strip()
        self.commit.deferred()

    def _feature_combo_changed(self):
        self.feature_type = self.FROM_VAR
        self.commit.deferred()

    @Inputs.data
    def set_data(self, data):
        # Skip the context if the combo is empty: a context with
        # feature_model == None would then match all domains
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.set_controls()
        if self.feature_model:
            self.openContext(data)
        self.commit.now()

    def set_controls(self):
        self.feature_model.set_domain(self.data.domain if self.data else None)
        self.meta_button.setEnabled(bool(self.feature_model))
        if self.feature_model:
            self.feature_names_column = self.feature_model[0]
            self.feature_type = self.FROM_VAR
        else:
            self.feature_names_column = None

    @gui.deferred
    def commit(self):
        self.clear_messages()
        variable = self.feature_type == self.FROM_VAR and \
            self.feature_names_column
        if variable and self.data:
            names = self.data.get_column_view(variable)[0]
            if len(names) != len(set(names)):
                self.Warning.duplicate_names(variable)
        if self.data and self.data.domain.has_discrete_attributes():
            self.Warning.discrete_attrs()
        feature_name = self.feature_name or self.DEFAULT_PREFIX
        self.start(run, self.data, variable, feature_name,
                   self.remove_redundant_inst)

    def on_partial_result(self, _):
        pass

    def on_done(self, transposed: Optional[Table]):
        self.Outputs.data.send(transposed)

    def on_exception(self, ex: Exception):
        if isinstance(ex, ValueError):
            self.Error.value_error(ex)
        else:
            raise ex

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def send_report(self):
        if self.feature_type == self.GENERIC:
            names = self.feature_name or self.DEFAULT_PREFIX
        else:
            names = "from variable"
            if self.feature_names_column:
                names += "  '{}'".format(self.feature_names_column.name)
        self.report_items("", [("Feature names", names)])
        if self.data:
            self.report_data("Data", self.data)
class OWGeneSets(OWWidget):
    name = "Gene Sets"
    description = ""
    icon = "icons/OWGeneSets.svg"
    priority = 9
    want_main_area = True

    COUNT, GENES, CATEGORY, TERM = range(4)
    DATA_HEADER_LABELS = ["Count", 'Genes In Set', 'Category', 'Term']

    organism = Setting(None, schema_only=True)
    stored_gene_sets_selection = Setting([], schema_only=True)
    selected_rows = Setting([], schema_only=True)
    custom_gene_set_indicator = Setting(None, schema_only=True)

    min_count = Setting(5)
    use_min_count = Setting(True)
    auto_commit = Setting(True)

    class Inputs:
        genes = Input("Data", Table)
        custom_sets = Input('Custom Gene Sets', Table)

    class Outputs:
        matched_genes = Output("Matched Genes", Table)

    class Information(OWWidget.Information):
        pass

    class Warning(OWWidget.Warning):
        all_sets_filtered = Msg('All sets were filtered out.')

    class Error(OWWidget.Error):
        organism_mismatch = Msg(
            'Organism in input data and custom gene sets does not match')
        missing_annotation = Msg(ERROR_ON_MISSING_ANNOTATION)
        missing_gene_id = Msg(ERROR_ON_MISSING_GENE_ID)
        missing_tax_id = Msg(ERROR_ON_MISSING_TAX_ID)
        cant_reach_host = Msg("Host orange.biolab.si is unreachable.")
        cant_load_organisms = Msg(
            "No available organisms, please check your connection.")

    def __init__(self):
        super().__init__()

        # commit
        self.commit_button = None

        # progress bar
        self.progress_bar = None
        self.progress_bar_iterations = None

        # data
        self.input_data = None
        self.input_genes = []
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gene_id_column = None

        # custom gene sets
        self.custom_data = None
        self.feature_model = DomainModel(valid_types=(DiscreteVariable,
                                                      StringVariable))
        self.custom_gs_col_box = None
        self.gs_label_combobox = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.num_of_custom_sets = None

        # Gene Sets widget
        self.gs_widget = None

        # info box
        self.input_info = None
        self.num_of_sel_genes = 0

        # filter
        self.line_edit_filter = None
        self.search_pattern = ''
        self.organism_select_combobox = None

        # data model view
        self.data_view = None
        self.data_model = None

        # gene matcher NCBI
        self.gene_matcher = None

        # filter proxy model
        self.filter_proxy_model = None

        # hierarchy widget
        self.hierarchy_widget = None
        self.hierarchy_state = None

        # spinbox
        self.spin_widget = None

        # threads
        self.threadpool = QThreadPool(self)
        self.workers = None

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # gui
        self.setup_gui()

    def __reset_widget_state(self):
        self.update_info_box()
        # clear data view
        self.init_item_model()
        # reset filters
        self.setup_filter_model()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._init_gene_sets_finished)
            self._task = None

    @Slot()
    def progress_advance(self):
        # GUI should be updated in main thread. That's why we are calling advance method here
        if self.progress_bar:
            self.progress_bar.advance()

    def __get_input_genes(self):
        self.input_genes = []

        if self.use_attr_names:
            for variable in self.input_data.domain.attributes:
                self.input_genes.append(
                    str(variable.attributes.get(self.gene_id_attribute, '?')))
        else:
            genes, _ = self.input_data.get_column_view(self.gene_id_column)
            self.input_genes = [str(g) for g in genes]

    def handle_custom_gene_sets(self, select_customs_flag=False):
        if self.custom_gene_set_indicator:
            if self.custom_data is not None and self.custom_gene_id_column is not None:

                if self.__check_organism_mismatch():
                    # self.gs_label_combobox.setDisabled(True)
                    self.Error.organism_mismatch()
                    self.gs_widget.update_gs_hierarchy()
                    return

                if isinstance(self.custom_gene_set_indicator,
                              DiscreteVariable):
                    labels = self.custom_gene_set_indicator.values
                    gene_sets_names = [
                        labels[int(idx)]
                        for idx in self.custom_data.get_column_view(
                            self.custom_gene_set_indicator)[0]
                    ]
                else:
                    gene_sets_names, _ = self.custom_data.get_column_view(
                        self.custom_gene_set_indicator)

                self.num_of_custom_sets = len(set(gene_sets_names))
                gene_names, _ = self.custom_data.get_column_view(
                    self.custom_gene_id_column)
                hierarchy_title = (self.custom_data.name if
                                   self.custom_data.name else 'Custom sets', )
                try:
                    self.gs_widget.add_custom_sets(
                        gene_sets_names,
                        gene_names,
                        hierarchy_title=hierarchy_title,
                        select_customs_flag=select_customs_flag)
                except geneset.GeneSetException:
                    pass
                # self.gs_label_combobox.setDisabled(False)
            else:
                self.gs_widget.update_gs_hierarchy()

        self.update_info_box()

    def update_tree_view(self):
        self.init_gene_sets()

    def invalidate(self):
        # clear
        self.__reset_widget_state()
        self.update_info_box()

        if self.input_data is not None:
            # setup
            self.__get_input_genes()
            self.update_tree_view()

    def __check_organism_mismatch(self):
        """ Check if organisms from different inputs match.

        :return: True if there is a mismatch
        """
        if self.tax_id is not None and self.custom_tax_id is not None:
            return self.tax_id != self.custom_tax_id
        return False

    def __get_reference_genes(self):
        self.reference_genes = []

        if self.reference_attr_names:
            for variable in self.reference_data.domain.attributes:
                self.reference_genes.append(
                    str(
                        variable.attributes.get(
                            self.reference_gene_id_attribute, '?')))
        else:
            genes, _ = self.reference_data.get_column_view(
                self.reference_gene_id_column)
            self.reference_genes = [str(g) for g in genes]

    @Inputs.custom_sets
    def handle_custom_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        self.custom_data = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.feature_model.set_domain(None)

        if data:
            self.custom_data = data
            self.feature_model.set_domain(self.custom_data.domain)
            self.custom_tax_id = str(
                self.custom_data.attributes.get(TAX_ID, None))
            self.custom_use_attr_names = self.custom_data.attributes.get(
                GENE_AS_ATTRIBUTE_NAME, None)
            self.custom_gene_id_attribute = self.custom_data.attributes.get(
                GENE_ID_ATTRIBUTE, None)
            self.custom_gene_id_column = self.custom_data.attributes.get(
                GENE_ID_COLUMN, None)

            if self.gs_label_combobox is None:
                self.gs_label_combobox = comboBox(
                    self.custom_gs_col_box,
                    self,
                    "custom_gene_set_indicator",
                    sendSelectedValue=True,
                    model=self.feature_model,
                    callback=self.on_gene_set_indicator_changed)
            self.custom_gs_col_box.show()

            if self.custom_gene_set_indicator in self.feature_model:
                index = self.feature_model.indexOf(
                    self.custom_gene_set_indicator)
                self.custom_gene_set_indicator = self.feature_model[index]
            else:
                self.custom_gene_set_indicator = self.feature_model[0]
        else:
            self.custom_gs_col_box.hide()

        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets(
            select_customs_flag=self.custom_gene_set_indicator is not None)
        self.invalidate()

    @Inputs.genes
    def handle_genes_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        # clear output
        self.Outputs.matched_genes.send(None)
        # clear input values
        self.input_genes = []
        self.input_data = None
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gs_widget.clear()
        self.gs_widget.clear_gene_sets()
        self.update_info_box()

        if data:
            self.input_data = data
            self.tax_id = str(self.input_data.attributes.get(TAX_ID, None))
            self.use_attr_names = self.input_data.attributes.get(
                GENE_AS_ATTRIBUTE_NAME, None)
            self.gene_id_attribute = self.input_data.attributes.get(
                GENE_ID_ATTRIBUTE, None)
            self.gene_id_column = self.input_data.attributes.get(
                GENE_ID_COLUMN, None)
            self.update_info_box()

            if not (self.use_attr_names is not None and
                    ((self.gene_id_attribute is None) ^
                     (self.gene_id_column is None))):

                if self.tax_id is None:
                    self.Error.missing_annotation()
                    return

                self.Error.missing_gene_id()
                return

            elif self.tax_id is None:
                self.Error.missing_tax_id()
                return

            if self.__check_organism_mismatch():
                self.Error.organism_mismatch()
                return

            self.gs_widget.load_gene_sets(self.tax_id)

            # if input data change, we need to refresh custom sets
            if self.custom_data:
                self.gs_widget.clear_custom_sets()
                self.handle_custom_gene_sets()

            self.invalidate()

    def update_info_box(self):
        info_string = ''
        if self.input_genes:
            info_string += '{} unique gene names on input.\n'.format(
                len(self.input_genes))
            info_string += '{} genes on output.\n'.format(
                self.num_of_sel_genes)
        else:
            if self.input_data:
                if not any([self.gene_id_column, self.gene_id_attribute]):
                    info_string += 'Input data with incorrect meta data.\nUse Gene Name Matcher widget.'
            else:
                info_string += 'No data on input.\n'

        if self.custom_data:
            info_string += '{} marker genes in {} sets\n'.format(
                self.custom_data.X.shape[0], self.num_of_custom_sets)

        self.input_info.setText(info_string)

    def create_partial(self):
        return partial(self.set_items, self.gs_widget.gs_object,
                       self.stored_gene_sets_selection, set(self.input_genes),
                       self.callback)

    def callback(self):
        if self._task.cancelled:
            raise KeyboardInterrupt()
        if self.progress_bar:
            methodinvoke(self, "progress_advance")()

    def init_gene_sets(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task()
        self.init_item_model()

        # save setting on selected hierarchies
        self.stored_gene_sets_selection = self.gs_widget.get_hierarchies(
            only_selected=True)

        f = self.create_partial()

        progress_iterations = sum([
            len(g_set) for hier, g_set in
            self.gs_widget.gs_object.map_hierarchy_to_sets().items()
            if hier in self.stored_gene_sets_selection
        ])

        self.progress_bar = ProgressBar(self, iterations=progress_iterations)

        self._task.future = self._executor.submit(f)

        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._init_gene_sets_finished)

    @Slot(concurrent.futures.Future)
    def _init_gene_sets_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progress_bar.finish()
        self.setStatusMessage('')

        try:
            results = f.result()  # type: list
            [self.data_model.appendRow(model_item) for model_item in results]
            self.filter_proxy_model.setSourceModel(self.data_model)
            self.data_view.selectionModel().selectionChanged.connect(
                self.commit)
            self.filter_data_view()
            self.set_selection()
            self.update_info_box()
        except Exception as ex:
            print(ex)

    def create_filters(self):
        search_term = self.search_pattern.lower().strip().split()

        filters = [
            FilterProxyModel.Filter(
                self.TERM, Qt.DisplayRole,
                lambda value: all(fs in value.lower() for fs in search_term))
        ]

        if self.use_min_count:
            filters.append(
                FilterProxyModel.Filter(
                    self.COUNT,
                    Qt.DisplayRole,
                    lambda value: value >= self.min_count,
                ))

        return filters

    def filter_data_view(self):
        filter_proxy = self.filter_proxy_model  # type: FilterProxyModel
        model = filter_proxy.sourceModel()  # type: QStandardItemModel

        if isinstance(model, QStandardItemModel):

            # apply filtering rules
            filter_proxy.set_filters(self.create_filters())

            if model.rowCount() and not filter_proxy.rowCount():
                self.Warning.all_sets_filtered()
            else:
                self.Warning.clear()

    def set_selection(self):
        if len(self.selected_rows):
            view = self.data_view
            model = self.data_model

            row_model_indexes = [
                model.indexFromItem(model.item(i)) for i in self.selected_rows
            ]
            proxy_rows = [
                self.filter_proxy_model.mapFromSource(i).row()
                for i in row_model_indexes
            ]

            if model.rowCount() <= self.selected_rows[-1]:
                return

            header_count = view.header().count() - 1
            selection = QItemSelection()

            for row_index in proxy_rows:
                selection.append(
                    QItemSelectionRange(
                        self.filter_proxy_model.index(row_index, 0),
                        self.filter_proxy_model.index(row_index,
                                                      header_count)))

            view.selectionModel().select(selection,
                                         QItemSelectionModel.ClearAndSelect)

    def commit(self):
        selection_model = self.data_view.selectionModel()

        if selection_model:
            selection = selection_model.selectedRows(self.COUNT)
            self.selected_rows = [
                self.filter_proxy_model.mapToSource(sel).row()
                for sel in selection
            ]

            if selection and self.input_genes:
                genes = [
                    model_index.data(Qt.UserRole) for model_index in selection
                ]
                output_genes = [
                    gene_name for gene_name in list(set.union(*genes))
                ]
                self.num_of_sel_genes = len(output_genes)
                self.update_info_box()

                if self.use_attr_names:
                    selected = [
                        column for column in self.input_data.domain.attributes
                        if self.gene_id_attribute in column.attributes
                        and str(column.attributes[
                            self.gene_id_attribute]) in output_genes
                    ]

                    domain = Domain(selected,
                                    self.input_data.domain.class_vars,
                                    self.input_data.domain.metas)
                    new_data = self.input_data.from_table(
                        domain, self.input_data)
                    self.Outputs.matched_genes.send(new_data)

                else:
                    # create filter from selected column for genes
                    only_known = table_filter.FilterStringList(
                        self.gene_id_column, output_genes)
                    # apply filter to the data
                    data_table = table_filter.Values([only_known
                                                      ])(self.input_data)

                    self.Outputs.matched_genes.send(data_table)

    def assign_delegates(self):
        self.data_view.setItemDelegateForColumn(self.GENES,
                                                NumericalColumnDelegate(self))

        self.data_view.setItemDelegateForColumn(self.COUNT,
                                                NumericalColumnDelegate(self))

    def setup_filter_model(self):
        self.filter_proxy_model = FilterProxyModel()
        self.filter_proxy_model.setFilterKeyColumn(self.TERM)
        self.data_view.setModel(self.filter_proxy_model)

    def setup_filter_area(self):
        h_layout = QHBoxLayout()
        h_layout.setSpacing(100)
        h_widget = widgetBox(self.mainArea, orientation=h_layout)

        spin(h_widget,
             self,
             'min_count',
             0,
             1000,
             label='Count',
             tooltip='Minimum genes count',
             checked='use_min_count',
             callback=self.filter_data_view,
             callbackOnReturn=True,
             checkCallback=self.filter_data_view)

        self.line_edit_filter = lineEdit(h_widget, self, 'search_pattern')
        self.line_edit_filter.setPlaceholderText('Filter gene sets ...')
        self.line_edit_filter.textChanged.connect(self.filter_data_view)

    def on_gene_set_indicator_changed(self):
        # self._handle_future_model()
        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets()
        self.invalidate()

    def setup_control_area(self):
        # Control area
        self.input_info = widgetLabel(
            widgetBox(self.controlArea, "Info", addSpace=True),
            'No data on input.\n')
        self.custom_gs_col_box = box = vBox(self.controlArea,
                                            'Custom Gene Set Term Column')
        box.hide()

        gene_sets_box = widgetBox(self.controlArea, "Gene Sets")
        self.gs_widget = GeneSetsSelection(gene_sets_box, self,
                                           'stored_gene_sets_selection')
        self.gs_widget.hierarchy_tree_widget.itemClicked.connect(
            self.update_tree_view)

        self.commit_button = auto_commit(self.controlArea,
                                         self,
                                         "auto_commit",
                                         "&Commit",
                                         box=False)

    def setup_gui(self):
        # control area
        self.setup_control_area()

        # main area
        self.data_view = QTreeView()
        self.setup_filter_model()
        self.setup_filter_area()
        self.data_view.setAlternatingRowColors(True)
        self.data_view.sortByColumn(self.COUNT, Qt.DescendingOrder)
        self.data_view.setSortingEnabled(True)
        self.data_view.setSelectionMode(QTreeView.ExtendedSelection)
        self.data_view.setEditTriggers(QTreeView.NoEditTriggers)
        self.data_view.viewport().setMouseTracking(False)
        self.data_view.setItemDelegateForColumn(
            self.TERM, LinkStyledItemDelegate(self.data_view))

        self.mainArea.layout().addWidget(self.data_view)

        self.data_view.header().setSectionResizeMode(
            QHeaderView.ResizeToContents)
        self.assign_delegates()

    @staticmethod
    def set_items(gene_sets, sets_to_display, genes, callback):
        model_items = []
        if not genes:
            return

        for gene_set in sorted(gene_sets):
            if gene_set.hierarchy not in sets_to_display:
                continue

            callback()

            matched_set = gene_set.genes & genes
            if len(matched_set) > 0:
                category_column = QStandardItem()
                term_column = QStandardItem()
                count_column = QStandardItem()
                genes_column = QStandardItem()

                category_column.setData(", ".join(gene_set.hierarchy),
                                        Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.ToolTipRole)
                term_column.setData(gene_set.link, LinkRole)
                term_column.setForeground(QColor(Qt.blue))

                count_column.setData(matched_set, Qt.UserRole)
                count_column.setData(len(matched_set), Qt.DisplayRole)

                genes_column.setData(len(gene_set.genes), Qt.DisplayRole)
                genes_column.setData(
                    set(gene_set.genes), Qt.UserRole
                )  # store genes to get then on output on selection

                model_items.append(
                    [count_column, genes_column, term_column, category_column])

        return model_items

    def init_item_model(self):
        if self.data_model:
            self.data_model.clear()
            self.setup_filter_model()
        else:
            self.data_model = QStandardItemModel()

        self.data_model.setSortRole(Qt.UserRole)
        self.data_model.setHorizontalHeaderLabels(self.DATA_HEADER_LABELS)

    def sizeHint(self):
        return QSize(1280, 960)
Beispiel #33
0
class ImagePlot(QWidget, OWComponent, SelectionGroupMixin,
                ImageColorSettingMixin, ImageZoomMixin):

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    gamma = Setting(0)

    selection_changed = Signal()

    def __init__(self, parent):
        QWidget.__init__(self)
        OWComponent.__init__(self, parent)
        SelectionGroupMixin.__init__(self)
        ImageColorSettingMixin.__init__(self)
        ImageZoomMixin.__init__(self)

        self.parent = parent

        self.selection_type = SELECTMANY
        self.saving_enabled = True
        self.selection_enabled = True
        self.viewtype = INDIVIDUAL  # required bt InteractiveViewBox
        self.highlighted = None
        self.data_points = None
        self.data_values = None
        self.data_imagepixels = None

        self.plotview = pg.PlotWidget(background="w", viewBox=InteractiveViewBox(self))
        self.plot = self.plotview.getPlotItem()

        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        layout = QVBoxLayout()
        self.setLayout(layout)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().addWidget(self.plotview)

        self.img = ImageItemNan()
        self.img.setOpts(axisOrder='row-major')
        self.plot.addItem(self.img)
        self.plot.vb.setAspectLocked()
        self.plot.scene().sigMouseMoved.connect(self.plot.vb.mouseMovedEvent)

        layout = QGridLayout()
        self.plotview.setLayout(layout)
        self.button = QPushButton("Menu", self.plotview)
        self.button.setAutoDefault(False)

        layout.setRowStretch(1, 1)
        layout.setColumnStretch(1, 1)
        layout.addWidget(self.button, 0, 0)
        view_menu = MenuFocus(self)
        self.button.setMenu(view_menu)

        # prepare interface according to the new context
        self.parent.contextAboutToBeOpened.connect(lambda x: self.init_interface_data(x[0]))

        actions = []

        self.add_zoom_actions(view_menu)

        select_square = QAction(
            "Select (square)", self, triggered=self.plot.vb.set_mode_select_square,
        )
        select_square.setShortcuts([Qt.Key_S])
        select_square.setShortcutContext(Qt.WidgetWithChildrenShortcut)
        actions.append(select_square)

        select_polygon = QAction(
            "Select (polygon)", self, triggered=self.plot.vb.set_mode_select_polygon,
        )
        select_polygon.setShortcuts([Qt.Key_P])
        select_polygon.setShortcutContext(Qt.WidgetWithChildrenShortcut)
        actions.append(select_polygon)

        if self.saving_enabled:
            save_graph = QAction(
                "Save graph", self, triggered=self.save_graph,
            )
            save_graph.setShortcuts([QKeySequence(Qt.ControlModifier | Qt.Key_I)])
            actions.append(save_graph)

        view_menu.addActions(actions)
        self.addActions(actions)

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str)

        choose_xy = QWidgetAction(self)
        box = gui.vBox(self)
        box.setFocusPolicy(Qt.TabFocus)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=DomainModel.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self.update_attr,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:", callback=self.update_attr,
            model=self.xy_model, **common_options)
        box.setFocusProxy(self.cb_attr_x)

        box.layout().addWidget(self.color_settings_box())

        choose_xy.setDefaultWidget(box)
        view_menu.addAction(choose_xy)

        self.markings_integral = []

        self.lsx = None  # info about the X axis
        self.lsy = None  # info about the Y axis

        self.data = None
        self.data_ids = {}

    def init_interface_data(self, data):
        same_domain = (self.data and data and
                       data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)

    def help_event(self, ev):
        pos = self.plot.vb.mapSceneToView(ev.scenePos())
        sel = self._points_at_pos(pos)
        prepared = []
        if sel is not None:
            data, vals, points = self.data[sel], self.data_values[sel], self.data_points[sel]
            for d, v, p in zip(data, vals, points):
                basic = "({}, {}): {}".format(p[0], p[1], v)
                variables = [v for v in self.data.domain.metas + self.data.domain.class_vars
                             if v not in [self.attr_x, self.attr_y]]
                features = ['{} = {}'.format(attr.name, d[attr]) for attr in variables]
                prepared.append("\n".join([basic] + features))
        text = "\n\n".join(prepared)
        if text:
            text = ('<span style="white-space:pre">{}</span>'
                    .format(escape(text)))
            QToolTip.showText(ev.screenPos(), text, widget=self.plotview)
            return True
        else:
            return False

    def update_attr(self):
        self.update_view()

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def save_graph(self):
        saveplot.save_plot(self.plotview, self.parent.graph_writers)

    def set_data(self, data):
        if data:
            self.data = data
            self.data_ids = {e: i for i, e in enumerate(data.ids)}
            self.restore_selection_settings()
        else:
            self.data = None
            self.data_ids = {}

    def refresh_markings(self, di):
        refresh_integral_markings([{"draw": di}], self.markings_integral, self.parent.curveplot)

    def update_view(self):
        self.img.clear()
        self.img.setSelection(None)
        self.lsx = None
        self.lsy = None
        self.data_points = None
        self.data_values = None
        self.data_imagepixels = None
        if self.data and self.attr_x and self.attr_y:
            xat = self.data.domain[self.attr_x]
            yat = self.data.domain[self.attr_y]

            ndom = Orange.data.Domain([xat, yat])
            datam = Orange.data.Table(ndom, self.data)
            coorx = datam.X[:, 0]
            coory = datam.X[:, 1]
            self.data_points = datam.X
            self.lsx = lsx = values_to_linspace(coorx)
            self.lsy = lsy = values_to_linspace(coory)
            if lsx[-1] * lsy[-1] > IMAGE_TOO_BIG:
                self.parent.Error.image_too_big(lsx[-1], lsy[-1])
                return
            else:
                self.parent.Error.image_too_big.clear()

            di = {}
            if self.parent.value_type == 0:  # integrals
                imethod = self.parent.integration_methods[self.parent.integration_method]

                if imethod != Integrate.PeakAt:
                    datai = Integrate(methods=imethod,
                                      limits=[[self.parent.lowlim, self.parent.highlim]])(self.data)
                else:
                    datai = Integrate(methods=imethod,
                                      limits=[[self.parent.choose, self.parent.choose]])(self.data)

                if np.any(self.parent.curveplot.selection_group):
                    # curveplot can have a subset of curves on the input> match IDs
                    ind = np.flatnonzero(self.parent.curveplot.selection_group)[0]
                    dind = self.data_ids[self.parent.curveplot.data[ind].id]
                    di = datai.domain.attributes[0].compute_value.draw_info(self.data[dind:dind+1])
                d = datai.X[:, 0]
            else:
                dat = self.data.domain[self.parent.attr_value]
                ndom = Orange.data.Domain([dat])
                d = Orange.data.Table(ndom, self.data).X[:, 0]
            self.refresh_markings(di)

            # set data
            imdata = np.ones((lsy[2], lsx[2])) * float("nan")

            xindex = index_values(coorx, lsx)
            yindex = index_values(coory, lsy)
            imdata[yindex, xindex] = d
            self.data_values = d
            self.data_imagepixels = np.vstack((yindex, xindex)).T

            self.img.setImage(imdata, autoLevels=False)
            self.img.setLevels([0, 1])
            self.update_levels()
            self.update_color_schema()

            # shift centres of the pixels so that the axes are useful
            shiftx = _shift(lsx)
            shifty = _shift(lsy)
            left = lsx[0] - shiftx
            bottom = lsy[0] - shifty
            width = (lsx[1]-lsx[0]) + 2*shiftx
            height = (lsy[1]-lsy[0]) + 2*shifty
            self.img.setRect(QRectF(left, bottom, width, height))

            self.refresh_img_selection()

    def refresh_img_selection(self):
        selected_px = np.zeros((self.lsy[2], self.lsx[2]), dtype=np.uint8)
        selected_px[self.data_imagepixels[:, 0], self.data_imagepixels[:, 1]] = self.selection_group
        self.img.setSelection(selected_px)

    def make_selection(self, selected, add):
        """Add selected indices to the selection."""
        add_to_group, add_group, remove = selection_modifiers()
        if self.data and self.lsx and self.lsy:
            if add_to_group:  # both keys - need to test it before add_group
                selnum = np.max(self.selection_group)
            elif add_group:
                selnum = np.max(self.selection_group) + 1
            elif remove:
                selnum = 0
            else:
                self.selection_group *= 0
                selnum = 1
            if selected is not None:
                self.selection_group[selected] = selnum
            self.refresh_img_selection()
        self.prepare_settings_for_saving()
        self.selection_changed.emit()

    def select_square(self, p1, p2, add):
        """ Select elements within a square drawn by the user.
        A selection needs to contain whole pixels """
        x1, y1 = p1.x(), p1.y()
        x2, y2 = p2.x(), p2.y()
        polygon = [QPointF(x1, y1), QPointF(x2, y1), QPointF(x2, y2), QPointF(x1, y2), QPointF(x1, y1)]
        self.select_polygon(polygon, add)

    def select_polygon(self, polygon, add):
        """ Select by a polygon which has to contain whole pixels. """
        if self.data and self.lsx and self.lsy:
            polygon = [(p.x(), p.y()) for p in polygon]
            # a polygon should contain all pixel
            shiftx = _shift(self.lsx)
            shifty = _shift(self.lsy)
            points_edges = [self.data_points + [[shiftx, shifty]],
                            self.data_points + [[-shiftx, shifty]],
                            self.data_points + [[shiftx, -shifty]],
                            self.data_points + [[-shiftx, -shifty]]]
            inp = in_polygon(points_edges[0], polygon)
            for p in points_edges[1:]:
                inp *= in_polygon(p, polygon)
            self.make_selection(inp, add)

    def _points_at_pos(self, pos):
        if self.data and self.lsx and self.lsy:
            x, y = pos.x(), pos.y()
            distance = np.abs(self.data_points - [[x, y]])
            sel = (distance[:, 0] < _shift(self.lsx)) * (distance[:, 1] < _shift(self.lsy))
            return sel

    def select_by_click(self, pos, add):
        sel = self._points_at_pos(pos)
        self.make_selection(sel, add)
Beispiel #34
0
class OWSieveDiagram(OWWidget):
    name = "Sieve Diagram"
    description = "Visualize the observed and expected frequencies " \
                  "for a combination of values."
    icon = "icons/SieveDiagram.svg"
    priority = 200

    class Inputs:
        data = Input("Data", Table, default=True)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    graph_name = "canvas"

    want_control_area = False

    settings_version = 1
    settingsHandler = DomainContextHandler()
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    selection = ContextSetting(set())

    def __init__(self):
        # pylint: disable=missing-docstring
        super().__init__()

        self.data = self.discrete_data = None
        self.attrs = []
        self.input_features = None
        self.areas = []
        self.selection = set()

        self.attr_box = gui.hBox(self.mainArea)
        self.domain_model = DomainModel(valid_types=DomainModel.PRIMITIVE)
        combo_args = dict(widget=self.attr_box,
                          master=self,
                          contentsLength=12,
                          callback=self.update_attr,
                          sendSelectedValue=True,
                          valueType=str,
                          model=self.domain_model)
        fixed_size = (QSizePolicy.Fixed, QSizePolicy.Fixed)
        gui.comboBox(value="attr_x", **combo_args)
        gui.widgetLabel(self.attr_box, "\u2715", sizePolicy=fixed_size)
        gui.comboBox(value="attr_y", **combo_args)
        self.vizrank, self.vizrank_button = SieveRank.add_vizrank(
            self.attr_box, self, "Score Combinations", self.set_attr)
        self.vizrank_button.setSizePolicy(*fixed_size)

        self.canvas = QGraphicsScene()
        self.canvasView = ViewWithPress(self.canvas,
                                        self.mainArea,
                                        handler=self.reset_selection)
        self.mainArea.layout().addWidget(self.canvasView)
        self.canvasView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvasView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)

    def sizeHint(self):
        return QSize(450, 550)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self.update_graph()

    def showEvent(self, event):
        super().showEvent(event)
        self.update_graph()

    @classmethod
    def migrate_context(cls, context, version):
        if not version:
            settings.rename_setting(context, "attrX", "attr_x")
            settings.rename_setting(context, "attrY", "attr_y")
            settings.migrate_str_to_variable(context)

    @Inputs.data
    def set_data(self, data):
        """
        Discretize continuous attributes, and put all attributes and discrete
        metas into self.attrs.

        Select the first two attributes unless context overrides this.
        Method `resolve_shown_attributes` is called to use the attributes from
        the input, if it exists and matches the attributes in the data.

        Remove selection; again let the context override this.
        Initialize the vizrank dialog, but don't show it.

        Args:
            data (Table): input data
        """
        if isinstance(data, SqlTable) and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data
        self.areas = []
        self.selection = set()
        if self.data is None:
            self.attrs[:] = []
            self.domain_model.set_domain(None)
        else:
            self.domain_model.set_domain(data.domain)
        self.attrs = [x for x in self.domain_model if isinstance(x, Variable)]
        if self.attrs:
            self.attr_x = self.attrs[0]
            self.attr_y = self.attrs[len(self.attrs) > 1]
        else:
            self.attr_x = self.attr_y = None
            self.areas = []
            self.selection = set()
        self.openContext(self.data)
        if self.data:
            self.discrete_data = self.sparse_to_dense(data, True)
        self.resolve_shown_attributes()
        self.update_graph()
        self.update_selection()

        self.vizrank.initialize()
        self.vizrank_button.setEnabled(self.data is not None
                                       and len(self.data) > 1
                                       and len(self.data.domain.attributes) > 1
                                       and not self.data.is_sparse())

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        """Update the graph and selection."""
        self.selection = set()
        self.discrete_data = self.sparse_to_dense(self.data)
        self.update_graph()
        self.update_selection()

    def sparse_to_dense(self, data, init=False):
        """
        Extracts two selected columns from sparse matrix.
        GH-2260
        """
        def discretizer(data):
            if any(attr.is_continuous for attr in chain(
                    data.domain.variables, data.domain.metas)):
                discretize = Discretize(method=EqualFreq(n=4),
                                        remove_const=False,
                                        discretize_classes=True,
                                        discretize_metas=True)
                return discretize(data)
            return data

        if not data.is_sparse() and not init:
            return self.discrete_data
        if data.is_sparse():
            attrs = {self.attr_x, self.attr_y}
            new_domain = data.domain.select_columns(attrs)
            data = Table.from_table(new_domain, data)
            data.X = data.X.toarray()
        return discretizer(data)

    @Inputs.features
    def set_input_features(self, attr_list):
        """
        Handler for the Features signal.

        The method stores the attributes and calls `resolve_shown_attributes`

        Args:
            attr_list (AttributeList): data from the signal
        """
        self.input_features = attr_list
        self.resolve_shown_attributes()
        self.update_selection()

    def resolve_shown_attributes(self):
        """
        Use the attributes from the input signal if the signal is present
        and at least two attributes appear in the domain. If there are
        multiple, use the first two. Combos are disabled if inputs are used.
        """
        self.warning()
        self.attr_box.setEnabled(True)
        if not self.input_features:  # None or empty
            return
        features = [f for f in self.input_features if f in self.domain_model]
        if not features:
            self.warning(
                "Features from the input signal are not present in the data")
            return
        old_attrs = self.attr_x, self.attr_y
        self.attr_x, self.attr_y = [f for f in (features * 2)[:2]]
        self.attr_box.setEnabled(False)
        if (self.attr_x, self.attr_y) != old_attrs:
            self.selection = set()
            self.update_graph()

    def reset_selection(self):
        self.selection = set()
        self.update_selection()

    def select_area(self, area, event):
        """
        Add or remove the clicked area from the selection

        Args:
            area (QRect): the area that is clicked
            event (QEvent): event description
        """
        if event.button() != Qt.LeftButton:
            return
        index = self.areas.index(area)
        if event.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection()

    def update_selection(self):
        """
        Update the graph (pen width) to show the current selection.
        Filter and output the data.
        """
        if self.areas is None or not self.selection:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(
                create_annotated_table(self.data, []))
            return

        filts = []
        for i, area in enumerate(self.areas):
            if i in self.selection:
                width = 4
                val_x, val_y = area.value_pair
                filts.append(
                    filter.Values([
                        filter.FilterDiscrete(self.attr_x.name, [val_x]),
                        filter.FilterDiscrete(self.attr_y.name, [val_y])
                    ]))
            else:
                width = 1
            pen = area.pen()
            pen.setWidth(width)
            area.setPen(pen)
        if len(filts) == 1:
            filts = filts[0]
        else:
            filts = filter.Values(filts, conjunction=False)
        selection = filts(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]
        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.data, sel_idx))

    def update_graph(self):
        # Function uses weird names like r, g, b, but it does it with utmost
        # caution, hence
        # pylint: disable=invalid-name
        """Update the graph."""
        def text(txt, *args, **kwargs):
            return CanvasText(self.canvas,
                              "",
                              html_text=to_html(txt),
                              *args,
                              **kwargs)

        def width(txt):
            return text(txt, 0, 0, show=False).boundingRect().width()

        def fmt(val):
            return str(int(val)) if val % 1 == 0 else "{:.2f}".format(val)

        def show_pearson(rect, pearson, pen_width):
            """
            Color the given rectangle according to its corresponding
            standardized Pearson residual.

            Args:
                rect (QRect): the rectangle being drawn
                pearson (float): signed standardized pearson residual
                pen_width (int): pen width (bolder pen is used for selection)
            """
            r = rect.rect()
            x, y, w, h = r.x(), r.y(), r.width(), r.height()
            if w == 0 or h == 0:
                return

            r = b = 255
            if pearson > 0:
                r = g = max(255 - 20 * pearson, 55)
            elif pearson < 0:
                b = g = max(255 + 20 * pearson, 55)
            else:
                r = g = b = 224
            rect.setBrush(QBrush(QColor(r, g, b)))
            pen_color = QColor(255 * (r == 255), 255 * (g == 255),
                               255 * (b == 255))
            pen = QPen(pen_color, pen_width)
            rect.setPen(pen)
            if pearson > 0:
                pearson = min(pearson, 10)
                dist = 20 - 1.6 * pearson
            else:
                pearson = max(pearson, -10)
                dist = 20 - 8 * pearson
            pen.setWidth(1)

            def _offseted_line(ax, ay):
                r = QGraphicsLineItem(x + ax, y + ay, x + (ax or w),
                                      y + (ay or h))
                self.canvas.addItem(r)
                r.setPen(pen)

            ax = dist
            while ax < w:
                _offseted_line(ax, 0)
                ax += dist

            ay = dist
            while ay < h:
                _offseted_line(0, ay)
                ay += dist

        def make_tooltip():
            """Create the tooltip. The function uses local variables from
            the enclosing scope."""

            # pylint: disable=undefined-loop-variable
            def _oper(attr, txt):
                if self.data.domain[attr.name] is ddomain[attr.name]:
                    return "="
                return " " if txt[0] in "<≥" else " in "

            return (
                "<b>{attr_x}{xeq}{xval_name}</b>: {obs_x}/{n} ({p_x:.0f} %)".
                format(attr_x=to_html(attr_x.name),
                       xeq=_oper(attr_x, xval_name),
                       xval_name=to_html(xval_name),
                       obs_x=fmt(chi.probs_x[x] * n),
                       n=int(n),
                       p_x=100 * chi.probs_x[x]) + "<br/>" +
                "<b>{attr_y}{yeq}{yval_name}</b>: {obs_y}/{n} ({p_y:.0f} %)".
                format(attr_y=to_html(attr_y.name),
                       yeq=_oper(attr_y, yval_name),
                       yval_name=to_html(yval_name),
                       obs_y=fmt(chi.probs_y[y] * n),
                       n=int(n),
                       p_y=100 * chi.probs_y[y]) + "<hr/>" +
                """<b>combination of values: </b><br/>
                   &nbsp;&nbsp;&nbsp;expected {exp} ({p_exp:.0f} %)<br/>
                   &nbsp;&nbsp;&nbsp;observed {obs} ({p_obs:.0f} %)""".format(
                    exp=fmt(chi.expected[y, x]),
                    p_exp=100 * chi.expected[y, x] / n,
                    obs=fmt(chi.observed[y, x]),
                    p_obs=100 * chi.observed[y, x] / n))

        for item in self.canvas.items():
            self.canvas.removeItem(item)
        if self.data is None or len(self.data) == 0 or \
                self.attr_x is None or self.attr_y is None:
            return

        ddomain = self.discrete_data.domain
        attr_x, attr_y = self.attr_x, self.attr_y
        disc_x, disc_y = ddomain[attr_x.name], ddomain[attr_y.name]
        view = self.canvasView

        chi = ChiSqStats(self.discrete_data, disc_x, disc_y)
        max_ylabel_w = max((width(val) for val in disc_y.values), default=0)
        max_ylabel_w = min(max_ylabel_w, 200)
        x_off = width(attr_x.name) + max_ylabel_w
        y_off = 15
        square_size = min(view.width() - x_off - 35,
                          view.height() - y_off - 80)
        square_size = max(square_size, 10)
        self.canvasView.setSceneRect(0, 0, view.width(), view.height())
        if not disc_x.values or not disc_y.values:
            text_ = "Features {} and {} have no values".format(disc_x, disc_y) \
                if not disc_x.values and \
                   not disc_y.values and \
                          disc_x != disc_y \
                else \
                    "Feature {} has no values".format(
                        disc_x if not disc_x.values else disc_y)
            text(text_,
                 view.width() / 2 + 70,
                 view.height() / 2, Qt.AlignRight | Qt.AlignVCenter)
            return
        n = chi.n
        curr_x = x_off
        max_xlabel_h = 0
        self.areas = []
        for x, (px, xval_name) in enumerate(zip(chi.probs_x, disc_x.values)):
            if px == 0:
                continue
            width = square_size * px

            curr_y = y_off
            for y in range(len(chi.probs_y) - 1, -1, -1):  # bottom-up order
                py = chi.probs_y[y]
                yval_name = disc_y.values[y]
                if py == 0:
                    continue
                height = square_size * py

                selected = len(self.areas) in self.selection
                rect = CanvasRectangle(self.canvas,
                                       curr_x + 2,
                                       curr_y + 2,
                                       width - 4,
                                       height - 4,
                                       z=-10,
                                       onclick=self.select_area)
                rect.value_pair = x, y
                self.areas.append(rect)
                show_pearson(rect, chi.residuals[y, x], 3 * selected)
                rect.setToolTip(make_tooltip())

                if x == 0:
                    text(yval_name, x_off, curr_y + height / 2,
                         Qt.AlignRight | Qt.AlignVCenter)
                curr_y += height

            xl = text(xval_name, curr_x + width / 2, y_off + square_size,
                      Qt.AlignHCenter | Qt.AlignTop)
            max_xlabel_h = max(int(xl.boundingRect().height()), max_xlabel_h)
            curr_x += width

        bottom = y_off + square_size + max_xlabel_h
        text(attr_y.name,
             0,
             y_off + square_size / 2,
             Qt.AlignLeft | Qt.AlignVCenter,
             bold=True,
             vertical=True)
        text(attr_x.name,
             x_off + square_size / 2,
             bottom,
             Qt.AlignHCenter | Qt.AlignTop,
             bold=True)
        bottom += 30
        xl = text("χ²={:.2f}, p={:.3f}".format(chi.chisq, chi.p), 0, bottom)
        # Assume similar height for both lines
        text("N = " + fmt(chi.n), 0, bottom - xl.boundingRect().height())

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        self.report_plot()
Beispiel #35
0
class OWTranspose(OWWidget):
    name = "Transpose"
    description = "Transpose data table."
    icon = "icons/Transpose.svg"
    priority = 2000
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table, dynamic=False)

    GENERIC, FROM_META_ATTR = range(2)

    resizing_enabled = False
    want_main_area = False

    DEFAULT_PREFIX = "Feature"

    settingsHandler = DomainContextHandler()
    feature_type = ContextSetting(GENERIC)
    feature_name = ContextSetting("")
    feature_names_column = ContextSetting(None)
    auto_apply = Setting(True)

    class Error(OWWidget.Error):
        value_error = Msg("{}")

    def __init__(self):
        super().__init__()
        self.data = None

        box = gui.radioButtons(self.controlArea,
                               self,
                               "feature_type",
                               box="Feature names",
                               callback=lambda: self.apply())

        button = gui.appendRadioButton(box, "Generic")
        edit = gui.lineEdit(gui.indentedBox(box,
                                            gui.checkButtonOffsetHint(button)),
                            self,
                            "feature_name",
                            placeholderText="Type a prefix ...",
                            toolTip="Custom feature name")
        edit.editingFinished.connect(self._apply_editing)

        self.meta_button = gui.appendRadioButton(box, "From meta attribute:")
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=StringVariable,
                                         alphabetical=True)
        self.feature_combo = gui.comboBox(gui.indentedBox(
            box, gui.checkButtonOffsetHint(button)),
                                          self,
                                          "feature_names_column",
                                          contentsLength=12,
                                          callback=self._feature_combo_changed,
                                          model=self.feature_model)

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False,
                                            commit=self.apply)
        self.apply_button.button.setAutoDefault(False)

        self.set_controls()

    def _apply_editing(self):
        self.feature_type = self.GENERIC
        self.feature_name = self.feature_name.strip()
        self.apply()

    def _feature_combo_changed(self):
        self.feature_type = self.FROM_META_ATTR
        self.apply()

    @Inputs.data
    def set_data(self, data):
        # Skip the context if the combo is empty: a context with
        # feature_model == None would then match all domains
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.set_controls()
        if self.feature_model:
            self.openContext(data)
        self.apply()

    def set_controls(self):
        self.feature_model.set_domain(self.data and self.data.domain)
        self.meta_button.setEnabled(bool(self.feature_model))
        if self.feature_model:
            self.feature_names_column = self.feature_model[0]
            self.feature_type = self.FROM_META_ATTR
        else:
            self.feature_names_column = None

    def apply(self):
        self.clear_messages()
        transposed = None
        if self.data:
            try:
                transposed = Table.transpose(
                    self.data,
                    self.feature_type == self.FROM_META_ATTR
                    and self.feature_names_column,
                    feature_name=self.feature_name or self.DEFAULT_PREFIX)
            except ValueError as e:
                self.Error.value_error(e)
        self.Outputs.data.send(transposed)

    def send_report(self):
        if self.feature_type == self.GENERIC:
            names = self.feature_name or self.DEFAULT_PREFIX
        else:
            names = "from meta attribute"
            if self.feature_names_column:
                names += "  '{}'".format(self.feature_names_column.name)
        self.report_items("", [("Feature names", names)])
        if self.data:
            self.report_data("Data", self.data)
Beispiel #36
0
class OWCorrelations(OWWidget):
    name = "Correlations"
    description = "Compute all pairwise attribute correlations."
    icon = "icons/Correlations.svg"
    priority = 1106

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table)
        features = Output("Features", AttributeList)
        correlations = Output("Correlations", Table)

    want_control_area = False

    correlation_type: int

    settings_version = 2
    settingsHandler = DomainContextHandler()
    selection = ContextSetting(())
    feature = ContextSetting(None)
    correlation_type = Setting(0)

    class Information(OWWidget.Information):
        removed_cons_feat = Msg("Constant features have been removed.")

    class Warning(OWWidget.Warning):
        not_enough_vars = Msg("At least two continuous features are needed.")
        not_enough_inst = Msg("At least two instances are needed.")

    def __init__(self):
        super().__init__()
        self.data = None
        self.cont_data = None

        # GUI
        box = gui.vBox(self.mainArea)
        self.correlation_combo = gui.comboBox(
            box, self, "correlation_type", items=CorrelationType.items(),
            orientation=Qt.Horizontal, callback=self._correlation_combo_changed
        )

        self.feature_model = DomainModel(
            order=DomainModel.ATTRIBUTES, separators=False,
            placeholder="(All combinations)", valid_types=ContinuousVariable)
        gui.comboBox(
            box, self, "feature", callback=self._feature_combo_changed,
            model=self.feature_model
        )

        self.vizrank, _ = CorrelationRank.add_vizrank(
            None, self, None, self._vizrank_selection_changed)
        self.vizrank.progressBar = self.progressBar
        self.vizrank.button.setEnabled(False)
        self.vizrank.threadStopped.connect(self._vizrank_stopped)

        gui.separator(box)
        box.layout().addWidget(self.vizrank.filter)
        box.layout().addWidget(self.vizrank.rank_table)

        button_box = gui.hBox(self.mainArea)
        button_box.layout().addWidget(self.vizrank.button)

    @staticmethod
    def sizeHint():
        return QSize(350, 400)

    def _correlation_combo_changed(self):
        self.apply()

    def _feature_combo_changed(self):
        self.apply()

    def _vizrank_selection_changed(self, *args):
        self.selection = [(var.name, vartype(var)) for var in args]
        self.commit()

    def _vizrank_stopped(self):
        self._vizrank_select()

    def _vizrank_select(self):
        model = self.vizrank.rank_table.model()
        if not model.rowCount():
            return
        selection = QItemSelection()

        # This flag is needed because data in the model could be
        # filtered by a feature and therefore selection could not be found
        selection_in_model = False
        if self.selection:
            sel_names = sorted(name for name, _ in self.selection)
            for i in range(model.rowCount()):
                # pylint: disable=protected-access
                names = sorted(x.name for x in model.data(
                    model.index(i, 0), CorrelationRank._AttrRole))
                if names == sel_names:
                    selection.select(model.index(i, 0),
                                     model.index(i, model.columnCount() - 1))
                    selection_in_model = True
                    break
        if not selection_in_model:
            selection.select(model.index(0, 0),
                             model.index(0, model.columnCount() - 1))
        self.vizrank.rank_table.selectionModel().select(
            selection, QItemSelectionModel.ClearAndSelect)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear_messages()
        self.data = data
        self.cont_data = None
        self.selection = ()
        if data is not None:
            if len(data) < 2:
                self.Warning.not_enough_inst()
            else:
                domain = data.domain
                cont_attrs = [a for a in domain.attributes if a.is_continuous]
                cont_dom = Domain(cont_attrs, domain.class_vars, domain.metas)
                cont_data = Table.from_table(cont_dom, data)
                remover = Remove(Remove.RemoveConstant)
                cont_data = remover(cont_data)
                if remover.attr_results["removed"]:
                    self.Information.removed_cons_feat()
                if len(cont_data.domain.attributes) < 2:
                    self.Warning.not_enough_vars()
                else:
                    self.cont_data = SklImpute()(cont_data)
        self.set_feature_model()
        self.openContext(self.cont_data)
        self.apply()
        self.vizrank.button.setEnabled(self.cont_data is not None)

    def set_feature_model(self):
        self.feature_model.set_domain(self.cont_data and self.cont_data.domain)
        self.feature = None

    def apply(self):
        self.vizrank.initialize()
        if self.cont_data is not None:
            # this triggers self.commit() by changing vizrank selection
            self.vizrank.toggle()
        else:
            self.commit()

    def commit(self):
        if self.data is None or self.cont_data is None:
            self.Outputs.data.send(self.data)
            self.Outputs.features.send(None)
            self.Outputs.correlations.send(None)
            return

        attrs = [ContinuousVariable("Correlation"), ContinuousVariable("FDR")]
        metas = [StringVariable("Feature 1"), StringVariable("Feature 2")]
        domain = Domain(attrs, metas=metas)
        model = self.vizrank.rank_model
        x = np.array([[float(model.data(model.index(row, 0), role))
                       for role in (Qt.DisplayRole, CorrelationRank.PValRole)]
                      for row in range(model.rowCount())])
        x[:, 1] = FDR(list(x[:, 1]))
        # pylint: disable=protected-access
        m = np.array([[a.name for a in model.data(model.index(row, 0),
                                                  CorrelationRank._AttrRole)]
                      for row in range(model.rowCount())], dtype=object)
        corr_table = Table(domain, x, metas=m)
        corr_table.name = "Correlations"

        self.Outputs.data.send(self.data)
        # data has been imputed; send original attributes
        self.Outputs.features.send(AttributeList(
            [self.data.domain[name] for name, _ in self.selection]))
        self.Outputs.correlations.send(corr_table)

    def send_report(self):
        self.report_table(CorrelationType.items()[self.correlation_type],
                          self.vizrank.rank_table)

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            sel = context.values["selection"]
            context.values["selection"] = ([(var.name, vartype(var))
                                            for var in sel[0]], sel[1])
Beispiel #37
0
class OWLinePlot(OWWidget):
    name = "Line Plot"
    description = "Visualization of data profiles (e.g., time series)."
    icon = "icons/LinePlot.svg"
    priority = 180

    enable_selection = Signal(bool)

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler()
    group_var = ContextSetting(None)
    show_profiles = Setting(False)
    show_range = Setting(True)
    show_mean = Setting(True)
    show_error = Setting(False)
    auto_commit = Setting(True)
    selection = Setting(None, schema_only=True)

    graph_name = "graph.plotItem"

    class Error(OWWidget.Error):
        not_enough_attrs = Msg("Need at least one continuous feature.")
        no_valid_data = Msg("No plot due to no valid data.")

    class Warning(OWWidget.Warning):
        no_display_option = Msg("No display option is selected.")

    class Information(OWWidget.Information):
        hidden_instances = Msg("Instances with unknown values are not shown.")
        too_many_features = Msg("Data has too many features. Only first {}"
                                " are shown.".format(MAX_FEATURES))

    def __init__(self, parent=None):
        super().__init__(parent)
        self.__groups = []
        self.data = None
        self.valid_data = None
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.graph_variables = []
        self.setup_gui()

        self.graph.view_box.selection_changed.connect(self.selection_changed)
        self.enable_selection.connect(self.graph.view_box.enable_selection)

    def setup_gui(self):
        self._add_graph()
        self._add_controls()

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = LinePlotGraph(self)
        box.layout().addWidget(self.graph)

    def _add_controls(self):
        infobox = gui.widgetBox(self.controlArea, "Info")
        self.infoLabel = gui.widgetLabel(infobox, "No data on input.")
        displaybox = gui.widgetBox(self.controlArea, "Display")
        gui.checkBox(displaybox,
                     self,
                     "show_profiles",
                     "Lines",
                     callback=self.__show_profiles_changed,
                     tooltip="Plot lines")
        gui.checkBox(displaybox,
                     self,
                     "show_range",
                     "Range",
                     callback=self.__show_range_changed,
                     tooltip="Plot range between 10th and 90th percentile")
        gui.checkBox(displaybox,
                     self,
                     "show_mean",
                     "Mean",
                     callback=self.__show_mean_changed,
                     tooltip="Plot mean curve")
        gui.checkBox(displaybox,
                     self,
                     "show_error",
                     "Error bars",
                     callback=self.__show_error_changed,
                     tooltip="Show standard deviation")

        self.group_vars = DomainModel(placeholder="None",
                                      separators=False,
                                      valid_types=DiscreteVariable)
        self.group_view = gui.listView(self.controlArea,
                                       self,
                                       "group_var",
                                       box="Group by",
                                       model=self.group_vars,
                                       callback=self.__group_var_changed)
        self.group_view.setEnabled(False)
        self.group_view.setMinimumSize(QSize(30, 100))
        self.group_view.setSizePolicy(QSizePolicy.Expanding,
                                      QSizePolicy.Ignored)

        plot_gui = OWPlotGUI(self)
        plot_gui.box_zoom_select(self.controlArea)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

    def __show_profiles_changed(self):
        self.check_display_options()
        self._update_visibility("profiles")

    def __show_range_changed(self):
        self.check_display_options()
        self._update_visibility("range")

    def __show_mean_changed(self):
        self.check_display_options()
        self._update_visibility("mean")

    def __show_error_changed(self):
        self._update_visibility("error")

    def __group_var_changed(self):
        if self.data is None or not self.graph_variables:
            return
        self.plot_groups()
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self._update_sub_profiles()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.clear()
        self.check_data()
        self.check_display_options()

        if self.data is not None:
            self.group_vars.set_domain(self.data.domain)
            self.group_view.setEnabled(len(self.group_vars) > 1)
            self.group_var = self.data.domain.class_var \
                if self.data.domain.has_discrete_class else None

        self.openContext(data)
        self.setup_plot()
        self.commit()

    def check_data(self):
        def error(err):
            err()
            self.data = None

        self.clear_messages()
        if self.data is not None:
            self.infoLabel.setText(
                "%i instances on input\n%i features" %
                (len(self.data), len(self.data.domain.attributes)))
            self.graph_variables = [
                var for var in self.data.domain.attributes if var.is_continuous
            ]
            self.valid_data = ~countnans(self.data.X, axis=1).astype(bool)
            if len(self.graph_variables) < 1:
                error(self.Error.not_enough_attrs)
            elif not np.sum(self.valid_data):
                error(self.Error.no_valid_data)
            else:
                if not np.all(self.valid_data):
                    self.Information.hidden_instances()
                if len(self.graph_variables) > MAX_FEATURES:
                    self.Information.too_many_features()
                    self.graph_variables = self.graph_variables[:MAX_FEATURES]

    def check_display_options(self):
        self.Warning.no_display_option.clear()
        if self.data is not None:
            if not (self.show_profiles or self.show_range or self.show_mean):
                self.Warning.no_display_option()
            enable = (self.show_profiles or self.show_range) and \
                len(self.data[self.valid_data]) < SEL_MAX_INSTANCES
            self.enable_selection.emit(enable)

    @Inputs.data_subset
    @check_sql_input
    def set_subset_data(self, subset):
        self.subset_data = subset

    def handleNewSignals(self):
        self.set_subset_ids()
        if self.data is not None:
            self._update_profiles_color()
            self._update_sel_profiles_color()
            self._update_sub_profiles()

    def set_subset_ids(self):
        sub_ids = {e.id for e in self.subset_data} \
            if self.subset_data is not None else {}
        self.subset_indices = None
        if self.data is not None and sub_ids:
            self.subset_indices = [
                x.id for x in self.data[self.valid_data] if x.id in sub_ids
            ]

    def setup_plot(self):
        if self.data is None:
            return

        ticks = [[(i, a.name) for i, a in enumerate(self.graph_variables, 1)]]
        self.graph.getAxis('bottom').setTicks(ticks)
        self.plot_groups()
        self.apply_selection()
        self.graph.view_box.enableAutoRange()
        self.graph.view_box.updateAutoRange()

    def plot_groups(self):
        self._remove_groups()
        data = self.data[self.valid_data, self.graph_variables]
        if self.group_var is None:
            self._plot_group(data, np.where(self.valid_data)[0])
        else:
            class_col_data, _ = self.data.get_column_view(self.group_var)
            for index in range(len(self.group_var.values)):
                mask = np.logical_and(class_col_data == index, self.valid_data)
                indices = np.flatnonzero(mask)
                if not len(indices):
                    continue
                group_data = self.data[indices, self.graph_variables]
                self._plot_group(group_data, indices, index)
        self.graph.update_legend(self.group_var)
        self.graph.view_box.add_profiles(data.X)

    def _remove_groups(self):
        for group in self.__groups:
            group.remove_items()
        self.graph.view_box.remove_profiles()
        self.__groups = []

    def _plot_group(self, data, indices, index=None):
        color = self.__get_group_color(index)
        group = ProfileGroup(data, indices, color, self.graph)
        kwargs = self.__get_visibility_flags()
        group.set_visible_error(**kwargs)
        group.set_visible_mean(**kwargs)
        group.set_visible_range(**kwargs)
        group.set_visible_profiles(**kwargs)
        self.__groups.append(group)

    def __get_group_color(self, index):
        if self.group_var is not None:
            return QColor(*self.group_var.colors[index])
        return QColor(LinePlotStyle.DEFAULT_COLOR)

    def __get_visibility_flags(self):
        return {
            "show_profiles": self.show_profiles,
            "show_range": self.show_range,
            "show_mean": self.show_mean,
            "show_error": self.show_error
        }

    def _update_profiles_color(self):
        # color alpha depends on subset and selection; with selection or
        # subset profiles color has more opacity
        if not self.show_profiles:
            return
        for group in self.__groups:
            has_sel = bool(self.subset_indices) or bool(self.selection)
            group.update_profiles_color(has_sel)

    def _update_sel_profiles_and_range(self):
        # mark selected instances and selected range
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [i for i in group.indices if self.__in(i, self.selection)]
            table = self.data[inds, self.graph_variables].X if inds else None
            if self.show_profiles:
                group.update_sel_profiles(table)
            if self.show_range:
                group.update_sel_range(table)

    def _update_sel_profiles_color(self):
        # color depends on subset; when subset is present,
        # selected profiles are black
        if not self.selection or not self.show_profiles:
            return
        for group in self.__groups:
            group.update_sel_profiles_color(bool(self.subset_indices))

    def _update_sub_profiles(self):
        # mark subset instances
        if not (self.show_profiles or self.show_range):
            return
        for group in self.__groups:
            inds = [
                i for i, _id in zip(group.indices, group.ids)
                if self.__in(_id, self.subset_indices)
            ]
            table = self.data[inds, self.graph_variables].X if inds else None
            group.update_sub_profiles(table)

    def _update_visibility(self, obj_name):
        if not len(self.__groups):
            return
        self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        kwargs = self.__get_visibility_flags()
        for group in self.__groups:
            getattr(group, "set_visible_{}".format(obj_name))(**kwargs)
        self.graph.view_box.updateAutoRange()

    def apply_selection(self):
        if self.data is not None and self.__pending_selection is not None:
            sel = [i for i in self.__pending_selection if i < len(self.data)]
            mask = np.zeros(len(self.data), dtype=bool)
            mask[sel] = True
            mask = mask[self.valid_data]
            self.selection_changed(mask)
            self.__pending_selection = None

    def selection_changed(self, mask):
        if self.data is None:
            return
        # need indices for self.data: mask refers to self.data[self.valid_data]
        indices = np.arange(len(self.data))[self.valid_data][mask]
        self.graph.select(indices)
        old = self.selection
        self.selection = None if self.data and isinstance(self.data, SqlTable)\
            else list(self.graph.selection)
        if not old and self.selection or old and not self.selection:
            self._update_profiles_color()
        self._update_sel_profiles_and_range()
        self._update_sel_profiles_color()
        self.commit()

    def commit(self):
        selected = self.data[self.selection] \
            if self.data is not None and bool(self.selection) else None
        annotated = create_annotated_table(self.data, self.selection)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        if self.data is None:
            return

        caption = report.render_items_vert((("Group by", self.group_var), ))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def sizeHint(self):
        return QSize(1132, 708)

    def clear(self):
        self.valid_data = None
        self.selection = None
        self.__groups = []
        self.graph_variables = []
        self.graph.reset()
        self.infoLabel.setText("No data on input.")
        self.group_vars.set_domain(None)
        self.group_view.setEnabled(False)

    @staticmethod
    def __in(obj, collection):
        return collection is not None and obj in collection
Beispiel #38
0
class OWMosaicDisplay(OWWidget):
    name = "Mosaic Display"
    description = "Display data in a mosaic plot."
    icon = "icons/MosaicDisplay.svg"
    priority = 220
    keywords = []

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler()
    vizrank = SettingProvider(MosaicVizRank)
    settings_version = 2
    use_boxes = Setting(True)
    variable1 = ContextSetting(None)
    variable2 = ContextSetting(None)
    variable3 = ContextSetting(None)
    variable4 = ContextSetting(None)
    variable_color = ContextSetting(None)
    selection = Setting(set(), schema_only=True)

    BAR_WIDTH = 5
    SPACING = 4
    ATTR_NAME_OFFSET = 20
    ATTR_VAL_OFFSET = 3
    BLUE_COLORS = [
        QColor(255, 255, 255),
        QColor(210, 210, 255),
        QColor(110, 110, 255),
        QColor(0, 0, 255)
    ]
    RED_COLORS = [
        QColor(255, 255, 255),
        QColor(255, 200, 200),
        QColor(255, 100, 100),
        QColor(255, 0, 0)
    ]
    graph_name = "canvas"

    attrs_changed_manually = Signal(list)

    class Warning(OWWidget.Warning):
        incompatible_subset = Msg("Data subset is incompatible with Data")
        no_valid_data = Msg("No valid data")
        no_cont_selection_sql = \
            Msg("Selection of numeric features on SQL is not supported")

    def __init__(self):
        super().__init__()

        self.data = None
        self.discrete_data = None
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.selection = set()

        self.color_data = None

        self.areas = []

        self.canvas = QGraphicsScene(self)
        self.canvas_view = ViewWithPress(self.canvas,
                                         handler=self.clear_selection)
        self.mainArea.layout().addWidget(self.canvas_view)
        self.canvas_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setRenderHint(QPainter.Antialiasing)

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        box = gui.vBox(self.controlArea, box=True)
        self.model_1 = DomainModel(order=DomainModel.MIXED,
                                   valid_types=DomainModel.PRIMITIVE)
        self.model_234 = DomainModel(order=DomainModel.MIXED,
                                     valid_types=DomainModel.PRIMITIVE,
                                     placeholder="(None)")
        self.attr_combos = [
            gui.comboBox(box,
                         self,
                         value="variable{}".format(i),
                         orientation=Qt.Horizontal,
                         contentsLength=12,
                         searchable=True,
                         callback=self.attr_changed,
                         model=self.model_1 if i == 1 else self.model_234)
            for i in range(1, 5)
        ]
        self.vizrank, self.vizrank_button = MosaicVizRank.add_vizrank(
            box, self, "Find Informative Mosaics", self.set_attr)

        box2 = gui.vBox(self.controlArea, box="Interior Coloring")
        self.color_model = DomainModel(order=DomainModel.MIXED,
                                       valid_types=DomainModel.PRIMITIVE,
                                       placeholder="(Pearson residuals)")
        self.cb_attr_color = gui.comboBox(box2,
                                          self,
                                          value="variable_color",
                                          orientation=Qt.Horizontal,
                                          contentsLength=12,
                                          labelWidth=50,
                                          searchable=True,
                                          callback=self.set_color_data,
                                          model=self.color_model)
        self.bar_button = gui.checkBox(box2,
                                       self,
                                       'use_boxes',
                                       label='Compare with total',
                                       callback=self.update_graph)
        gui.rubber(self.controlArea)

    def sizeHint(self):
        return QSize(720, 530)

    def _get_discrete_data(self, data):
        """
        Discretize continuous attributes.
        Return None when there is no data, no rows, or no primitive attributes.
        """
        if (data is None or not len(data) or not any(
                attr.is_discrete or attr.is_continuous
                for attr in chain(data.domain.variables, data.domain.metas))):
            return None
        elif any(attr.is_continuous for attr in data.domain.variables):
            return Discretize(method=EqualFreq(n=4),
                              remove_const=False,
                              discretize_classes=True,
                              discretize_metas=True)(data)
        else:
            return data

    def init_combos(self, data):
        def set_combos(value):
            self.model_1.set_domain(value)
            self.model_234.set_domain(value)
            self.color_model.set_domain(value)

        if data is None:
            set_combos(None)
            self.variable1 = self.variable2 = self.variable3 \
                = self.variable4 = self.variable_color = None
            return
        set_combos(self.data.domain)

        if len(self.model_1) > 0:
            self.variable1 = self.model_1[0]
            self.variable2 = self.model_1[min(1, len(self.model_1) - 1)]
        self.variable3 = self.variable4 = None
        self.variable_color = self.data.domain.class_var  # None is OK, too

    def get_disc_attr_list(self):
        return [
            self.discrete_data.domain[var.name]
            for var in (self.variable1, self.variable2, self.variable3,
                        self.variable4) if var
        ]

    def set_attr(self, *attrs):
        self.variable1, self.variable2, self.variable3, self.variable4 = [
            attr and self.data.domain[attr.name] for attr in attrs
        ]
        self.reset_graph()

    def attr_changed(self):
        self.attrs_changed_manually.emit(self.get_disc_attr_list())
        self.reset_graph()

    def resizeEvent(self, e):
        OWWidget.resizeEvent(self, e)
        self.update_graph()

    def showEvent(self, ev):
        OWWidget.showEvent(self, ev)
        self.update_graph()

    @Inputs.data
    def set_data(self, data):
        if isinstance(data, SqlTable) and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data

        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(
            self.data is not None and len(self.data) > 1
            and len(self.data.domain.attributes) >= 1)

        if self.data is None:
            self.discrete_data = None
            self.init_combos(None)
            self.info.set_input_summary(self.info.NoInput)
            return

        self.info.set_input_summary(len(data), format_summary_details(data))
        self.init_combos(self.data)
        self.openContext(self.data)

    @Inputs.data_subset
    def set_subset_data(self, data):
        self.subset_data = data

    # this is called by widget after setData and setSubsetData are called.
    # this way the graph is updated only once
    def handleNewSignals(self):
        self.Warning.incompatible_subset.clear()
        self.subset_indices = None
        if self.data is not None and self.subset_data:
            transformed = self.subset_data.transform(self.data.domain)
            if np.all(np.isnan(transformed.X)) \
                    and np.all(np.isnan(transformed.Y)):
                self.Warning.incompatible_subset()
            else:
                indices = {e.id for e in transformed}
                self.subset_indices = [ex.id in indices for ex in self.data]
        if self.data is not None and self.__pending_selection is not None:
            self.selection = self.__pending_selection
            self.__pending_selection = None
        else:
            self.selection = set()
        self.set_color_data()
        self.update_graph()
        self.send_selection()

    def clear_selection(self):
        self.selection = set()
        self.update_selection_rects()
        self.send_selection()

    def coloring_changed(self):
        self.vizrank.coloring_changed()
        self.update_graph()

    def reset_graph(self):
        self.clear_selection()
        self.update_graph()

    def set_color_data(self):
        if self.data is None:
            return
        self.bar_button.setEnabled(self.variable_color is not None)
        attrs = [v for v in self.model_1 if v and v is not self.variable_color]
        domain = Domain(attrs, self.variable_color, None)
        self.color_data = self.data.from_table(domain, self.data)
        self.discrete_data = self._get_discrete_data(self.color_data)
        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(True)
        self.coloring_changed()

    def update_selection_rects(self):
        pens = (QPen(), QPen(Qt.black, 3, Qt.DotLine))
        for i, (_, _, area) in enumerate(self.areas):
            area.setPen(pens[i in self.selection])

    def select_area(self, index, ev):
        if ev.button() != Qt.LeftButton:
            return
        if ev.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection_rects()
        self.send_selection()

    def send_selection(self):
        if not self.selection or self.data is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(
                create_annotated_table(self.data, []))
            self.info.set_output_summary(self.info.NoOutput)
            return
        filters = []
        self.Warning.no_cont_selection_sql.clear()
        if self.discrete_data is not self.data:
            if isinstance(self.data, SqlTable):
                self.Warning.no_cont_selection_sql()
        for i in self.selection:
            cols, vals, _ = self.areas[i]
            filters.append(
                filter.Values(
                    filter.FilterDiscrete(col, [val])
                    for col, val in zip(cols, vals)))
        if len(filters) > 1:
            filters = filter.Values(filters, conjunction=False)
        else:
            filters = filters[0]
        selection = filters(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]

        summary = len(selection) if selection else self.info.NoOutput
        details = format_summary_details(selection) if selection else ""
        self.info.set_output_summary(summary, details)
        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.data, sel_idx))

    def send_report(self):
        self.report_plot(self.canvas)

    def update_graph(self):
        spacing = self.SPACING
        bar_width = self.BAR_WIDTH

        def get_counts(attr_vals, values):
            """Calculate rectangles' widths; if all are 0, they are set to 1."""
            if not attr_vals:
                counts = [conditionaldict[val] for val in values]
            else:
                counts = [
                    conditionaldict[attr_vals + "-" + val] for val in values
                ]
            total = sum(counts)
            if total == 0:
                counts = [1] * len(values)
                total = sum(counts)
            return total, counts

        def draw_data(attr_list,
                      x0_x1,
                      y0_y1,
                      side,
                      condition,
                      total_attrs,
                      used_attrs,
                      used_vals,
                      attr_vals=""):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if conditionaldict[attr_vals] == 0:
                add_rect(x0,
                         x1,
                         y0,
                         y1,
                         "",
                         used_attrs,
                         used_vals,
                         attr_vals=attr_vals)
                # store coordinates for later drawing of labels
                draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                          used_attrs, used_vals, attr_vals)
                return

            attr = attr_list[0]
            # how much smaller rectangles do we draw
            edge = len(attr_list) * spacing
            values = get_variable_values_sorted(attr)
            if side % 2:
                values = values[::-1]  # reverse names if necessary

            if side % 2 == 0:  # we are drawing on the x axis
                # remove the space needed for separating different attr. values
                whole = max(0, (x1 - x0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (x1 - x0) / float(len(values) - 1)
            else:  # we are drawing on the y axis
                whole = max(0, (y1 - y0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (y1 - y0) / float(len(values) - 1)

            total, counts = get_counts(attr_vals, values)

            # when visualizing the third attribute and the first attribute has
            # the last value, reverse the order in which the boxes are drawn;
            # otherwise, if the last cell, nearest to the labels of the fourth
            # attribute, is empty, we wouldn't be able to position the labels
            valrange = list(range(len(values)))
            if len(attr_list + used_attrs) == 4 and len(used_attrs) == 2:
                attr1values = get_variable_values_sorted(used_attrs[0])
                if used_vals[0] == attr1values[-1]:
                    valrange = valrange[::-1]

            for i in valrange:
                start = i * edge + whole * float(sum(counts[:i]) / total)
                end = i * edge + whole * float(sum(counts[:i + 1]) / total)
                val = values[i]
                htmlval = to_html(val)
                newattrvals = attr_vals + "-" + val if attr_vals else val

                tooltip = "{}&nbsp;&nbsp;&nbsp;&nbsp;{}: <b>{}</b><br/>".format(
                    condition, attr.name, htmlval)
                attrs = used_attrs + [attr]
                vals = used_vals + [val]
                args = attrs, vals, newattrvals
                if side % 2 == 0:  # if we are moving horizontally
                    if len(attr_list) == 1:
                        add_rect(x0 + start, x0 + end, y0, y1, tooltip, *args)
                    else:
                        draw_data(attr_list[1:], (x0 + start, x0 + end),
                                  (y0, y1), side + 1, tooltip, total_attrs,
                                  *args)
                else:
                    if len(attr_list) == 1:
                        add_rect(x0, x1, y0 + start, y0 + end, tooltip, *args)
                    else:
                        draw_data(attr_list[1:], (x0, x1),
                                  (y0 + start, y0 + end), side + 1, tooltip,
                                  total_attrs, *args)
            draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                      used_attrs, used_vals, attr_vals)

        def draw_text(side, attr, x0_x1, y0_y1, total_attrs, used_attrs,
                      used_vals, attr_vals):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if side in drawn_sides:
                return

            # the text on the right will be drawn when we are processing
            # visualization of the last value of the first attribute
            if side == 3:
                attr1values = get_variable_values_sorted(used_attrs[0])
                if used_vals[0] != attr1values[-1]:
                    return

            if not conditionaldict[attr_vals]:
                if side not in draw_positions:
                    draw_positions[side] = (x0, x1, y0, y1)
                return
            else:
                if side in draw_positions:
                    # restore the positions of attribute values and name
                    (x0, x1, y0, y1) = draw_positions[side]

            drawn_sides.add(side)

            values = get_variable_values_sorted(attr)
            if side % 2:
                values = values[::-1]

            spaces = spacing * (total_attrs - side) * (len(values) - 1)
            width = x1 - x0 - spaces * (side % 2 == 0)
            height = y1 - y0 - spaces * (side % 2 == 1)

            # calculate position of first attribute
            currpos = 0
            total, counts = get_counts(attr_vals, values)
            aligns = [
                Qt.AlignTop | Qt.AlignHCenter, Qt.AlignRight | Qt.AlignVCenter,
                Qt.AlignBottom | Qt.AlignHCenter,
                Qt.AlignLeft | Qt.AlignVCenter
            ]
            align = aligns[side]
            for i, val in enumerate(values):
                if distributiondict[val] != 0:
                    perc = counts[i] / float(total)
                    rwidth = width * perc
                    xs = [
                        x0 + currpos + rwidth / 2, x0 - self.ATTR_VAL_OFFSET,
                        x0 + currpos + rwidth / 2, x1 + self.ATTR_VAL_OFFSET
                    ]
                    ys = [
                        y1 + self.ATTR_VAL_OFFSET,
                        y0 + currpos + height * 0.5 * perc,
                        y0 - self.ATTR_VAL_OFFSET,
                        y0 + currpos + height * 0.5 * perc
                    ]

                    CanvasText(self.canvas,
                               val,
                               xs[side],
                               ys[side],
                               align,
                               max_width=rwidth if side == 0 else None)
                    space = height if side % 2 else width
                    currpos += perc * space + spacing * (total_attrs - side)

            xs = [
                x0 + (x1 - x0) / 2, x0 - max_ylabel_w1 - self.ATTR_VAL_OFFSET,
                x0 + (x1 - x0) / 2, x1 + max_ylabel_w2 + self.ATTR_VAL_OFFSET
            ]
            ys = [
                y1 + self.ATTR_VAL_OFFSET + self.ATTR_NAME_OFFSET,
                y0 + (y1 - y0) / 2,
                y0 - self.ATTR_VAL_OFFSET - self.ATTR_NAME_OFFSET,
                y0 + (y1 - y0) / 2
            ]
            CanvasText(self.canvas,
                       attr.name,
                       xs[side],
                       ys[side],
                       align,
                       bold=True,
                       vertical=side % 2)

        def add_rect(x0,
                     x1,
                     y0,
                     y1,
                     condition,
                     used_attrs,
                     used_vals,
                     attr_vals=""):
            area_index = len(self.areas)
            x1 += (x0 == x1)
            y1 += (y0 == y1)
            # rectangles of width and height 1 are not shown - increase
            y1 += (x1 - x0 + y1 - y0 == 2)
            colors = class_var and [QColor(*col) for col in class_var.colors]

            def select_area(_, ev):
                self.select_area(area_index, ev)

            def rect(x, y, w, h, z, pen_color=None, brush_color=None, **args):
                if pen_color is None:
                    return CanvasRectangle(self.canvas,
                                           x,
                                           y,
                                           w,
                                           h,
                                           z=z,
                                           onclick=select_area,
                                           **args)
                if brush_color is None:
                    brush_color = pen_color
                return CanvasRectangle(self.canvas,
                                       x,
                                       y,
                                       w,
                                       h,
                                       pen_color,
                                       brush_color,
                                       z=z,
                                       onclick=select_area,
                                       **args)

            def line(x1, y1, x2, y2):
                r = QGraphicsLineItem(x1, y1, x2, y2, None)
                self.canvas.addItem(r)
                r.setPen(QPen(Qt.white, 2))
                r.setZValue(30)

            outer_rect = rect(x0, y0, x1 - x0, y1 - y0, 30)
            self.areas.append((used_attrs, used_vals, outer_rect))
            if not conditionaldict[attr_vals]:
                return

            if self.variable_color is None:
                s = sum(apriori_dists[0])
                expected = s * reduce(
                    mul, (apriori_dists[i][used_vals[i]] / float(s)
                          for i in range(len(used_vals))))
                actual = conditionaldict[attr_vals]
                pearson = float((actual - expected) / sqrt(expected))
                if pearson == 0:
                    ind = 0
                else:
                    ind = max(0, min(int(log(abs(pearson), 2)), 3))
                color = [self.RED_COLORS, self.BLUE_COLORS][pearson > 0][ind]
                rect(x0, y0, x1 - x0, y1 - y0, -20, color)
                outer_rect.setToolTip(
                    condition + "<hr/>" + "Expected instances: %.1f<br>"
                    "Actual instances: %d<br>"
                    "Standardized (Pearson) residual: %.1f" %
                    (expected, conditionaldict[attr_vals], pearson))
            else:
                cls_values = get_variable_values_sorted(class_var)
                prior = get_distribution(data, class_var.name)
                total = 0
                for i, value in enumerate(cls_values):
                    val = conditionaldict[attr_vals + "-" + value]
                    if val == 0:
                        continue
                    if i == len(cls_values) - 1:
                        v = y1 - y0 - total
                    else:
                        v = ((y1 - y0) * val) / conditionaldict[attr_vals]
                    rect(x0, y0 + total, x1 - x0, v, -20, colors[i])
                    total += v

                if self.use_boxes and \
                        abs(x1 - x0) > bar_width and abs(y1 - y0) > bar_width:
                    total = 0
                    line(x0 + bar_width, y0, x0 + bar_width, y1)
                    n = sum(prior)
                    for i, (val, color) in enumerate(zip(prior, colors)):
                        if i == len(prior) - 1:
                            h = y1 - y0 - total
                        else:
                            h = (y1 - y0) * val / n
                        rect(x0, y0 + total, bar_width, h, 20, color)
                        total += h

                if conditionalsubsetdict:
                    if conditionalsubsetdict[attr_vals]:
                        if self.subset_indices is not None:
                            line(x1 - bar_width, y0, x1 - bar_width, y1)
                            total = 0
                            n = conditionalsubsetdict[attr_vals]
                            if n:
                                for i, (cls, color) in \
                                        enumerate(zip(cls_values, colors)):
                                    val = conditionalsubsetdict[attr_vals +
                                                                "-" + cls]
                                    if val == 0:
                                        continue
                                    if i == len(prior) - 1:
                                        v = y1 - y0 - total
                                    else:
                                        v = ((y1 - y0) * val) / n
                                    rect(x1 - bar_width, y0 + total, bar_width,
                                         v, 15, color)
                                    total += v

                actual = [
                    conditionaldict[attr_vals + "-" + cls_values[i]]
                    for i in range(len(prior))
                ]
                n_actual = sum(actual)
                if n_actual > 0:
                    apriori = [prior[key] for key in cls_values]
                    n_apriori = sum(apriori)
                    text = "<br/>".join(
                        "<b>%s</b>: %d / %.1f%% (Expected %.1f / %.1f%%)" %
                        (cls, act, 100.0 * act / n_actual,
                         apr / n_apriori * n_actual, 100.0 * apr / n_apriori)
                        for cls, act, apr in zip(cls_values, actual, apriori))
                else:
                    text = ""
                outer_rect.setToolTip("{}<hr>Instances: {}<br><br>{}".format(
                    condition, n_actual, text[:-4]))

        def create_legend():
            if self.variable_color is None:
                names = [
                    "<-8", "-8:-4", "-4:-2", "-2:2", "2:4", "4:8", ">8",
                    "Residuals:"
                ]
                colors = self.RED_COLORS[::-1] + self.BLUE_COLORS[1:]
                edges = repeat(Qt.black)
            else:
                names = get_variable_values_sorted(class_var)
                edges = colors = [QColor(*col) for col in class_var.colors]

            items = []
            size = 8
            for name, color, edgecolor in zip(names, colors, edges):
                item = QGraphicsItemGroup()
                item.addToGroup(
                    CanvasRectangle(None, -size / 2, -size / 2, size, size,
                                    edgecolor, color))
                item.addToGroup(
                    CanvasText(None, name, size, 0, Qt.AlignVCenter))
                items.append(item)
            return wrap_legend_items(items,
                                     hspacing=20,
                                     vspacing=16 + size,
                                     max_width=self.canvas_view.width() - xoff)

        self.canvas.clear()
        self.areas = []

        data = self.discrete_data
        if data is None:
            return
        attr_list = self.get_disc_attr_list()
        class_var = data.domain.class_var
        # TODO: check this
        # data = Preprocessor_dropMissing(data)

        unique = [v.name for v in set(attr_list + [class_var]) if v]
        if len(data[:, unique]) == 0:
            self.Warning.no_valid_data()
            return
        else:
            self.Warning.no_valid_data.clear()

        attrs = [attr for attr in attr_list if not attr.values]
        if attrs:
            CanvasText(self.canvas,
                       "Feature {} has no values".format(attrs[0]),
                       (self.canvas_view.width() - 120) / 2,
                       self.canvas_view.height() / 2)
            return
        if self.variable_color is None:
            apriori_dists = [
                get_distribution(data, attr) for attr in attr_list
            ]
        else:
            apriori_dists = []

        def get_max_label_width(attr):
            values = get_variable_values_sorted(attr)
            maxw = 0
            for val in values:
                t = CanvasText(self.canvas, val, 0, 0, bold=0, show=False)
                maxw = max(int(t.boundingRect().width()), maxw)
            return maxw

        xoff = 20

        # get the maximum width of rectangle
        width = 20
        max_ylabel_w1 = max_ylabel_w2 = 0
        if len(attr_list) > 1:
            text = CanvasText(self.canvas, attr_list[1].name, bold=1, show=0)
            max_ylabel_w1 = min(get_max_label_width(attr_list[1]), 150)
            width = 5 + text.boundingRect().height() + \
                self.ATTR_VAL_OFFSET + max_ylabel_w1
            xoff = width
            if len(attr_list) == 4:
                text = CanvasText(self.canvas,
                                  attr_list[3].name,
                                  bold=1,
                                  show=0)
                max_ylabel_w2 = min(get_max_label_width(attr_list[3]), 150)
                width += text.boundingRect().height() + \
                    self.ATTR_VAL_OFFSET + max_ylabel_w2 - 10

        legend = create_legend()

        # get the maximum height of rectangle
        yoff = 45
        legendoff = yoff + self.ATTR_NAME_OFFSET + self.ATTR_VAL_OFFSET + 35
        square_size = min(
            self.canvas_view.width() - width - 20,
            self.canvas_view.height() - legendoff -
            legend.boundingRect().height())

        if square_size < 0:
            return  # canvas is too small to draw rectangles
        self.canvas_view.setSceneRect(0, 0, self.canvas_view.width(),
                                      self.canvas_view.height())

        drawn_sides = set()
        draw_positions = {}

        conditionaldict, distributiondict = \
            get_conditional_distribution(data, attr_list)
        conditionalsubsetdict = None
        if self.subset_indices:
            conditionalsubsetdict, _ = get_conditional_distribution(
                self.discrete_data[self.subset_indices], attr_list)

        # draw rectangles
        draw_data(attr_list, (xoff, xoff + square_size),
                  (yoff, yoff + square_size), 0, "", len(attr_list), [], [])

        self.canvas.addItem(legend)
        legend.setPos(
            xoff - legend.boundingRect().x() +
            max(0, (square_size - legend.boundingRect().width()) / 2),
            legendoff + square_size)
        self.update_selection_rects()

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            settings.migrate_str_to_variable(context,
                                             none_placeholder="(None)")
Beispiel #39
0
class OWTranspose(OWWidget):
    name = "Transpose"
    description = "Transpose data table."
    icon = "icons/Transpose.svg"
    priority = 2000

    inputs = [("Data", Table, "set_data")]
    outputs = [("Data", Table)]

    resizing_enabled = False
    want_main_area = False

    settingsHandler = DomainContextHandler(metas_in_res=True)
    feature_type = ContextSetting(0)
    feature_names_column = ContextSetting(None)
    auto_apply = Setting(True)

    class Error(OWWidget.Error):
        value_error = Msg("{}")

    def __init__(self):
        super().__init__()
        self.data = None

        # GUI
        box = gui.vBox(self.controlArea, "Feature names")
        self.feature_radio = gui.radioButtonsInBox(
            box, self, "feature_type", callback=lambda: self.apply(), btnLabels=["Generic", "From meta attribute:"]
        )

        self.feature_model = DomainModel(order=DomainModel.METAS, valid_types=StringVariable, alphabetical=True)
        self.feature_combo = gui.comboBox(
            gui.indentedBox(box, gui.checkButtonOffsetHint(self.feature_radio.buttons[0])),
            self,
            "feature_names_column",
            callback=self._feature_combo_changed,
            model=self.feature_model,
        )

        self.apply_button = gui.auto_commit(
            self.controlArea, self, "auto_apply", "&Apply", box=False, commit=self.apply
        )

    def _feature_combo_changed(self):
        self.feature_type = 1
        self.apply()

    def set_data(self, data):
        # Skip the context if the combo is empty: a context with
        # feature_model == None would then match all domains
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.update_controls()
        if self.data is not None and self.feature_model:
            self.openContext(data)
        self.apply()

    def update_controls(self):
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.feature_names_column = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.feature_radio.buttons[1].setEnabled(enabled)
        self.feature_combo.setEnabled(enabled)
        self.feature_type = int(enabled)

    def apply(self):
        self.clear_messages()
        transposed = None
        if self.data:
            try:
                transposed = Table.transpose(self.data, self.feature_type and self.feature_names_column)
            except ValueError as e:
                self.Error.value_error(e)
        self.send("Data", transposed)

    def send_report(self):
        text = "from meta attribute: {}".format(self.feature_names_column) if self.feature_type else "generic"
        self.report_items("", [("Feature names", text)])
        if self.data:
            self.report_data("Data", self.data)
class OWFeatureStatistics(widget.OWWidget):
    name = 'Feature Statistics'
    description = 'Show basic statistics for data features.'
    icon = 'icons/FeatureStatistics.svg'

    class Inputs:
        data = Input('Data', Table, default=True)

    class Outputs:
        reduced_data = Output('Reduced Data', Table, default=True)
        statistics = Output('Statistics', Table)

    want_control_area = False
    buttons_area_orientation = Qt.Vertical

    settingsHandler = DomainContextHandler()
    settings_version = 2

    auto_commit = Setting(True)
    color_var = ContextSetting(None)  # type: Optional[Variable]
    # filter_string = ContextSetting('')

    sorting = Setting((0, Qt.DescendingOrder))
    selected_vars = ContextSetting([], schema_only=True)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Table]

        # TODO: Implement filtering on the model
        # filter_box = gui.vBox(self.controlArea, 'Filter')
        # self.filter_text = gui.lineEdit(
        #     filter_box, self, value='filter_string',
        #     placeholderText='Filter variables by name',
        #     callback=self._filter_table_variables, callbackOnType=True,
        # )
        # shortcut = QShortcut(QKeySequence('Ctrl+f'), self, self.filter_text.setFocus)
        # shortcut.setWhatsThis('Filter variables by name')

        box = gui.hBox(None, box=False)
        box.setContentsMargins(0, 0, 0, 0)

        self.color_var_model = DomainModel(
            valid_types=(ContinuousVariable, DiscreteVariable),
            placeholder='None',
        )
        self.cb_color_var = gui.comboBox(box,
                                         master=self,
                                         value='color_var',
                                         model=self.color_var_model,
                                         label='Color:',
                                         orientation=Qt.Horizontal,
                                         contentsLength=13,
                                         searchable=True)
        self.cb_color_var.activated.connect(self.__color_var_changed)

        gui.rubber(box)
        gui.auto_send(box, self, "auto_commit", box=None)
        self.mainArea.layout().addWidget(box)

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        # Main area
        self.model = FeatureStatisticsTableModel(parent=self)
        self.table_view = FeatureStatisticsTableView(self.model, parent=self)
        self.table_view.selectionModel().selectionChanged.connect(
            self.on_select)
        self.table_view.horizontalHeader().sectionClicked.connect(
            self.on_header_click)

        self.mainArea.layout().addWidget(self.table_view)

    @staticmethod
    def sizeHint():
        return QSize(1050, 500)

    def _filter_table_variables(self):
        regex = QRegExp(self.filter_string)
        # If the user explicitly types different cases, we assume they know
        # what they are searching for and account for letter case in filter
        different_case = (any(c.islower() for c in self.filter_string)
                          and any(c.isupper() for c in self.filter_string))
        if not different_case:
            regex.setCaseSensitivity(Qt.CaseInsensitive)

    @Inputs.data
    def set_data(self, data):
        # Clear outputs and reset widget state
        self.closeContext()
        self.selected_vars = []
        self.model.resetSorting()
        self.Outputs.reduced_data.send(None)
        self.Outputs.statistics.send(None)

        # Setup widget state for new data and restore settings
        self.data = data

        if data is not None:
            self.info.set_input_summary(len(data),
                                        format_summary_details(data))
            self.color_var_model.set_domain(data.domain)
            self.color_var = None
            if self.data.domain.class_vars:
                self.color_var = self.data.domain.class_vars[0]
        else:
            self.info.set_input_summary(self.info.NoInput)
            self.color_var_model.set_domain(None)
            self.color_var = None
        self.model.set_data(data)

        self.openContext(self.data)
        self.__restore_selection()
        self.__restore_sorting()
        # self._filter_table_variables()
        self.__color_var_changed()

        self.commit()

    def __restore_selection(self):
        """Restore the selection on the table view from saved settings."""
        selection_model = self.table_view.selectionModel()
        selection = QItemSelection()
        if self.selected_vars:
            var_indices = {
                var: i
                for i, var in enumerate(self.model.variables)
            }
            selected_indices = [var_indices[var] for var in self.selected_vars]
            for row in self.model.mapFromSourceRows(selected_indices):
                selection.append(
                    QItemSelectionRange(
                        self.model.index(row, 0),
                        self.model.index(row,
                                         self.model.columnCount() - 1)))
        selection_model.select(selection, QItemSelectionModel.ClearAndSelect)

    def __restore_sorting(self):
        """Restore the sort column and order from saved settings."""
        sort_column, sort_order = self.sorting
        if self.model.n_attributes and sort_column < self.model.columnCount():
            self.model.sort(sort_column, sort_order)
            self.table_view.horizontalHeader().setSortIndicator(
                sort_column, sort_order)

    @pyqtSlot(int)
    def on_header_click(self, *_):
        # Store the header states
        sort_order = self.model.sortOrder()
        sort_column = self.model.sortColumn()
        self.sorting = sort_column, sort_order

    @pyqtSlot(int)
    def __color_var_changed(self, *_):
        if self.model is not None:
            self.model.set_target_var(self.color_var)

    def on_select(self):
        selection_indices = list(
            self.model.mapToSourceRows([
                i.row()
                for i in self.table_view.selectionModel().selectedRows()
            ]))
        self.selected_vars = list(self.model.variables[selection_indices])
        self.commit()

    def commit(self):
        if not self.selected_vars:
            self.info.set_output_summary(self.info.NoOutput)
            self.Outputs.reduced_data.send(None)
            self.Outputs.statistics.send(None)
            return

        # Send a table with only selected columns to output
        variables = self.selected_vars
        self.info.set_output_summary(
            len(self.data[:, variables]),
            format_summary_details(self.data[:, variables]))
        self.Outputs.reduced_data.send(self.data[:, variables])

        # Send the statistics of the selected variables to ouput
        labels, data = self.model.get_statistics_matrix(variables,
                                                        return_labels=True)
        var_names = np.atleast_2d([var.name for var in variables]).T
        domain = Domain(
            attributes=[ContinuousVariable(name) for name in labels],
            metas=[StringVariable('Feature')])
        statistics = Table(domain, data, metas=var_names)
        statistics.name = '%s (Feature Statistics)' % self.data.name
        self.Outputs.statistics.send(statistics)

    def send_report(self):
        view = self.table_view
        self.report_table(view)

    @classmethod
    def migrate_context(cls, context, version):
        if not version or version < 2:
            selected_rows = context.values.pop("selected_rows", None)
            if not selected_rows:
                selected_vars = []
            else:
                # This assumes that dict was saved by Python >= 3.6 so dict is
                # ordered; if not, context hasn't had worked anyway.
                all_vars = [
                    (var, tpe) for (var, tpe) in chain(
                        context.attributes.items(), context.metas.items())
                    # it would be nicer to use cls.HIDDEN_VAR_TYPES, but there
                    # is no suitable conversion function, and StringVariable (3)
                    # was the only hidden var when settings_version < 2, so:
                    if tpe != 3
                ]
                selected_vars = [all_vars[i] for i in selected_rows]
            context.values["selected_vars"] = selected_vars, -3
Beispiel #41
0
class OWBoxPlot(widget.OWWidget):
    """
    Here's how the widget's functions call each other:

    - `set_data` is a signal handler fills the list boxes and calls
    `grouping_changed`.

    - `grouping_changed` handles changes of grouping attribute: it enables or
    disables the box for ordering, orders attributes and calls `attr_changed`.

    - `attr_changed` handles changes of attribute. It recomputes box data by
    calling `compute_box_data`, shows the appropriate display box
    (discrete/continuous) and then calls`layout_changed`

    - `layout_changed` constructs all the elements for the scene (as lists of
    QGraphicsItemGroup) and calls `display_changed`. It is called when the
    attribute or grouping is changed (by attr_changed) and on resize event.

    - `display_changed` puts the elements corresponding to the current display
    settings on the scene. It is called when the elements are reconstructed
    (layout is changed due to selection of attributes or resize event), or
    when the user changes display settings or colors.

    For discrete attributes, the flow is a bit simpler: the elements are not
    constructed in advance (by layout_changed). Instead, layout_changed and
    display_changed call display_changed_disc that draws everything.
    """
    name = "Box Plot"
    description = "Visualize the distribution of feature values in a box plot."
    icon = "icons/BoxPlot.svg"
    priority = 100
    keywords = ["whisker"]

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data", Orange.data.Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    #: Comparison types for continuous variables
    CompareNone, CompareMedians, CompareMeans = 0, 1, 2

    settingsHandler = DomainContextHandler()
    conditions = ContextSetting([])

    attribute = ContextSetting(None)
    order_by_importance = Setting(False)
    group_var = ContextSetting(None)
    show_annotations = Setting(True)
    compare = Setting(CompareMeans)
    stattest = Setting(0)
    sig_threshold = Setting(0.05)
    stretched = Setting(True)
    show_labels = Setting(True)
    sort_freqs = Setting(False)
    auto_commit = Setting(True)

    _sorting_criteria_attrs = {
        CompareNone: "", CompareMedians: "median", CompareMeans: "mean"
    }

    _pen_axis_tick = QPen(Qt.white, 5)
    _pen_axis = QPen(Qt.darkGray, 3)
    _pen_median = QPen(QBrush(QColor(0xff, 0xff, 0x00)), 2)
    _pen_paramet = QPen(QBrush(QColor(0x33, 0x00, 0xff)), 2)
    _pen_dotted = QPen(QBrush(QColor(0x33, 0x00, 0xff)), 1)
    _pen_dotted.setStyle(Qt.DotLine)
    _post_line_pen = QPen(Qt.lightGray, 2)
    _post_grp_pen = QPen(Qt.lightGray, 4)
    for pen in (_pen_paramet, _pen_median, _pen_dotted,
                _pen_axis, _pen_axis_tick, _post_line_pen, _post_grp_pen):
        pen.setCosmetic(True)
        pen.setCapStyle(Qt.RoundCap)
        pen.setJoinStyle(Qt.RoundJoin)
    _pen_axis_tick.setCapStyle(Qt.FlatCap)

    _box_brush = QBrush(QColor(0x33, 0x88, 0xff, 0xc0))

    _axis_font = QFont()
    _axis_font.setPixelSize(12)
    _label_font = QFont()
    _label_font.setPixelSize(11)
    _attr_brush = QBrush(QColor(0x33, 0x00, 0xff))

    graph_name = "box_scene"

    def __init__(self):
        super().__init__()
        self.stats = []
        self.dataset = None
        self.posthoc_lines = []

        self.label_txts = self.mean_labels = self.boxes = self.labels = \
            self.label_txts_all = self.attr_labels = self.order = []
        self.scale_x = self.scene_min_x = self.scene_width = 0
        self.label_width = 0

        self.attrs = VariableListModel()
        view = gui.listView(
            self.controlArea, self, "attribute", box="Variable",
            model=self.attrs, callback=self.attr_changed)
        view.setMinimumSize(QSize(30, 30))
        # Any other policy than Ignored will let the QListBox's scrollbar
        # set the minimal height (see the penultimate paragraph of
        # http://doc.qt.io/qt-4.8/qabstractscrollarea.html#addScrollBarWidget)
        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)
        gui.separator(view.box, 6, 6)
        self.cb_order = gui.checkBox(
            view.box, self, "order_by_importance",
            "Order by relevance",
            tooltip="Order by 𝜒² or ANOVA over the subgroups",
            callback=self.apply_sorting)
        self.group_vars = DomainModel(
            placeholder="None", separators=False,
            valid_types=Orange.data.DiscreteVariable)
        self.group_view = view = gui.listView(
            self.controlArea, self, "group_var", box="Subgroups",
            model=self.group_vars, callback=self.grouping_changed)
        view.setEnabled(False)
        view.setMinimumSize(QSize(30, 30))
        # See the comment above
        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)

        # TODO: move Compare median/mean to grouping box
        # The vertical size policy is needed to let only the list views expand
        self.display_box = gui.vBox(
            self.controlArea, "Display",
            sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum),
            addSpace=False)

        gui.checkBox(self.display_box, self, "show_annotations", "Annotate",
                     callback=self.display_changed)
        self.compare_rb = gui.radioButtonsInBox(
            self.display_box, self, 'compare',
            btnLabels=["No comparison", "Compare medians", "Compare means"],
            callback=self.layout_changed)

        # The vertical size policy is needed to let only the list views expand
        self.stretching_box = box = gui.vBox(
            self.controlArea, box="Display",
            sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Fixed))
        self.stretching_box.sizeHint = self.display_box.sizeHint
        gui.checkBox(
            box, self, 'stretched', "Stretch bars",
            callback=self.display_changed)
        gui.checkBox(
            box, self, 'show_labels', "Show box labels",
            callback=self.display_changed)
        self.sort_cb = gui.checkBox(
            box, self, 'sort_freqs', "Sort by subgroup frequencies",
            callback=self.display_changed)
        gui.rubber(box)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

        gui.vBox(self.mainArea, addSpace=True)
        self.box_scene = QGraphicsScene()
        self.box_scene.selectionChanged.connect(self.commit)
        self.box_view = QGraphicsView(self.box_scene)
        self.box_view.setRenderHints(QPainter.Antialiasing |
                                     QPainter.TextAntialiasing |
                                     QPainter.SmoothPixmapTransform)
        self.box_view.viewport().installEventFilter(self)

        self.mainArea.layout().addWidget(self.box_view)

        e = gui.hBox(self.mainArea, addSpace=False)
        self.infot1 = gui.widgetLabel(e, "<center>No test results.</center>")
        self.mainArea.setMinimumWidth(300)

        self.stats = self.dist = self.conts = []
        self.is_continuous = False

        self.update_display_box()

    def sizeHint(self):
        return QSize(900, 500)

    def eventFilter(self, obj, event):
        if obj is self.box_view.viewport() and \
                event.type() == QEvent.Resize:
            self.layout_changed()

        return super().eventFilter(obj, event)

    def reset_attrs(self, domain):
        self.attrs[:] = [
            var for var in chain(
                domain.class_vars, domain.metas, domain.attributes)
            if var.is_primitive()]

    # noinspection PyTypeChecker
    @Inputs.data
    def set_data(self, dataset):
        if dataset is not None and (
                not bool(dataset) or not len(dataset.domain) and not
                any(var.is_primitive() for var in dataset.domain.metas)):
            dataset = None
        self.closeContext()
        self.dataset = dataset
        self.dist = self.stats = self.conts = []
        self.group_var = None
        self.attribute = None
        if dataset:
            domain = dataset.domain
            self.group_vars.set_domain(domain)
            self.group_view.setEnabled(len(self.group_vars) > 1)
            self.reset_attrs(domain)
            self.select_default_variables(domain)
            self.openContext(self.dataset)
            self.grouping_changed()
        else:
            self.reset_all_data()
        self.commit()

    def select_default_variables(self, domain):
        # visualize first non-class variable, group by class (if present)
        if len(self.attrs) > len(domain.class_vars):
            self.attribute = self.attrs[len(domain.class_vars)]
        elif self.attrs:
            self.attribute = self.attrs[0]

        if domain.class_var and domain.class_var.is_discrete:
            self.group_var = domain.class_var
        else:
            self.group_var = None  # Reset to trigger selection via callback

    def apply_sorting(self):
        def compute_score(attr):
            if attr is group_var:
                return 3
            if attr.is_continuous:
                # One-way ANOVA
                col = data.get_column_view(attr)[0].astype(float)
                groups = (col[group_col == i] for i in range(n_groups))
                groups = (col[~np.isnan(col)] for col in groups)
                groups = [group for group in groups if len(group)]
                p = f_oneway(*groups)[1] if len(groups) > 1 else 2
            else:
                # Chi-square with the given distribution into groups
                # (see degrees of freedom in computation of the p-value)
                if not attr.values or not group_var.values:
                    return 2
                observed = np.array(
                    contingency.get_contingency(data, group_var, attr))
                observed = observed[observed.sum(axis=1) != 0, :]
                observed = observed[:, observed.sum(axis=0) != 0]
                if min(observed.shape) < 2:
                    return 2
                expected = \
                    np.outer(observed.sum(axis=1), observed.sum(axis=0)) / \
                    np.sum(observed)
                p = chisquare(observed.ravel(), f_exp=expected.ravel(),
                              ddof=n_groups - 1)[1]
            if math.isnan(p):
                return 2
            return p

        data = self.dataset
        if data is None:
            return
        domain = data.domain
        attribute = self.attribute
        group_var = self.group_var
        if self.order_by_importance and group_var is not None:
            n_groups = len(group_var.values)
            group_col = data.get_column_view(group_var)[0] if \
                domain.has_continuous_attributes(
                    include_class=True, include_metas=True) else None
            self.attrs.sort(key=compute_score)
        else:
            self.reset_attrs(domain)
        self.attribute = attribute

    def reset_all_data(self):
        self.clear_scene()
        self.infot1.setText("")
        self.attrs.clear()
        self.group_vars.set_domain(None)
        self.group_view.setEnabled(False)
        self.is_continuous = False
        self.update_display_box()

    def grouping_changed(self):
        self.cb_order.setEnabled(self.group_var is not None)
        self.apply_sorting()
        self.attr_changed()

    def select_box_items(self):
        temp_cond = self.conditions.copy()
        for box in self.box_scene.items():
            if isinstance(box, FilterGraphicsRectItem):
                box.setSelected(box.filter.conditions in
                                [c.conditions for c in temp_cond])

    def attr_changed(self):
        self.compute_box_data()
        self.update_display_box()
        self.layout_changed()

        if self.is_continuous:
            heights = 90 if self.show_annotations else 60
            self.box_view.centerOn(self.scene_min_x + self.scene_width / 2,
                                   -30 - len(self.stats) * heights / 2 + 45)
        else:
            self.box_view.centerOn(self.scene_width / 2,
                                   -30 - len(self.boxes) * 40 / 2 + 45)

    def compute_box_data(self):
        attr = self.attribute
        if not attr:
            return
        dataset = self.dataset
        self.is_continuous = attr.is_continuous
        if dataset is None or not self.is_continuous and not attr.values or \
                        self.group_var and not self.group_var.values:
            self.stats = self.dist = self.conts = []
            return
        if self.group_var:
            self.dist = []
            self.conts = contingency.get_contingency(
                dataset, attr, self.group_var)
            if self.is_continuous:
                stats, label_texts = [], []
                for i, cont in enumerate(self.conts):
                    if np.sum(cont[1]):
                        stats.append(BoxData(cont, attr, i, self.group_var))
                        label_texts.append(self.group_var.values[i])
                self.stats = stats
                self.label_txts_all = label_texts
            else:
                self.label_txts_all = \
                    [v for v, c in zip(self.group_var.values, self.conts)
                     if np.sum(c) > 0]
        else:
            self.dist = distribution.get_distribution(dataset, attr)
            self.conts = []
            if self.is_continuous:
                self.stats = [BoxData(self.dist, attr, None)]
            self.label_txts_all = [""]
        self.label_txts = [txts for stat, txts in zip(self.stats,
                                                      self.label_txts_all)
                           if stat.n > 0]
        self.stats = [stat for stat in self.stats if stat.n > 0]

    def update_display_box(self):
        if self.is_continuous:
            self.stretching_box.hide()
            self.display_box.show()
            self.compare_rb.setEnabled(self.group_var is not None)
        else:
            self.stretching_box.show()
            self.display_box.hide()
            self.sort_cb.setEnabled(self.group_var is not None)

    def clear_scene(self):
        self.closeContext()
        self.box_scene.clearSelection()
        self.box_scene.clear()
        self.box_view.viewport().update()
        self.attr_labels = []
        self.labels = []
        self.boxes = []
        self.mean_labels = []
        self.posthoc_lines = []
        self.openContext(self.dataset)

    def layout_changed(self):
        attr = self.attribute
        if not attr:
            return
        self.clear_scene()
        if self.dataset is None or len(self.conts) == len(self.dist) == 0:
            return

        if not self.is_continuous:
            self.display_changed_disc()
            return

        self.mean_labels = [self.mean_label(stat, attr, lab)
                            for stat, lab in zip(self.stats, self.label_txts)]
        self.draw_axis()
        self.boxes = [self.box_group(stat) for stat in self.stats]
        self.labels = [self.label_group(stat, attr, mean_lab)
                       for stat, mean_lab in zip(self.stats, self.mean_labels)]
        self.attr_labels = [QGraphicsSimpleTextItem(lab)
                            for lab in self.label_txts]
        for it in chain(self.labels, self.attr_labels):
            self.box_scene.addItem(it)
        self.display_changed()

    def display_changed(self):
        if self.dataset is None:
            return

        if not self.is_continuous:
            self.display_changed_disc()
            return

        self.order = list(range(len(self.stats)))
        criterion = self._sorting_criteria_attrs[self.compare]
        if criterion:
            vals = [getattr(stat, criterion) for stat in self.stats]
            overmax = max((val for val in vals if val is not None), default=0) \
                      + 1
            vals = [val if val is not None else overmax for val in vals]
            self.order = sorted(self.order, key=vals.__getitem__)

        heights = 90 if self.show_annotations else 60

        for row, box_index in enumerate(self.order):
            y = (-len(self.stats) + row) * heights + 10
            for item in self.boxes[box_index]:
                self.box_scene.addItem(item)
                item.setY(y)
            labels = self.labels[box_index]

            if self.show_annotations:
                labels.show()
                labels.setY(y)
            else:
                labels.hide()

            label = self.attr_labels[box_index]
            label.setY(y - 15 - label.boundingRect().height())
            if self.show_annotations:
                label.hide()
            else:
                stat = self.stats[box_index]

                if self.compare == OWBoxPlot.CompareMedians and \
                        stat.median is not None:
                    pos = stat.median + 5 / self.scale_x
                elif self.compare == OWBoxPlot.CompareMeans or stat.q25 is None:
                    pos = stat.mean + 5 / self.scale_x
                else:
                    pos = stat.q25
                label.setX(pos * self.scale_x)
                label.show()

        r = QRectF(self.scene_min_x, -30 - len(self.stats) * heights,
                   self.scene_width, len(self.stats) * heights + 90)
        self.box_scene.setSceneRect(r)

        self.compute_tests()
        self.show_posthoc()
        self.select_box_items()

    def display_changed_disc(self):
        assert not self.is_continuous
        self.clear_scene()
        self.attr_labels = [QGraphicsSimpleTextItem(lab)
                            for lab in self.label_txts_all]

        if not self.stretched:
            if self.group_var:
                self.labels = [
                    QGraphicsTextItem("{}".format(int(sum(cont))))
                    for cont in self.conts if np.sum(cont) > 0]
            else:
                self.labels = [
                    QGraphicsTextItem(str(int(sum(self.dist))))]

        self.order = list(range(len(self.attr_labels)))

        self.draw_axis_disc()
        if self.group_var:
            self.boxes = \
                [self.strudel(cont, i) for i, cont in enumerate(self.conts)
                 if np.sum(cont) > 0]
            self.conts = self.conts[np.sum(np.array(self.conts), axis=1) > 0]

            if self.sort_freqs:
                # pylint: disable=invalid-unary-operand-type
                self.order = sorted(self.order, key=(-np.sum(self.conts, axis=1)).__getitem__)
        else:
            self.boxes = [self.strudel(self.dist)]

        for row, box_index in enumerate(self.order):
            y = (-len(self.boxes) + row) * 40 + 10
            box = self.boxes[box_index]
            bars, labels = box[::2], box[1::2]

            self.__draw_group_labels(y, box_index)
            if not self.stretched:
                self.__draw_row_counts(y, box_index)
            if self.show_labels and self.attribute is not self.group_var:
                self.__draw_bar_labels(y, bars, labels)
            self.__draw_bars(y, bars)

        self.box_scene.setSceneRect(-self.label_width - 5,
                                    -30 - len(self.boxes) * 40,
                                    self.scene_width, len(self.boxes * 40) + 90)
        self.infot1.setText("")
        self.select_box_items()

    def __draw_group_labels(self, y, row):
        """Draw group labels

        Parameters
        ----------
        y: int
            vertical offset of bars
        row: int
            row index
        """
        label = self.attr_labels[row]
        b = label.boundingRect()
        label.setPos(-b.width() - 10, y - b.height() / 2)
        self.box_scene.addItem(label)

    def __draw_row_counts(self, y, row):
        """Draw row counts

        Parameters
        ----------
        y: int
            vertical offset of bars
        row: int
            row index
        """
        assert not self.is_continuous
        label = self.labels[row]
        b = label.boundingRect()
        if self.group_var:
            right = self.scale_x * sum(self.conts[row])
        else:
            right = self.scale_x * sum(self.dist)
        label.setPos(right + 10, y - b.height() / 2)
        self.box_scene.addItem(label)

    def __draw_bar_labels(self, y, bars, labels):
        """Draw bar labels

        Parameters
        ----------
        y: int
            vertical offset of bars
        bars: List[FilterGraphicsRectItem]
            list of bars being drawn
        labels: List[QGraphicsTextItem]
            list of labels for corresponding bars
        """
        label = bar_part = None
        for text_item, bar_part in zip(labels, bars):
            label = self.Label(
                text_item.toPlainText())
            label.setPos(bar_part.boundingRect().x(),
                         y - label.boundingRect().height() - 8)
            label.setMaxWidth(bar_part.boundingRect().width())
            self.box_scene.addItem(label)

    def __draw_bars(self, y, bars):
        """Draw bars

        Parameters
        ----------
        y: int
            vertical offset of bars

        bars: List[FilterGraphicsRectItem]
            list of bars to draw
        """
        for item in bars:
            item.setPos(0, y)
            self.box_scene.addItem(item)

    # noinspection PyPep8Naming
    def compute_tests(self):
        # The t-test and ANOVA are implemented here since they efficiently use
        # the widget-specific data in self.stats.
        # The non-parametric tests can't do this, so we use statistics.tests

        # pylint: disable=comparison-with-itself
        def stat_ttest():
            d1, d2 = self.stats
            if d1.n < 2 or d2.n < 2:
                return np.nan, np.nan
            pooled_var = d1.var / d1.n + d2.var / d2.n
            # pylint: disable=comparison-with-itself
            if pooled_var == 0 or np.isnan(pooled_var):
                return np.nan, np.nan
            df = pooled_var ** 2 / \
                ((d1.var / d1.n) ** 2 / (d1.n - 1) +
                 (d2.var / d2.n) ** 2 / (d2.n - 1))
            t = abs(d1.mean - d2.mean) / math.sqrt(pooled_var)
            p = 2 * (1 - scipy.special.stdtr(df, t))
            return t, p

        # TODO: Check this function
        # noinspection PyPep8Naming
        def stat_ANOVA():
            if any(stat.n == 0 for stat in self.stats):
                return np.nan, np.nan
            n = sum(stat.n for stat in self.stats)
            grand_avg = sum(stat.n * stat.mean for stat in self.stats) / n
            var_between = sum(stat.n * (stat.mean - grand_avg) ** 2
                              for stat in self.stats)
            df_between = len(self.stats) - 1

            var_within = sum(stat.n * stat.var for stat in self.stats)
            df_within = n - len(self.stats)
            if var_within == 0 or df_within == 0 or df_between == 0:
                return np.nan, np.nan
            F = (var_between / df_between) / (var_within / df_within)
            p = 1 - scipy.special.fdtr(df_between, df_within, F)
            return F, p

        if self.compare == OWBoxPlot.CompareNone or len(self.stats) < 2:
            t = ""
        elif any(s.n <= 1 for s in self.stats):
            t = "At least one group has just one instance, " \
                "cannot compute significance"
        elif len(self.stats) == 2:
            if self.compare == OWBoxPlot.CompareMedians:
                t = ""
                # z, p = tests.wilcoxon_rank_sum(
                #    self.stats[0].dist, self.stats[1].dist)
                # t = "Mann-Whitney's z: %.1f (p=%.3f)" % (z, p)
            else:
                t, p = stat_ttest()
                t = "" if np.isnan(t) else f"Student's t: {t:.3f} (p={p:.3f})"
        else:
            if self.compare == OWBoxPlot.CompareMedians:
                t = ""
                # U, p = -1, -1
                # t = "Kruskal Wallis's U: %.1f (p=%.3f)" % (U, p)
            else:
                F, p = stat_ANOVA()
                t = "" if np.isnan(F) else f"ANOVA: {F:.3f} (p={p:.3f})"
        self.infot1.setText("<center>%s</center>" % t)

    def mean_label(self, stat, attr, val_name):
        label = QGraphicsItemGroup()
        t = QGraphicsSimpleTextItem(
            "%.*f" % (attr.number_of_decimals + 1, stat.mean), label)
        t.setFont(self._label_font)
        bbox = t.boundingRect()
        w2, h = bbox.width() / 2, bbox.height()
        t.setPos(-w2, -h)
        tpm = QGraphicsSimpleTextItem(
            " \u00b1 " + "%.*f" % (attr.number_of_decimals + 1, stat.dev),
            label)
        tpm.setFont(self._label_font)
        tpm.setPos(w2, -h)
        if val_name:
            vnm = QGraphicsSimpleTextItem(val_name + ": ", label)
            vnm.setFont(self._label_font)
            vnm.setBrush(self._attr_brush)
            vb = vnm.boundingRect()
            label.min_x = -w2 - vb.width()
            vnm.setPos(label.min_x, -h)
        else:
            label.min_x = -w2
        return label

    def draw_axis(self):
        """Draw the horizontal axis and sets self.scale_x"""
        misssing_stats = not self.stats
        stats = self.stats or [BoxData(np.array([[0.], [1.]]), self.attribute)]
        mean_labels = self.mean_labels or [self.mean_label(stats[0], self.attribute, "")]
        bottom = min(stat.a_min for stat in stats)
        top = max(stat.a_max for stat in stats)

        first_val, step = compute_scale(bottom, top)
        while bottom <= first_val:
            first_val -= step
        bottom = first_val
        no_ticks = math.ceil((top - first_val) / step) + 1
        top = max(top, first_val + no_ticks * step)

        gbottom = min(bottom, min(stat.mean - stat.dev for stat in stats))
        gtop = max(top, max(stat.mean + stat.dev for stat in stats))

        bv = self.box_view
        viewrect = bv.viewport().rect().adjusted(15, 15, -15, -30)
        self.scale_x = scale_x = viewrect.width() / (gtop - gbottom)

        # In principle we should repeat this until convergence since the new
        # scaling is too conservative. (No chance am I doing this.)
        mlb = min(stat.mean + mean_lab.min_x / scale_x
                  for stat, mean_lab in zip(stats, mean_labels))
        if mlb < gbottom:
            gbottom = mlb
            self.scale_x = scale_x = viewrect.width() / (gtop - gbottom)

        self.scene_min_x = gbottom * scale_x
        self.scene_width = (gtop - gbottom) * scale_x

        val = first_val
        decimals = max(3, 4 - int(math.log10(step)))
        while True:
            l = self.box_scene.addLine(val * scale_x, -1, val * scale_x, 1,
                                       self._pen_axis_tick)
            l.setZValue(100)
            t = self.box_scene.addSimpleText(
                repr(round(val, decimals)) if not misssing_stats else "?",
                self._axis_font)
            t.setFlags(
                t.flags() | QGraphicsItem.ItemIgnoresTransformations)
            r = t.boundingRect()
            t.setPos(val * scale_x - r.width() / 2, 8)
            if val >= top:
                break
            val += step
        self.box_scene.addLine(
            bottom * scale_x - 4, 0, top * scale_x + 4, 0, self._pen_axis)

    def draw_axis_disc(self):
        """
        Draw the horizontal axis and sets self.scale_x for discrete attributes
        """
        assert not self.is_continuous
        if self.stretched:
            if not self.attr_labels:
                return
            step = steps = 10
        else:
            if self.group_var:
                max_box = max(float(np.sum(dist)) for dist in self.conts)
            else:
                max_box = float(np.sum(self.dist))
            if max_box == 0:
                self.scale_x = 1
                return
            _, step = compute_scale(0, max_box)
            step = int(step) if step > 1 else 1
            steps = int(math.ceil(max_box / step))
        max_box = step * steps

        bv = self.box_view
        viewrect = bv.viewport().rect().adjusted(15, 15, -15, -30)
        self.scene_width = viewrect.width()

        lab_width = max(lab.boundingRect().width() for lab in self.attr_labels)
        lab_width = max(lab_width, 40)
        lab_width = min(lab_width, self.scene_width / 3)
        self.label_width = lab_width

        right_offset = 0  # offset for the right label
        if not self.stretched and self.labels:
            if self.group_var:
                rows = list(zip(self.conts, self.labels))
            else:
                rows = [(self.dist, self.labels[0])]
            # available space left of the 'group labels'
            available = self.scene_width - lab_width - 10
            scale_x = (available - right_offset) / max_box
            max_right = max(sum(dist) * scale_x + 10 +
                            lbl.boundingRect().width()
                            for dist, lbl in rows)
            right_offset = max(0, max_right - max_box * scale_x)

        self.scale_x = scale_x = \
            (self.scene_width - lab_width - 10 - right_offset) / max_box

        self.box_scene.addLine(0, 0, max_box * scale_x, 0, self._pen_axis)
        for val in range(0, step * steps + 1, step):
            l = self.box_scene.addLine(val * scale_x, -1, val * scale_x, 1,
                                       self._pen_axis_tick)
            l.setZValue(100)
            t = self.box_scene.addSimpleText(str(val), self._axis_font)
            t.setPos(val * scale_x - t.boundingRect().width() / 2, 8)
        if self.stretched:
            self.scale_x *= 100

    def label_group(self, stat, attr, mean_lab):
        def centered_text(val, pos):
            t = QGraphicsSimpleTextItem(
                "%.*f" % (attr.number_of_decimals + 1, val), labels)
            t.setFont(self._label_font)
            bbox = t.boundingRect()
            t.setPos(pos - bbox.width() / 2, 22)
            return t

        def line(x, down=1):
            QGraphicsLineItem(x, 12 * down, x, 20 * down, labels)

        def move_label(label, frm, to):
            label.setX(to)
            to += t_box.width() / 2
            path = QPainterPath()
            path.lineTo(0, 4)
            path.lineTo(to - frm, 4)
            path.lineTo(to - frm, 8)
            p = QGraphicsPathItem(path)
            p.setPos(frm, 12)
            labels.addToGroup(p)

        labels = QGraphicsItemGroup()

        labels.addToGroup(mean_lab)
        m = stat.mean * self.scale_x
        mean_lab.setPos(m, -22)
        line(m, -1)

        if stat.median is not None:
            msc = stat.median * self.scale_x
            med_t = centered_text(stat.median, msc)
            med_box_width2 = med_t.boundingRect().width() / 2
            line(msc)

        if stat.q25 is not None:
            x = stat.q25 * self.scale_x
            t = centered_text(stat.q25, x)
            t_box = t.boundingRect()
            med_left = msc - med_box_width2
            if x + t_box.width() / 2 >= med_left - 5:
                move_label(t, x, med_left - t_box.width() - 5)
            else:
                line(x)

        if stat.q75 is not None:
            x = stat.q75 * self.scale_x
            t = centered_text(stat.q75, x)
            t_box = t.boundingRect()
            med_right = msc + med_box_width2
            if x - t_box.width() / 2 <= med_right + 5:
                move_label(t, x, med_right + 5)
            else:
                line(x)

        return labels

    def box_group(self, stat, height=20):
        def line(x0, y0, x1, y1, *args):
            return QGraphicsLineItem(x0 * scale_x, y0, x1 * scale_x, y1, *args)

        scale_x = self.scale_x
        box = []
        whisker1 = line(stat.a_min, -1.5, stat.a_min, 1.5)
        whisker2 = line(stat.a_max, -1.5, stat.a_max, 1.5)
        vert_line = line(stat.a_min, 0, stat.a_max, 0)
        mean_line = line(stat.mean, -height / 3, stat.mean, height / 3)
        for it in (whisker1, whisker2, mean_line):
            it.setPen(self._pen_paramet)
        vert_line.setPen(self._pen_dotted)
        var_line = line(stat.mean - stat.dev, 0, stat.mean + stat.dev, 0)
        var_line.setPen(self._pen_paramet)
        box.extend([whisker1, whisker2, vert_line, mean_line, var_line])
        if stat.q25 is not None and stat.q75 is not None:
            mbox = FilterGraphicsRectItem(
                stat.conditions, stat.q25 * scale_x, -height / 2,
                (stat.q75 - stat.q25) * scale_x, height)
            mbox.setBrush(self._box_brush)
            mbox.setPen(QPen(Qt.NoPen))
            mbox.setZValue(-200)
            box.append(mbox)

        if stat.median is not None:
            median_line = line(stat.median, -height / 2,
                               stat.median, height / 2)
            median_line.setPen(self._pen_median)
            median_line.setZValue(-150)
            box.append(median_line)

        return box

    def strudel(self, dist, group_val_index=None):
        attr = self.attribute
        ss = np.sum(dist)
        box = []
        if ss < 1e-6:
            cond = [FilterDiscrete(attr, None)]
            if group_val_index is not None:
                cond.append(FilterDiscrete(self.group_var, [group_val_index]))
            box.append(FilterGraphicsRectItem(cond, 0, -10, 1, 10))
        cum = 0
        for i, v in enumerate(dist):
            if v < 1e-6:
                continue
            if self.stretched:
                v /= ss
            v *= self.scale_x
            cond = [FilterDiscrete(attr, [i])]
            if group_val_index is not None:
                cond.append(FilterDiscrete(self.group_var, [group_val_index]))
            rect = FilterGraphicsRectItem(cond, cum + 1, -6, v - 2, 12)
            rect.setBrush(QBrush(QColor(*attr.colors[i])))
            rect.setPen(QPen(Qt.NoPen))
            if self.stretched:
                tooltip = "{}: {:.2f}%".format(attr.values[i],
                                               100 * dist[i] / sum(dist))
            else:
                tooltip = "{}: {}".format(attr.values[i], int(dist[i]))
            rect.setToolTip(tooltip)
            text = QGraphicsTextItem(attr.values[i])
            box.append(rect)
            box.append(text)
            cum += v
        return box

    def commit(self):
        self.conditions = [item.filter for item in
                           self.box_scene.selectedItems() if item.filter]
        selected, selection = None, []
        if self.conditions:
            selected = Values(self.conditions, conjunction=False)(self.dataset)
            selection = np.in1d(
                self.dataset.ids, selected.ids, assume_unique=True).nonzero()[0]
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.dataset, selection))

    def show_posthoc(self):
        def line(y0, y1):
            it = self.box_scene.addLine(x, y0, x, y1, self._post_line_pen)
            it.setZValue(-100)
            self.posthoc_lines.append(it)

        while self.posthoc_lines:
            self.box_scene.removeItem(self.posthoc_lines.pop())

        if self.compare == OWBoxPlot.CompareNone or len(self.stats) < 2:
            return

        if self.compare == OWBoxPlot.CompareMedians:
            crit_line = "median"
        else:
            crit_line = "mean"

        xs = []

        height = 90 if self.show_annotations else 60

        y_up = -len(self.stats) * height + 10
        for pos, box_index in enumerate(self.order):
            stat = self.stats[box_index]
            x = getattr(stat, crit_line)
            if x is None:
                continue
            x *= self.scale_x
            xs.append(x * self.scale_x)
            by = y_up + pos * height
            line(by + 12, 3)
            line(by - 12, by - 25)

        used_to = []
        last_to = to = 0
        for frm, frm_x in enumerate(xs[:-1]):
            for to in range(frm + 1, len(xs)):
                if xs[to] - frm_x > 1.5:
                    to -= 1
                    break
            if to in (last_to, frm):
                continue
            for rowi, used in enumerate(used_to):
                if used < frm:
                    used_to[rowi] = to
                    break
            else:
                rowi = len(used_to)
                used_to.append(to)
            y = - 6 - rowi * 6
            it = self.box_scene.addLine(frm_x - 2, y, xs[to] + 2, y,
                                        self._post_grp_pen)
            self.posthoc_lines.append(it)
            last_to = to

    def get_widget_name_extension(self):
        return self.attribute.name if self.attribute else None

    def send_report(self):
        self.report_plot()
        text = ""
        if self.attribute:
            text += "Box plot for attribute '{}' ".format(self.attribute.name)
        if self.group_var:
            text += "grouped by '{}'".format(self.group_var.name)
        if text:
            self.report_caption(text)

    class Label(QGraphicsSimpleTextItem):
        """Boxplot Label with settable maxWidth"""
        # Minimum width to display label text
        MIN_LABEL_WIDTH = 25

        # padding bellow the text
        PADDING = 3

        __max_width = None

        def maxWidth(self):
            return self.__max_width

        def setMaxWidth(self, max_width):
            self.__max_width = max_width

        def paint(self, painter, option, widget):
            """Overrides QGraphicsSimpleTextItem.paint

            If label text is too long, it is elided
            to fit into the allowed region
            """
            if self.__max_width is None:
                width = option.rect.width()
            else:
                width = self.__max_width

            if width < self.MIN_LABEL_WIDTH:
                # if space is too narrow, no label
                return

            fm = painter.fontMetrics()
            text = fm.elidedText(self.text(), Qt.ElideRight, width)
            painter.drawText(
                option.rect.x(),
                option.rect.y() + self.boundingRect().height() - self.PADDING,
                text)
Beispiel #42
0
class OWMap(OWDataProjectionWidget):
    """
    Scatter plot visualization of coordinates data with geographic maps for
    background.
    """

    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/GeoMap.svg"
    priority = 100

    replaces = [
        "Orange.widgets.visualize.owmap.OWMap",
    ]

    settings_version = 3

    attr_lat = settings.ContextSetting(None)
    attr_lon = settings.ContextSetting(None)

    GRAPH_CLASS = OWScatterPlotMapGraph
    graph = settings.SettingProvider(OWScatterPlotMapGraph)
    embedding_variables_names = None

    class Error(OWDataProjectionWidget.Error):
        no_lat_lon_vars = Msg("Data has no latitude and longitude variables.")

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")
        out_of_range = Msg(
            "Points with out of range latitude or longitude are not displayed."
        )
        no_internet = Msg("Cannot fetch map from the internet. "
                          "Displaying only cached parts.")

    class Information(OWDataProjectionWidget.Information):
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        super().__init__()
        self._attr_lat, self._attr_lon = None, None
        self.graph.show_internet_error.connect(self._show_internet_error)

    def _show_internet_error(self, show):
        if not self.Warning.no_internet.is_shown() and show:
            self.Warning.no_internet()
        elif self.Warning.no_internet.is_shown() and not show:
            self.Warning.no_internet.clear()

    def _add_controls(self):
        self.lat_lon_model = DomainModel(DomainModel.MIXED,
                                         valid_types=ContinuousVariable)

        lat_lon_box = gui.vBox(self.controlArea, True)
        options = dict(labelWidth=75,
                       orientation=Qt.Horizontal,
                       sendSelectedValue=True,
                       valueType=str,
                       contentsLength=14)

        gui.comboBox(lat_lon_box,
                     self,
                     'graph.tile_provider_key',
                     label='Map:',
                     items=list(TILE_PROVIDERS.keys()),
                     callback=self.graph.update_tile_provider,
                     **options)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lat',
                     label='Latitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lon',
                     label='Longitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        super()._add_controls()

        gui.checkBox(
            self._plot_box,
            self,
            value="graph.freeze",
            label="Freeze map",
            tooltip="If checked, the map won't change position to fit new data."
        )

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        lat_data = self.get_column(self.attr_lat, filter_valid=False)
        lon_data = self.get_column(self.attr_lon, filter_valid=False)
        if lat_data is None or lon_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(lat_data) & np.isfinite(lon_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_lat.name, self.attr_lon.name)

        in_range = (-MAX_LONGITUDE <= lon_data) & (lon_data <= MAX_LONGITUDE) &\
                   (-MAX_LATITUDE <= lat_data) & (lat_data <= MAX_LATITUDE)
        in_range = ~np.bitwise_xor(in_range, self.valid_data)
        self.Warning.out_of_range.clear()
        if in_range.sum() != len(lon_data):
            self.Warning.out_of_range()
        if in_range.sum() == 0:
            return None
        self.valid_data &= in_range

        x, y = deg2norm(lon_data, lat_data)
        # invert y to increase from bottom to top
        y = 1 - y
        return np.vstack((x, y)).T

    def check_data(self):
        super().check_data()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def init_attr_values(self):
        lat, lon = None, None
        if self.data is not None:
            lat, lon = find_lat_lon(self.data, filter_hidden=True)
            if lat is None or lon is None:
                # we either find both or we don't have valid data
                self.Error.no_lat_lon_vars()
                self.data = None
                lat, lon = None, None

        super().init_attr_values()

        self.lat_lon_model.set_domain(self.data.domain if self.data else None)
        self.attr_lat, self.attr_lon = lat, lon

    @property
    def effective_variables(self):
        return [self.attr_lat, self.attr_lon] \
            if self.attr_lat and self.attr_lon else []

    def showEvent(self, ev):
        super().showEvent(ev)
        # reset the map on show event since before that we didn't know the
        # right resolution
        self.graph.update_view_range()

    def resizeEvent(self, ev):
        super().resizeEvent(ev)
        # when resizing we need to constantly reset the map so that new
        # portions are drawn
        self.graph.update_view_range(match_data=False)

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            _settings["graph"] = {}
            if "tile_provider" in _settings:
                if _settings["tile_provider"] == "Watercolor":
                    _settings["tile_provider"] = DEFAULT_TILE_PROVIDER
                _settings["graph"]["tile_provider_key"] = \
                    _settings["tile_provider"]
            if "opacity" in _settings:
                _settings["graph"]["alpha_value"] = \
                    round(_settings["opacity"] * 2.55)
            if "zoom" in _settings:
                _settings["graph"]["point_width"] = \
                    round(_settings["zoom"] * 0.02)
            if "jittering" in _settings:
                _settings["graph"]["jitter_size"] = _settings["jittering"]
            if "show_legend" in _settings:
                _settings["graph"]["show_legend"] = _settings["show_legend"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            settings.migrate_str_to_variable(context,
                                             names="lat_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="lon_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="class_attr",
                                             none_placeholder="(None)")

            # those settings can have two none placeholder
            attr_placeholders = [("color_attr", "(Same color)"),
                                 ("label_attr", "(No labels)"),
                                 ("shape_attr", "(Same shape)"),
                                 ("size_attr", "(Same size)")]
            for attr, place in attr_placeholders:
                if context.values[attr][0] == place:
                    context.values[attr] = ("", context.values[attr][1])

                settings.migrate_str_to_variable(context,
                                                 names=attr,
                                                 none_placeholder="")
        if version < 3:
            settings.rename_setting(context, "lat_attr", "attr_lat")
            settings.rename_setting(context, "lon_attr", "attr_lon")
            settings.rename_setting(context, "color_attr", "attr_color")
            settings.rename_setting(context, "label_attr", "attr_label")
            settings.rename_setting(context, "shape_attr", "attr_shape")
            settings.rename_setting(context, "size_attr", "attr_size")
Beispiel #43
0
class OWScatterPlot(OWWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 2
    settingsHandler = DomainContextHandler()

    auto_send_selection = Setting(True)
    auto_sample = Setting(True)
    toolbar_selection = Setting(0)

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)

    #: Serialized selection state to be restored
    selection_group = Setting(None, schema_only=True)

    graph = SettingProvider(OWScatterPlotGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Information(OWWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")

    def __init__(self):
        super().__init__()

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWScatterPlotGraph(self, box, "ScatterPlot")
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        axispen = QPen(self.palette().color(QPalette.Text))
        axis = plot.getAxis("bottom")
        axis.setPen(axispen)

        axis = plot.getAxis("left")
        axis.setPen(axispen)

        self.data = None  # Orange.data.Table
        self.subset_data = None  # Orange.data.Table
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        #: Remember the saved state to restore
        self.__pending_selection_restore = self.selection_group
        self.selection_group = None

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str)
        box = gui.vBox(self.controlArea, "Axis Data")
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self.update_attr,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:", callback=self.update_attr,
            model=self.xy_model, **common_options)

        vizrank_box = gui.hBox(box)
        gui.separator(vizrank_box, width=common_options["labelWidth"])
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

        gui.separator(box)

        g = self.graph.gui
        g.add_widgets([g.JitterSizeSlider,
                       g.JitterNumericValues], box)

        self.sampling = gui.auto_commit(
            self.controlArea, self, "auto_sample", "Sample", box="Sampling",
            callback=self.switch_sampling, commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

        g.point_properties_box(self.controlArea)
        self.models = [self.xy_model] + g.points_models

        box_plot_prop = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([g.ShowLegend,
                       g.ShowGridLines,
                       g.ToolTipShowsAll,
                       g.ClassDensity,
                       g.RegressionLine,
                       g.LabelOnlySelected], box_plot_prop)

        self.graph.box_zoom_select(self.controlArea)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_send_selection",
                        "Send Selection", "Send Automatically")

        self.graph.zoom_actions(self)

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
                     len([v for v in chain(self.data.domain.variables, self.data.domain.metas)
                          if v.is_primitive]) > 2\
                     and len(self.data) > 1
        self.vizrank_button.setEnabled(
            is_enabled and self.graph.attr_color is not None and
            not np.isnan(self.data.get_column_view(self.graph.attr_color)[0].astype(float)).all())
        if is_enabled and self.graph.attr_color is None:
            self.vizrank_button.setToolTip("Color variable has to be selected.")
        else:
            self.vizrank_button.setToolTip("")

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.Information.sampled_sql.clear()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled_sql()
                self.sql_data = data
                data_sample = data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if data is not None and (len(data) == 0 or len(data.domain) == 0):
            data = None
        if self.data and data and self.data.checksum() == data.checksum():
            return

        self.closeContext()
        same_domain = (self.data and data and
                       data.domain.checksum() == self.data.domain.checksum())
        self.data = data

        if not same_domain:
            self.init_attr_values()
        self.openContext(self.data)
        self._vizrank_color_change()

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Orange.data.Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.graph.attr_label, str):
            self.graph.attr_label = findvar(
                self.graph.attr_label, self.graph.gui.label_model)
        if isinstance(self.graph.attr_color, str):
            self.graph.attr_color = findvar(
                self.graph.attr_color, self.graph.gui.color_model)
        if isinstance(self.graph.attr_shape, str):
            self.graph.attr_shape = findvar(
                self.graph.attr_shape, self.graph.gui.shape_model)
        if isinstance(self.graph.attr_size, str):
            self.graph.attr_size = findvar(
                self.graph.attr_size, self.graph.gui.size_model)

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            return self.__timer.stop()
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        self.subset_data = subset_data
        self.controls.graph.alpha_value.setEnabled(subset_data is None)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.graph.new_data(self.data, self.subset_data)
        if self.attribute_selection_list and self.graph.domain is not None and \
                all(attr in self.graph.domain
                        for attr in self.attribute_selection_list):
            self.attr_x = self.attribute_selection_list[0]
            self.attr_y = self.attribute_selection_list[1]
        self.attribute_selection_list = None
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        if self.data is not None and self.__pending_selection_restore is not None:
            self.apply_selection(self.__pending_selection_restore)
            self.__pending_selection_restore = None
        self.unconditional_commit()

    def apply_selection(self, selection):
        """Apply `selection` to the current plot."""
        if self.data is not None:
            self.graph.selection = np.zeros(len(self.data), dtype=np.uint8)
            self.selection_group = [x for x in selection if x[0] < len(self.data)]
            selection_array = np.array(self.selection_group).T
            self.graph.selection[selection_array[0]] = selection_array[1]
            self.graph.update_colors(keep_colors=True)

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def init_attr_values(self):
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x
        self.graph.set_domain(data)

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        self.send_features()

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.attr_x, self.attr_y, reset_view)

    def selection_changed(self):

        # Store current selection in a setting that is stored in workflow
        if isinstance(self.data, SqlTable):
            selection = None
        elif self.data is not None:
            selection = self.graph.get_selection()
        else:
            selection = None
        if selection is not None and len(selection):
            self.selection_group = list(zip(selection, self.graph.selection[selection]))
        else:
            self.selection_group = None

        self.commit()

    def send_data(self):
        # TODO: Implement selection for sql data
        def _get_selected():
            if not len(selection):
                return None
            return create_groups_table(data, graph.selection, False, "Group")

        def _get_annotated():
            if graph.selection is not None and np.max(graph.selection) > 1:
                return create_groups_table(data, graph.selection)
            else:
                return create_annotated_table(data, selection)

        graph = self.graph
        data = self.data
        selection = graph.get_selection()
        self.Outputs.annotated_data.send(_get_annotated())
        self.Outputs.selected_data.send(_get_selected())

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def commit(self):
        self.send_data()
        self.send_features()

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        if self.data is None:
            return
        def name(var):
            return var and var.name
        caption = report.render_items_vert((
            ("Color", name(self.graph.attr_color)),
            ("Label", name(self.graph.attr_label)),
            ("Shape", name(self.graph.attr_shape)),
            ("Size", name(self.graph.attr_size)),
            ("Jittering", (self.attr_x.is_discrete or
                           self.attr_y.is_discrete or
                           self.graph.jitter_continuous) and
             self.graph.jitter_size)))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1) for a in settings["selection"]]
Beispiel #44
0
class OWSieveDiagram(OWWidget):
    name = "Sieve Diagram"
    description = "Visualize the observed and expected frequencies " \
                  "for a combination of values."
    icon = "icons/SieveDiagram.svg"
    priority = 200

    class Inputs:
        data = Input("Data", Table, default=True)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    graph_name = "canvas"

    want_control_area = False

    settings_version = 1
    settingsHandler = DomainContextHandler()
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    selection = ContextSetting(set())

    def __init__(self):
        # pylint: disable=missing-docstring
        super().__init__()

        self.data = self.discrete_data = None
        self.attrs = []
        self.input_features = None
        self.areas = []
        self.selection = set()

        self.attr_box = gui.hBox(self.mainArea)
        self.domain_model = DomainModel(valid_types=DomainModel.PRIMITIVE)
        combo_args = dict(
            widget=self.attr_box, master=self, contentsLength=12,
            callback=self.update_attr, sendSelectedValue=True, valueType=str,
            model=self.domain_model)
        fixed_size = (QSizePolicy.Fixed, QSizePolicy.Fixed)
        gui.comboBox(value="attr_x", **combo_args)
        gui.widgetLabel(self.attr_box, "\u2715", sizePolicy=fixed_size)
        gui.comboBox(value="attr_y", **combo_args)
        self.vizrank, self.vizrank_button = SieveRank.add_vizrank(
            self.attr_box, self, "Score Combinations", self.set_attr)
        self.vizrank_button.setSizePolicy(*fixed_size)

        self.canvas = QGraphicsScene()
        self.canvasView = ViewWithPress(
            self.canvas, self.mainArea, handler=self.reset_selection)
        self.mainArea.layout().addWidget(self.canvasView)
        self.canvasView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvasView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)

    def sizeHint(self):
        return QSize(450, 550)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self.update_graph()

    def showEvent(self, event):
        super().showEvent(event)
        self.update_graph()

    @classmethod
    def migrate_context(cls, context, version):
        if not version:
            settings.rename_setting(context, "attrX", "attr_x")
            settings.rename_setting(context, "attrY", "attr_y")
            settings.migrate_str_to_variable(context)

    @Inputs.data
    def set_data(self, data):
        """
        Discretize continuous attributes, and put all attributes and discrete
        metas into self.attrs.

        Select the first two attributes unless context overrides this.
        Method `resolve_shown_attributes` is called to use the attributes from
        the input, if it exists and matches the attributes in the data.

        Remove selection; again let the context override this.
        Initialize the vizrank dialog, but don't show it.

        Args:
            data (Table): input data
        """
        if isinstance(data, SqlTable) and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data
        self.areas = []
        self.selection = set()
        if self.data is None:
            self.attrs[:] = []
            self.domain_model.set_domain(None)
            self.discrete_data = None
        else:
            self.domain_model.set_domain(data.domain)
        self.attrs = [x for x in self.domain_model if isinstance(x, Variable)]
        if self.attrs:
            self.attr_x = self.attrs[0]
            self.attr_y = self.attrs[len(self.attrs) > 1]
        else:
            self.attr_x = self.attr_y = None
            self.areas = []
            self.selection = set()
        self.openContext(self.data)
        if self.data:
            self.discrete_data = self.sparse_to_dense(data, True)
        self.resolve_shown_attributes()
        self.update_graph()
        self.update_selection()

        self.vizrank.initialize()
        self.vizrank_button.setEnabled(
            self.data is not None and len(self.data) > 1 and
            len(self.data.domain.attributes) > 1 and not self.data.is_sparse())

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        """Update the graph and selection."""
        self.selection = set()
        self.discrete_data = self.sparse_to_dense(self.data)
        self.update_graph()
        self.update_selection()

    def sparse_to_dense(self, data, init=False):
        """
        Extracts two selected columns from sparse matrix.
        GH-2260
        """
        def discretizer(data):
            if any(attr.is_continuous for attr in chain(data.domain.variables, data.domain.metas)):
                discretize = Discretize(
                    method=EqualFreq(n=4), remove_const=False,
                    discretize_classes=True, discretize_metas=True)
                return discretize(data).to_dense()
            return data

        if not data.is_sparse() and not init:
            return self.discrete_data
        if data.is_sparse():
            attrs = {self.attr_x,
                     self.attr_y}
            new_domain = data.domain.select_columns(attrs)
            data = Table.from_table(new_domain, data)
        return discretizer(data)

    @Inputs.features
    def set_input_features(self, attr_list):
        """
        Handler for the Features signal.

        The method stores the attributes and calls `resolve_shown_attributes`

        Args:
            attr_list (AttributeList): data from the signal
        """
        self.input_features = attr_list
        self.resolve_shown_attributes()
        self.update_selection()

    def resolve_shown_attributes(self):
        """
        Use the attributes from the input signal if the signal is present
        and at least two attributes appear in the domain. If there are
        multiple, use the first two. Combos are disabled if inputs are used.
        """
        self.warning()
        self.attr_box.setEnabled(True)
        if not self.input_features:  # None or empty
            return
        features = [f for f in self.input_features if f in self.domain_model]
        if not features:
            self.warning(
                "Features from the input signal are not present in the data")
            return
        old_attrs = self.attr_x, self.attr_y
        self.attr_x, self.attr_y = [f for f in (features * 2)[:2]]
        self.attr_box.setEnabled(False)
        if (self.attr_x, self.attr_y) != old_attrs:
            self.selection = set()
            self.update_graph()

    def reset_selection(self):
        self.selection = set()
        self.update_selection()

    def select_area(self, area, event):
        """
        Add or remove the clicked area from the selection

        Args:
            area (QRect): the area that is clicked
            event (QEvent): event description
        """
        if event.button() != Qt.LeftButton:
            return
        index = self.areas.index(area)
        if event.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection()

    def update_selection(self):
        """
        Update the graph (pen width) to show the current selection.
        Filter and output the data.
        """
        if self.areas is None or not self.selection:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(create_annotated_table(self.data, []))
            return

        filts = []
        for i, area in enumerate(self.areas):
            if i in self.selection:
                width = 4
                val_x, val_y = area.value_pair
                filts.append(
                    filter.Values([
                        filter.FilterDiscrete(self.attr_x.name, [val_x]),
                        filter.FilterDiscrete(self.attr_y.name, [val_y])
                    ]))
            else:
                width = 1
            pen = area.pen()
            pen.setWidth(width)
            area.setPen(pen)
        if len(filts) == 1:
            filts = filts[0]
        else:
            filts = filter.Values(filts, conjunction=False)
        selection = filts(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]
        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(create_annotated_table(self.data, sel_idx))

    def update_graph(self):
        # Function uses weird names like r, g, b, but it does it with utmost
        # caution, hence
        # pylint: disable=invalid-name
        """Update the graph."""

        def text(txt, *args, **kwargs):
            return CanvasText(self.canvas, "", html_text=to_html(txt),
                              *args, **kwargs)

        def width(txt):
            return text(txt, 0, 0, show=False).boundingRect().width()

        def fmt(val):
            return str(int(val)) if val % 1 == 0 else "{:.2f}".format(val)

        def show_pearson(rect, pearson, pen_width):
            """
            Color the given rectangle according to its corresponding
            standardized Pearson residual.

            Args:
                rect (QRect): the rectangle being drawn
                pearson (float): signed standardized pearson residual
                pen_width (int): pen width (bolder pen is used for selection)
            """
            r = rect.rect()
            x, y, w, h = r.x(), r.y(), r.width(), r.height()
            if w == 0 or h == 0:
                return

            r = b = 255
            if pearson > 0:
                r = g = max(255 - 20 * pearson, 55)
            elif pearson < 0:
                b = g = max(255 + 20 * pearson, 55)
            else:
                r = g = b = 224
            rect.setBrush(QBrush(QColor(r, g, b)))
            pen_color = QColor(255 * (r == 255), 255 * (g == 255),
                               255 * (b == 255))
            pen = QPen(pen_color, pen_width)
            rect.setPen(pen)
            if pearson > 0:
                pearson = min(pearson, 10)
                dist = 20 - 1.6 * pearson
            else:
                pearson = max(pearson, -10)
                dist = 20 - 8 * pearson
            pen.setWidth(1)

            def _offseted_line(ax, ay):
                r = QGraphicsLineItem(x + ax, y + ay, x + (ax or w),
                                      y + (ay or h))
                self.canvas.addItem(r)
                r.setPen(pen)

            ax = dist
            while ax < w:
                _offseted_line(ax, 0)
                ax += dist

            ay = dist
            while ay < h:
                _offseted_line(0, ay)
                ay += dist

        def make_tooltip():
            """Create the tooltip. The function uses local variables from
            the enclosing scope."""
            # pylint: disable=undefined-loop-variable
            def _oper(attr, txt):
                if self.data.domain[attr.name] is ddomain[attr.name]:
                    return "="
                return " " if txt[0] in "<≥" else " in "

            return (
                "<b>{attr_x}{xeq}{xval_name}</b>: {obs_x}/{n} ({p_x:.0f} %)".
                format(attr_x=to_html(attr_x.name),
                       xeq=_oper(attr_x, xval_name),
                       xval_name=to_html(xval_name),
                       obs_x=fmt(chi.probs_x[x] * n),
                       n=int(n),
                       p_x=100 * chi.probs_x[x]) +
                "<br/>" +
                "<b>{attr_y}{yeq}{yval_name}</b>: {obs_y}/{n} ({p_y:.0f} %)".
                format(attr_y=to_html(attr_y.name),
                       yeq=_oper(attr_y, yval_name),
                       yval_name=to_html(yval_name),
                       obs_y=fmt(chi.probs_y[y] * n),
                       n=int(n),
                       p_y=100 * chi.probs_y[y]) +
                "<hr/>" +
                """<b>combination of values: </b><br/>
                   &nbsp;&nbsp;&nbsp;expected {exp} ({p_exp:.0f} %)<br/>
                   &nbsp;&nbsp;&nbsp;observed {obs} ({p_obs:.0f} %)""".
                format(exp=fmt(chi.expected[y, x]),
                       p_exp=100 * chi.expected[y, x] / n,
                       obs=fmt(chi.observed[y, x]),
                       p_obs=100 * chi.observed[y, x] / n))

        for item in self.canvas.items():
            self.canvas.removeItem(item)
        if self.data is None or len(self.data) == 0 or \
                self.attr_x is None or self.attr_y is None:
            return

        ddomain = self.discrete_data.domain
        attr_x, attr_y = self.attr_x, self.attr_y
        disc_x, disc_y = ddomain[attr_x.name], ddomain[attr_y.name]
        view = self.canvasView

        chi = ChiSqStats(self.discrete_data, disc_x, disc_y)
        max_ylabel_w = max((width(val) for val in disc_y.values), default=0)
        max_ylabel_w = min(max_ylabel_w, 200)
        x_off = width(attr_x.name) + max_ylabel_w
        y_off = 15
        square_size = min(view.width() - x_off - 35, view.height() - y_off - 80)
        square_size = max(square_size, 10)
        self.canvasView.setSceneRect(0, 0, view.width(), view.height())
        if not disc_x.values or not disc_y.values:
            text_ = "Features {} and {} have no values".format(disc_x, disc_y) \
                if not disc_x.values and \
                   not disc_y.values and \
                          disc_x != disc_y \
                else \
                    "Feature {} has no values".format(
                        disc_x if not disc_x.values else disc_y)
            text(text_, view.width() / 2 + 70, view.height() / 2,
                 Qt.AlignRight | Qt.AlignVCenter)
            return
        n = chi.n
        curr_x = x_off
        max_xlabel_h = 0
        self.areas = []
        for x, (px, xval_name) in enumerate(zip(chi.probs_x, disc_x.values)):
            if px == 0:
                continue
            width = square_size * px

            curr_y = y_off
            for y in range(len(chi.probs_y) - 1, -1, -1):  # bottom-up order
                py = chi.probs_y[y]
                yval_name = disc_y.values[y]
                if py == 0:
                    continue
                height = square_size * py

                selected = len(self.areas) in self.selection
                rect = CanvasRectangle(
                    self.canvas, curr_x + 2, curr_y + 2, width - 4, height - 4,
                    z=-10, onclick=self.select_area)
                rect.value_pair = x, y
                self.areas.append(rect)
                show_pearson(rect, chi.residuals[y, x], 3 * selected)
                rect.setToolTip(make_tooltip())

                if x == 0:
                    text(yval_name, x_off, curr_y + height / 2,
                         Qt.AlignRight | Qt.AlignVCenter)
                curr_y += height

            xl = text(xval_name, curr_x + width / 2, y_off + square_size,
                      Qt.AlignHCenter | Qt.AlignTop)
            max_xlabel_h = max(int(xl.boundingRect().height()), max_xlabel_h)
            curr_x += width

        bottom = y_off + square_size + max_xlabel_h
        text(attr_y.name, 0, y_off + square_size / 2,
             Qt.AlignLeft | Qt.AlignVCenter, bold=True, vertical=True)
        text(attr_x.name, x_off + square_size / 2, bottom,
             Qt.AlignHCenter | Qt.AlignTop, bold=True)
        bottom += 30
        xl = text("χ²={:.2f}, p={:.3f}".format(chi.chisq, chi.p),
                  0, bottom)
        # Assume similar height for both lines
        text("N = " + fmt(chi.n), 0, bottom - xl.boundingRect().height())

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        self.report_plot()
Beispiel #45
0
class OWBoxPlot(widget.OWWidget):
    """
    Here's how the widget's functions call each other:

    - `set_data` is a signal handler fills the list boxes and calls
    `grouping_changed`.

    - `grouping_changed` handles changes of grouping attribute: it enables or
    disables the box for ordering, orders attributes and calls `attr_changed`.

    - `attr_changed` handles changes of attribute. It recomputes box data by
    calling `compute_box_data`, shows the appropriate display box
    (discrete/continuous) and then calls`layout_changed`

    - `layout_changed` constructs all the elements for the scene (as lists of
    QGraphicsItemGroup) and calls `display_changed`. It is called when the
    attribute or grouping is changed (by attr_changed) and on resize event.

    - `display_changed` puts the elements corresponding to the current display
    settings on the scene. It is called when the elements are reconstructed
    (layout is changed due to selection of attributes or resize event), or
    when the user changes display settings or colors.

    For discrete attributes, the flow is a bit simpler: the elements are not
    constructed in advance (by layout_changed). Instead, layout_changed and
    display_changed call display_changed_disc that draws everything.
    """
    name = "Box Plot"
    description = "Visualize the distribution of feature values in a box plot."
    icon = "icons/BoxPlot.svg"
    priority = 100
    keywords = ["whisker"]

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    #: Comparison types for continuous variables
    CompareNone, CompareMedians, CompareMeans = 0, 1, 2

    settingsHandler = DomainContextHandler()
    conditions = ContextSetting([])

    attribute = ContextSetting(None)
    order_by_importance = Setting(False)
    group_var = ContextSetting(None)
    show_annotations = Setting(True)
    compare = Setting(CompareMeans)
    stattest = Setting(0)
    sig_threshold = Setting(0.05)
    stretched = Setting(True)
    show_labels = Setting(True)
    sort_freqs = Setting(False)
    auto_commit = Setting(True)

    _sorting_criteria_attrs = {
        CompareNone: "",
        CompareMedians: "median",
        CompareMeans: "mean"
    }

    _pen_axis_tick = QPen(Qt.white, 5)
    _pen_axis = QPen(Qt.darkGray, 3)
    _pen_median = QPen(QBrush(QColor(0xff, 0xff, 0x00)), 2)
    _pen_paramet = QPen(QBrush(QColor(0x33, 0x00, 0xff)), 2)
    _pen_dotted = QPen(QBrush(QColor(0x33, 0x00, 0xff)), 1)
    _pen_dotted.setStyle(Qt.DotLine)
    _post_line_pen = QPen(Qt.lightGray, 2)
    _post_grp_pen = QPen(Qt.lightGray, 4)
    for pen in (_pen_paramet, _pen_median, _pen_dotted, _pen_axis,
                _pen_axis_tick, _post_line_pen, _post_grp_pen):
        pen.setCosmetic(True)
        pen.setCapStyle(Qt.RoundCap)
        pen.setJoinStyle(Qt.RoundJoin)
    _pen_axis_tick.setCapStyle(Qt.FlatCap)

    _box_brush = QBrush(QColor(0x33, 0x88, 0xff, 0xc0))

    _axis_font = QFont()
    _axis_font.setPixelSize(12)
    _label_font = QFont()
    _label_font.setPixelSize(11)
    _attr_brush = QBrush(QColor(0x33, 0x00, 0xff))

    graph_name = "box_scene"

    def __init__(self):
        super().__init__()
        self.stats = []
        self.dataset = None
        self.posthoc_lines = []

        self.label_txts = self.mean_labels = self.boxes = self.labels = \
            self.label_txts_all = self.attr_labels = self.order = []
        self.p = -1.0
        self.scale_x = self.scene_min_x = self.scene_width = 0
        self.label_width = 0

        self.attrs = VariableListModel()
        view = gui.listView(self.controlArea,
                            self,
                            "attribute",
                            box="Variable",
                            model=self.attrs,
                            callback=self.attr_changed)
        view.setMinimumSize(QSize(30, 30))
        # Any other policy than Ignored will let the QListBox's scrollbar
        # set the minimal height (see the penultimate paragraph of
        # http://doc.qt.io/qt-4.8/qabstractscrollarea.html#addScrollBarWidget)
        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)
        gui.separator(view.box, 6, 6)
        self.cb_order = gui.checkBox(
            view.box,
            self,
            "order_by_importance",
            "Order by relevance",
            tooltip="Order by 𝜒² or ANOVA over the subgroups",
            callback=self.apply_sorting)
        self.group_vars = DomainModel(placeholder="None",
                                      separators=False,
                                      valid_types=Orange.data.DiscreteVariable)
        self.group_view = view = gui.listView(self.controlArea,
                                              self,
                                              "group_var",
                                              box="Subgroups",
                                              model=self.group_vars,
                                              callback=self.grouping_changed)
        view.setEnabled(False)
        view.setMinimumSize(QSize(30, 30))
        # See the comment above
        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)

        # TODO: move Compare median/mean to grouping box
        # The vertical size policy is needed to let only the list views expand
        self.display_box = gui.vBox(self.controlArea,
                                    "Display",
                                    sizePolicy=(QSizePolicy.Minimum,
                                                QSizePolicy.Maximum),
                                    addSpace=False)

        gui.checkBox(self.display_box,
                     self,
                     "show_annotations",
                     "Annotate",
                     callback=self.display_changed)
        self.compare_rb = gui.radioButtonsInBox(
            self.display_box,
            self,
            'compare',
            btnLabels=["No comparison", "Compare medians", "Compare means"],
            callback=self.layout_changed)

        # The vertical size policy is needed to let only the list views expand
        self.stretching_box = box = gui.vBox(self.controlArea,
                                             box="Display",
                                             sizePolicy=(QSizePolicy.Minimum,
                                                         QSizePolicy.Fixed))
        self.stretching_box.sizeHint = self.display_box.sizeHint
        gui.checkBox(box,
                     self,
                     'stretched',
                     "Stretch bars",
                     callback=self.display_changed)
        gui.checkBox(box,
                     self,
                     'show_labels',
                     "Show box labels",
                     callback=self.display_changed)
        self.sort_cb = gui.checkBox(box,
                                    self,
                                    'sort_freqs',
                                    "Sort by subgroup frequencies",
                                    callback=self.display_changed)
        gui.rubber(box)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

        gui.vBox(self.mainArea, addSpace=True)
        self.box_scene = QGraphicsScene()
        self.box_scene.selectionChanged.connect(self.commit)
        self.box_view = QGraphicsView(self.box_scene)
        self.box_view.setRenderHints(QPainter.Antialiasing
                                     | QPainter.TextAntialiasing
                                     | QPainter.SmoothPixmapTransform)
        self.box_view.viewport().installEventFilter(self)

        self.mainArea.layout().addWidget(self.box_view)

        e = gui.hBox(self.mainArea, addSpace=False)
        self.infot1 = gui.widgetLabel(e, "<center>No test results.</center>")
        self.mainArea.setMinimumWidth(600)

        self.stats = self.dist = self.conts = []
        self.is_continuous = False

        self.update_display_box()

    def sizeHint(self):
        return QSize(100, 500)  # Vertical size is regulated by mainArea

    def eventFilter(self, obj, event):
        if obj is self.box_view.viewport() and \
                event.type() == QEvent.Resize:
            self.layout_changed()

        return super().eventFilter(obj, event)

    def reset_attrs(self, domain):
        self.attrs[:] = [
            var for var in chain(domain.class_vars, domain.metas,
                                 domain.attributes) if var.is_primitive()
        ]

    # noinspection PyTypeChecker
    @Inputs.data
    def set_data(self, dataset):
        if dataset is not None and (not bool(dataset)
                                    or not len(dataset.domain)):
            dataset = None
        self.closeContext()
        self.dataset = dataset
        self.dist = self.stats = self.conts = []
        self.group_var = None
        self.attribute = None
        if dataset:
            domain = dataset.domain
            self.group_vars.set_domain(domain)
            self.group_view.setEnabled(len(self.group_vars) > 1)
            self.reset_attrs(domain)
            self.select_default_variables(domain)
            self.openContext(self.dataset)
            self.grouping_changed()
        else:
            self.reset_all_data()
        self.commit()

    def select_default_variables(self, domain):
        # visualize first non-class variable, group by class (if present)
        if len(self.attrs) > len(domain.class_vars):
            self.attribute = self.attrs[len(domain.class_vars)]
        elif self.attrs:
            self.attribute = self.attrs[0]

        if domain.class_var and domain.class_var.is_discrete:
            self.group_var = domain.class_var
        else:
            self.group_var = None  # Reset to trigger selection via callback

    def apply_sorting(self):
        def compute_score(attr):
            if attr is group_var:
                return 3
            if attr.is_continuous:
                # One-way ANOVA
                col = data.get_column_view(attr)[0].astype(float)
                groups = (col[group_col == i] for i in range(n_groups))
                groups = (col[~np.isnan(col)] for col in groups)
                groups = [group for group in groups if len(group)]
                p = f_oneway(*groups)[1] if len(groups) > 1 else 2
            else:
                # Chi-square with the given distribution into groups
                # (see degrees of freedom in computation of the p-value)
                if not attr.values or not group_var.values:
                    return 2
                observed = np.array(
                    contingency.get_contingency(data, group_var, attr))
                observed = observed[observed.sum(axis=1) != 0, :]
                observed = observed[:, observed.sum(axis=0) != 0]
                if min(observed.shape) < 2:
                    return 2
                expected = \
                    np.outer(observed.sum(axis=1), observed.sum(axis=0)) / \
                    np.sum(observed)
                p = chisquare(observed.ravel(),
                              f_exp=expected.ravel(),
                              ddof=n_groups - 1)[1]
            if math.isnan(p):
                return 2
            return p

        data = self.dataset
        if data is None:
            return
        domain = data.domain
        attribute = self.attribute
        group_var = self.group_var
        if self.order_by_importance and group_var is not None:
            n_groups = len(group_var.values)
            group_col = data.get_column_view(group_var)[0] if \
                domain.has_continuous_attributes(
                    include_class=True, include_metas=True) else None
            self.attrs.sort(key=compute_score)
        else:
            self.reset_attrs(domain)
        self.attribute = attribute

    def reset_all_data(self):
        self.clear_scene()
        self.infot1.setText("")
        self.attrs.clear()
        self.group_vars.set_domain(None)
        self.group_view.setEnabled(False)
        self.is_continuous = False
        self.update_display_box()

    def grouping_changed(self):
        self.cb_order.setEnabled(self.group_var is not None)
        self.apply_sorting()
        self.attr_changed()

    def select_box_items(self):
        temp_cond = self.conditions.copy()
        for box in self.box_scene.items():
            if isinstance(box, FilterGraphicsRectItem):
                box.setSelected(
                    box.filter.conditions in [c.conditions for c in temp_cond])

    def attr_changed(self):
        self.compute_box_data()
        self.update_display_box()
        self.layout_changed()

        if self.is_continuous:
            heights = 90 if self.show_annotations else 60
            self.box_view.centerOn(self.scene_min_x + self.scene_width / 2,
                                   -30 - len(self.stats) * heights / 2 + 45)
        else:
            self.box_view.centerOn(self.scene_width / 2,
                                   -30 - len(self.boxes) * 40 / 2 + 45)

    def compute_box_data(self):
        attr = self.attribute
        if not attr:
            return
        dataset = self.dataset
        self.is_continuous = attr.is_continuous
        if dataset is None or not self.is_continuous and not attr.values or \
                        self.group_var and not self.group_var.values:
            self.stats = self.dist = self.conts = []
            return
        if self.group_var:
            self.dist = []
            self.conts = contingency.get_contingency(dataset, attr,
                                                     self.group_var)
            if self.is_continuous:
                stats, label_texts = [], []
                for i, cont in enumerate(self.conts):
                    if np.sum(cont[1]):
                        stats.append(BoxData(cont, attr, i, self.group_var))
                        label_texts.append(self.group_var.values[i])
                self.stats = stats
                self.label_txts_all = label_texts
            else:
                self.label_txts_all = \
                    [v for v, c in zip(self.group_var.values, self.conts)
                     if np.sum(c) > 0]
        else:
            self.dist = distribution.get_distribution(dataset, attr)
            self.conts = []
            if self.is_continuous:
                self.stats = [BoxData(self.dist, attr, None)]
            self.label_txts_all = [""]
        self.label_txts = [
            txts for stat, txts in zip(self.stats, self.label_txts_all)
            if stat.n > 0
        ]
        self.stats = [stat for stat in self.stats if stat.n > 0]

    def update_display_box(self):
        if self.is_continuous:
            self.stretching_box.hide()
            self.display_box.show()
            self.compare_rb.setEnabled(self.group_var is not None)
        else:
            self.stretching_box.show()
            self.display_box.hide()
            self.sort_cb.setEnabled(self.group_var is not None)

    def clear_scene(self):
        self.closeContext()
        self.box_scene.clearSelection()
        self.box_scene.clear()
        self.box_view.viewport().update()
        self.attr_labels = []
        self.labels = []
        self.boxes = []
        self.mean_labels = []
        self.posthoc_lines = []
        self.openContext(self.dataset)

    def layout_changed(self):
        attr = self.attribute
        if not attr:
            return
        self.clear_scene()
        if self.dataset is None or len(self.conts) == len(self.dist) == 0:
            return

        if not self.is_continuous:
            self.display_changed_disc()
            return

        self.mean_labels = [
            self.mean_label(stat, attr, lab)
            for stat, lab in zip(self.stats, self.label_txts)
        ]
        self.draw_axis()
        self.boxes = [self.box_group(stat) for stat in self.stats]
        self.labels = [
            self.label_group(stat, attr, mean_lab)
            for stat, mean_lab in zip(self.stats, self.mean_labels)
        ]
        self.attr_labels = [
            QGraphicsSimpleTextItem(lab) for lab in self.label_txts
        ]
        for it in chain(self.labels, self.attr_labels):
            self.box_scene.addItem(it)
        self.display_changed()

    def display_changed(self):
        if self.dataset is None:
            return

        if not self.is_continuous:
            self.display_changed_disc()
            return

        self.order = list(range(len(self.stats)))
        criterion = self._sorting_criteria_attrs[self.compare]
        if criterion:
            vals = [getattr(stat, criterion) for stat in self.stats]
            overmax = max((val for val in vals if val is not None), default=0) \
                      + 1
            vals = [val if val is not None else overmax for val in vals]
            self.order = sorted(self.order, key=vals.__getitem__)

        heights = 90 if self.show_annotations else 60

        for row, box_index in enumerate(self.order):
            y = (-len(self.stats) + row) * heights + 10
            for item in self.boxes[box_index]:
                self.box_scene.addItem(item)
                item.setY(y)
            labels = self.labels[box_index]

            if self.show_annotations:
                labels.show()
                labels.setY(y)
            else:
                labels.hide()

            label = self.attr_labels[box_index]
            label.setY(y - 15 - label.boundingRect().height())
            if self.show_annotations:
                label.hide()
            else:
                stat = self.stats[box_index]

                if self.compare == OWBoxPlot.CompareMedians and \
                        stat.median is not None:
                    pos = stat.median + 5 / self.scale_x
                elif self.compare == OWBoxPlot.CompareMeans or stat.q25 is None:
                    pos = stat.mean + 5 / self.scale_x
                else:
                    pos = stat.q25
                label.setX(pos * self.scale_x)
                label.show()

        r = QRectF(self.scene_min_x, -30 - len(self.stats) * heights,
                   self.scene_width,
                   len(self.stats) * heights + 90)
        self.box_scene.setSceneRect(r)

        self.compute_tests()
        self.show_posthoc()
        self.select_box_items()

    def display_changed_disc(self):
        assert not self.is_continuous
        self.clear_scene()
        self.attr_labels = [
            QGraphicsSimpleTextItem(lab) for lab in self.label_txts_all
        ]

        if not self.stretched:
            if self.group_var:
                self.labels = [
                    QGraphicsTextItem("{}".format(int(sum(cont))))
                    for cont in self.conts if np.sum(cont) > 0
                ]
            else:
                self.labels = [QGraphicsTextItem(str(int(sum(self.dist))))]

        self.order = list(range(len(self.attr_labels)))

        self.draw_axis_disc()
        if self.group_var:
            self.boxes = \
                [self.strudel(cont, i) for i, cont in enumerate(self.conts)
                 if np.sum(cont) > 0]
            self.conts = self.conts[np.sum(np.array(self.conts), axis=1) > 0]

            if self.sort_freqs:
                # pylint: disable=invalid-unary-operand-type
                self.order = sorted(
                    self.order, key=(-np.sum(self.conts, axis=1)).__getitem__)
        else:
            self.boxes = [self.strudel(self.dist)]

        for row, box_index in enumerate(self.order):
            y = (-len(self.boxes) + row) * 40 + 10
            box = self.boxes[box_index]
            bars, labels = box[::2], box[1::2]

            self.__draw_group_labels(y, box_index)
            if not self.stretched:
                self.__draw_row_counts(y, box_index)
            if self.show_labels and self.attribute is not self.group_var:
                self.__draw_bar_labels(y, bars, labels)
            self.__draw_bars(y, bars)

        self.box_scene.setSceneRect(-self.label_width - 5,
                                    -30 - len(self.boxes) * 40,
                                    self.scene_width,
                                    len(self.boxes * 40) + 90)
        self.infot1.setText("")
        self.select_box_items()

    def __draw_group_labels(self, y, row):
        """Draw group labels

        Parameters
        ----------
        y: int
            vertical offset of bars
        row: int
            row index
        """
        label = self.attr_labels[row]
        b = label.boundingRect()
        label.setPos(-b.width() - 10, y - b.height() / 2)
        self.box_scene.addItem(label)

    def __draw_row_counts(self, y, row):
        """Draw row counts

        Parameters
        ----------
        y: int
            vertical offset of bars
        row: int
            row index
        """
        assert not self.is_continuous
        label = self.labels[row]
        b = label.boundingRect()
        if self.group_var:
            right = self.scale_x * sum(self.conts[row])
        else:
            right = self.scale_x * sum(self.dist)
        label.setPos(right + 10, y - b.height() / 2)
        self.box_scene.addItem(label)

    def __draw_bar_labels(self, y, bars, labels):
        """Draw bar labels

        Parameters
        ----------
        y: int
            vertical offset of bars
        bars: List[FilterGraphicsRectItem]
            list of bars being drawn
        labels: List[QGraphicsTextItem]
            list of labels for corresponding bars
        """
        label = bar_part = None
        for text_item, bar_part in zip(labels, bars):
            label = self.Label(text_item.toPlainText())
            label.setPos(bar_part.boundingRect().x(),
                         y - label.boundingRect().height() - 8)
            label.setMaxWidth(bar_part.boundingRect().width())
            self.box_scene.addItem(label)

    def __draw_bars(self, y, bars):
        """Draw bars

        Parameters
        ----------
        y: int
            vertical offset of bars

        bars: List[FilterGraphicsRectItem]
            list of bars to draw
        """
        for item in bars:
            item.setPos(0, y)
            self.box_scene.addItem(item)

    # noinspection PyPep8Naming
    def compute_tests(self):
        # The t-test and ANOVA are implemented here since they efficiently use
        # the widget-specific data in self.stats.
        # The non-parametric tests can't do this, so we use statistics.tests
        def stat_ttest():
            d1, d2 = self.stats
            if d1.n == 0 or d2.n == 0:
                return np.nan, np.nan
            pooled_var = d1.var / d1.n + d2.var / d2.n
            df = pooled_var ** 2 / \
                ((d1.var / d1.n) ** 2 / (d1.n - 1) +
                 (d2.var / d2.n) ** 2 / (d2.n - 1))
            if pooled_var == 0:
                return np.nan, np.nan
            t = abs(d1.mean - d2.mean) / math.sqrt(pooled_var)
            p = 2 * (1 - scipy.special.stdtr(df, t))
            return t, p

        # TODO: Check this function
        # noinspection PyPep8Naming
        def stat_ANOVA():
            if any(stat.n == 0 for stat in self.stats):
                return np.nan, np.nan
            n = sum(stat.n for stat in self.stats)
            grand_avg = sum(stat.n * stat.mean for stat in self.stats) / n
            var_between = sum(stat.n * (stat.mean - grand_avg)**2
                              for stat in self.stats)
            df_between = len(self.stats) - 1

            var_within = sum(stat.n * stat.var for stat in self.stats)
            df_within = n - len(self.stats)
            F = (var_between / df_between) / (var_within / df_within)
            p = 1 - scipy.special.fdtr(df_between, df_within, F)
            return F, p

        if self.compare == OWBoxPlot.CompareNone or len(self.stats) < 2:
            t = ""
        elif any(s.n <= 1 for s in self.stats):
            t = "At least one group has just one instance, " \
                "cannot compute significance"
        elif len(self.stats) == 2:
            if self.compare == OWBoxPlot.CompareMedians:
                t = ""
                # z, self.p = tests.wilcoxon_rank_sum(
                #    self.stats[0].dist, self.stats[1].dist)
                # t = "Mann-Whitney's z: %.1f (p=%.3f)" % (z, self.p)
            else:
                t, self.p = stat_ttest()
                t = "Student's t: %.3f (p=%.3f)" % (t, self.p)
        else:
            if self.compare == OWBoxPlot.CompareMedians:
                t = ""
                # U, self.p = -1, -1
                # t = "Kruskal Wallis's U: %.1f (p=%.3f)" % (U, self.p)
            else:
                F, self.p = stat_ANOVA()
                t = "ANOVA: %.3f (p=%.3f)" % (F, self.p)
        self.infot1.setText("<center>%s</center>" % t)

    def mean_label(self, stat, attr, val_name):
        label = QGraphicsItemGroup()
        t = QGraphicsSimpleTextItem(
            "%.*f" % (attr.number_of_decimals + 1, stat.mean), label)
        t.setFont(self._label_font)
        bbox = t.boundingRect()
        w2, h = bbox.width() / 2, bbox.height()
        t.setPos(-w2, -h)
        tpm = QGraphicsSimpleTextItem(
            " \u00b1 " + "%.*f" % (attr.number_of_decimals + 1, stat.dev),
            label)
        tpm.setFont(self._label_font)
        tpm.setPos(w2, -h)
        if val_name:
            vnm = QGraphicsSimpleTextItem(val_name + ": ", label)
            vnm.setFont(self._label_font)
            vnm.setBrush(self._attr_brush)
            vb = vnm.boundingRect()
            label.min_x = -w2 - vb.width()
            vnm.setPos(label.min_x, -h)
        else:
            label.min_x = -w2
        return label

    def draw_axis(self):
        """Draw the horizontal axis and sets self.scale_x"""
        misssing_stats = not self.stats
        stats = self.stats or [BoxData(np.array([[0.], [1.]]), self.attribute)]
        mean_labels = self.mean_labels or [
            self.mean_label(stats[0], self.attribute, "")
        ]
        bottom = min(stat.a_min for stat in stats)
        top = max(stat.a_max for stat in stats)

        first_val, step = compute_scale(bottom, top)
        while bottom <= first_val:
            first_val -= step
        bottom = first_val
        no_ticks = math.ceil((top - first_val) / step) + 1
        top = max(top, first_val + no_ticks * step)

        gbottom = min(bottom, min(stat.mean - stat.dev for stat in stats))
        gtop = max(top, max(stat.mean + stat.dev for stat in stats))

        bv = self.box_view
        viewrect = bv.viewport().rect().adjusted(15, 15, -15, -30)
        self.scale_x = scale_x = viewrect.width() / (gtop - gbottom)

        # In principle we should repeat this until convergence since the new
        # scaling is too conservative. (No chance am I doing this.)
        mlb = min(stat.mean + mean_lab.min_x / scale_x
                  for stat, mean_lab in zip(stats, mean_labels))
        if mlb < gbottom:
            gbottom = mlb
            self.scale_x = scale_x = viewrect.width() / (gtop - gbottom)

        self.scene_min_x = gbottom * scale_x
        self.scene_width = (gtop - gbottom) * scale_x

        val = first_val
        decimals = max(3, 4 - int(math.log10(step)))
        while True:
            l = self.box_scene.addLine(val * scale_x, -1, val * scale_x, 1,
                                       self._pen_axis_tick)
            l.setZValue(100)
            t = self.box_scene.addSimpleText(
                repr(round(val, decimals)) if not misssing_stats else "?",
                self._axis_font)
            t.setFlags(t.flags() | QGraphicsItem.ItemIgnoresTransformations)
            r = t.boundingRect()
            t.setPos(val * scale_x - r.width() / 2, 8)
            if val >= top:
                break
            val += step
        self.box_scene.addLine(bottom * scale_x - 4, 0, top * scale_x + 4, 0,
                               self._pen_axis)

    def draw_axis_disc(self):
        """
        Draw the horizontal axis and sets self.scale_x for discrete attributes
        """
        assert not self.is_continuous
        if self.stretched:
            if not self.attr_labels:
                return
            step = steps = 10
        else:
            if self.group_var:
                max_box = max(float(np.sum(dist)) for dist in self.conts)
            else:
                max_box = float(np.sum(self.dist))
            if max_box == 0:
                self.scale_x = 1
                return
            _, step = compute_scale(0, max_box)
            step = int(step) if step > 1 else 1
            steps = int(math.ceil(max_box / step))
        max_box = step * steps

        bv = self.box_view
        viewrect = bv.viewport().rect().adjusted(15, 15, -15, -30)
        self.scene_width = viewrect.width()

        lab_width = max(lab.boundingRect().width() for lab in self.attr_labels)
        lab_width = max(lab_width, 40)
        lab_width = min(lab_width, self.scene_width / 3)
        self.label_width = lab_width

        right_offset = 0  # offset for the right label
        if not self.stretched and self.labels:
            if self.group_var:
                rows = list(zip(self.conts, self.labels))
            else:
                rows = [(self.dist, self.labels[0])]
            # available space left of the 'group labels'
            available = self.scene_width - lab_width - 10
            scale_x = (available - right_offset) / max_box
            max_right = max(
                sum(dist) * scale_x + 10 + lbl.boundingRect().width()
                for dist, lbl in rows)
            right_offset = max(0, max_right - max_box * scale_x)

        self.scale_x = scale_x = \
            (self.scene_width - lab_width - 10 - right_offset) / max_box

        self.box_scene.addLine(0, 0, max_box * scale_x, 0, self._pen_axis)
        for val in range(0, step * steps + 1, step):
            l = self.box_scene.addLine(val * scale_x, -1, val * scale_x, 1,
                                       self._pen_axis_tick)
            l.setZValue(100)
            t = self.box_scene.addSimpleText(str(val), self._axis_font)
            t.setPos(val * scale_x - t.boundingRect().width() / 2, 8)
        if self.stretched:
            self.scale_x *= 100

    def label_group(self, stat, attr, mean_lab):
        def centered_text(val, pos):
            t = QGraphicsSimpleTextItem(
                "%.*f" % (attr.number_of_decimals + 1, val), labels)
            t.setFont(self._label_font)
            bbox = t.boundingRect()
            t.setPos(pos - bbox.width() / 2, 22)
            return t

        def line(x, down=1):
            QGraphicsLineItem(x, 12 * down, x, 20 * down, labels)

        def move_label(label, frm, to):
            label.setX(to)
            to += t_box.width() / 2
            path = QPainterPath()
            path.lineTo(0, 4)
            path.lineTo(to - frm, 4)
            path.lineTo(to - frm, 8)
            p = QGraphicsPathItem(path)
            p.setPos(frm, 12)
            labels.addToGroup(p)

        labels = QGraphicsItemGroup()

        labels.addToGroup(mean_lab)
        m = stat.mean * self.scale_x
        mean_lab.setPos(m, -22)
        line(m, -1)

        if stat.median is not None:
            msc = stat.median * self.scale_x
            med_t = centered_text(stat.median, msc)
            med_box_width2 = med_t.boundingRect().width() / 2
            line(msc)

        if stat.q25 is not None:
            x = stat.q25 * self.scale_x
            t = centered_text(stat.q25, x)
            t_box = t.boundingRect()
            med_left = msc - med_box_width2
            if x + t_box.width() / 2 >= med_left - 5:
                move_label(t, x, med_left - t_box.width() - 5)
            else:
                line(x)

        if stat.q75 is not None:
            x = stat.q75 * self.scale_x
            t = centered_text(stat.q75, x)
            t_box = t.boundingRect()
            med_right = msc + med_box_width2
            if x - t_box.width() / 2 <= med_right + 5:
                move_label(t, x, med_right + 5)
            else:
                line(x)

        return labels

    def box_group(self, stat, height=20):
        def line(x0, y0, x1, y1, *args):
            return QGraphicsLineItem(x0 * scale_x, y0, x1 * scale_x, y1, *args)

        scale_x = self.scale_x
        box = []
        whisker1 = line(stat.a_min, -1.5, stat.a_min, 1.5)
        whisker2 = line(stat.a_max, -1.5, stat.a_max, 1.5)
        vert_line = line(stat.a_min, 0, stat.a_max, 0)
        mean_line = line(stat.mean, -height / 3, stat.mean, height / 3)
        for it in (whisker1, whisker2, mean_line):
            it.setPen(self._pen_paramet)
        vert_line.setPen(self._pen_dotted)
        var_line = line(stat.mean - stat.dev, 0, stat.mean + stat.dev, 0)
        var_line.setPen(self._pen_paramet)
        box.extend([whisker1, whisker2, vert_line, mean_line, var_line])
        if stat.q25 is not None and stat.q75 is not None:
            mbox = FilterGraphicsRectItem(stat.conditions, stat.q25 * scale_x,
                                          -height / 2,
                                          (stat.q75 - stat.q25) * scale_x,
                                          height)
            mbox.setBrush(self._box_brush)
            mbox.setPen(QPen(Qt.NoPen))
            mbox.setZValue(-200)
            box.append(mbox)

        if stat.median is not None:
            median_line = line(stat.median, -height / 2, stat.median,
                               height / 2)
            median_line.setPen(self._pen_median)
            median_line.setZValue(-150)
            box.append(median_line)

        return box

    def strudel(self, dist, group_val_index=None):
        attr = self.attribute
        ss = np.sum(dist)
        box = []
        if ss < 1e-6:
            cond = [FilterDiscrete(attr, None)]
            if group_val_index is not None:
                cond.append(FilterDiscrete(self.group_var, [group_val_index]))
            box.append(FilterGraphicsRectItem(cond, 0, -10, 1, 10))
        cum = 0
        for i, v in enumerate(dist):
            if v < 1e-6:
                continue
            if self.stretched:
                v /= ss
            v *= self.scale_x
            cond = [FilterDiscrete(attr, [i])]
            if group_val_index is not None:
                cond.append(FilterDiscrete(self.group_var, [group_val_index]))
            rect = FilterGraphicsRectItem(cond, cum + 1, -6, v - 2, 12)
            rect.setBrush(QBrush(QColor(*attr.colors[i])))
            rect.setPen(QPen(Qt.NoPen))
            if self.stretched:
                tooltip = "{}: {:.2f}%".format(attr.values[i],
                                               100 * dist[i] / sum(dist))
            else:
                tooltip = "{}: {}".format(attr.values[i], int(dist[i]))
            rect.setToolTip(tooltip)
            text = QGraphicsTextItem(attr.values[i])
            box.append(rect)
            box.append(text)
            cum += v
        return box

    def commit(self):
        self.conditions = [
            item.filter for item in self.box_scene.selectedItems()
            if item.filter
        ]
        selected, selection = None, []
        if self.conditions:
            selected = Values(self.conditions, conjunction=False)(self.dataset)
            selection = np.in1d(self.dataset.ids,
                                selected.ids,
                                assume_unique=True).nonzero()[0]
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.dataset, selection))

    def show_posthoc(self):
        def line(y0, y1):
            it = self.box_scene.addLine(x, y0, x, y1, self._post_line_pen)
            it.setZValue(-100)
            self.posthoc_lines.append(it)

        while self.posthoc_lines:
            self.box_scene.removeItem(self.posthoc_lines.pop())

        if self.compare == OWBoxPlot.CompareNone or len(self.stats) < 2:
            return

        if self.compare == OWBoxPlot.CompareMedians:
            crit_line = "median"
        else:
            crit_line = "mean"

        xs = []

        height = 90 if self.show_annotations else 60

        y_up = -len(self.stats) * height + 10
        for pos, box_index in enumerate(self.order):
            stat = self.stats[box_index]
            x = getattr(stat, crit_line)
            if x is None:
                continue
            x *= self.scale_x
            xs.append(x * self.scale_x)
            by = y_up + pos * height
            line(by + 12, 3)
            line(by - 12, by - 25)

        used_to = []
        last_to = to = 0
        for frm, frm_x in enumerate(xs[:-1]):
            for to in range(frm + 1, len(xs)):
                if xs[to] - frm_x > 1.5:
                    to -= 1
                    break
            if to in (last_to, frm):
                continue
            for rowi, used in enumerate(used_to):
                if used < frm:
                    used_to[rowi] = to
                    break
            else:
                rowi = len(used_to)
                used_to.append(to)
            y = -6 - rowi * 6
            it = self.box_scene.addLine(frm_x - 2, y, xs[to] + 2, y,
                                        self._post_grp_pen)
            self.posthoc_lines.append(it)
            last_to = to

    def get_widget_name_extension(self):
        return self.attribute.name if self.attribute else None

    def send_report(self):
        self.report_plot()
        text = ""
        if self.attribute:
            text += "Box plot for attribute '{}' ".format(self.attribute.name)
        if self.group_var:
            text += "grouped by '{}'".format(self.group_var.name)
        if text:
            self.report_caption(text)

    class Label(QGraphicsSimpleTextItem):
        """Boxplot Label with settable maxWidth"""
        # Minimum width to display label text
        MIN_LABEL_WIDTH = 25

        # padding bellow the text
        PADDING = 3

        __max_width = None

        def maxWidth(self):
            return self.__max_width

        def setMaxWidth(self, max_width):
            self.__max_width = max_width

        def paint(self, painter, option, widget):
            """Overrides QGraphicsSimpleTextItem.paint

            If label text is too long, it is elided
            to fit into the allowed region
            """
            if self.__max_width is None:
                width = option.rect.width()
            else:
                width = self.__max_width

            if width < self.MIN_LABEL_WIDTH:
                # if space is too narrow, no label
                return

            fm = painter.fontMetrics()
            text = fm.elidedText(self.text(), Qt.ElideRight, width)
            painter.drawText(
                option.rect.x(),
                option.rect.y() + self.boundingRect().height() - self.PADDING,
                text)
Beispiel #46
0
class OWContingencyTable(widget.OWWidget):
    name = "Contingency Table"
    description = "Construct a contingency table from given data."
    icon = "icons/Contingency.svg"
    priority = 2010

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        contingency = Output("Contingency Table", Table, default=True)
        selected_data = Output("Selected Data", Table)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler(metas_in_res=True)
    rows = ContextSetting(None)
    columns = ContextSetting(None)
    selection = ContextSetting(set())
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.data = None
        self.feature_model = DomainModel(valid_types=DiscreteVariable)
        self.table = None

        box = gui.vBox(self.controlArea, "Rows")
        gui.comboBox(box,
                     self,
                     'rows',
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._attribute_changed)

        box = gui.vBox(self.controlArea, "Columns")
        gui.comboBox(box,
                     self,
                     'columns',
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._attribute_changed)

        gui.rubber(self.controlArea)

        box = gui.vBox(self.controlArea, "Scores")
        self.scores = gui.widgetLabel(box, "\n\n")

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.feature_model.set_domain(None)
        self.rows = None
        self.columns = None
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.rows = self.feature_model[0]
                self.columns = self.feature_model[0]
                self.openContext(data)
                self.tableview.set_variables(self.rows, self.columns)
                self.table = contingency_table(self.data, self.columns,
                                               self.rows)
                self.tableview.update_table(self.table.X, formatstr="{:.0f}")
        else:
            self.tableview.clear()

    def handleNewSignals(self):
        self._attribute_changed()

    def commit(self):
        if len(self.selection):
            cells = []
            for ir, r in enumerate(self.rows.values):
                for ic, c in enumerate(self.columns.values):
                    if (ir, ic) in self.selection:
                        cells.append(
                            Values([
                                FilterDiscrete(self.rows, [r]),
                                FilterDiscrete(self.columns, [c])
                            ]))
            selected_data = Values(cells, conjunction=False)(self.data)
            annotated_data = create_annotated_table(
                self.data,
                np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])
        self.Outputs.contingency.send(self.table)
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)

    def _invalidate(self):
        self.selection = self.tableview.get_selection()
        self.commit()

    def _attribute_changed(self):
        self.tableview.set_selection(self.selection)
        self.table = None
        if self.data and self.rows and self.columns:
            self.tableview.set_variables(self.rows, self.columns)
            self.table = contingency_table(self.data, self.columns, self.rows)
            self.tableview.update_table(self.table.X, formatstr="{:.0f}")

            chi = ChiSqStats(self.data, self.rows, self.columns)
            vardata1 = self.data.get_column_view(self.rows.name)[0]
            vardata2 = self.data.get_column_view(self.columns.name)[0]
            self.scores.setText(
                "ARI: {:.3f}\nAMI: {:.3f}\nχ²={:.2f}, p={:.3f}".format(
                    adjusted_rand_score(vardata1, vardata2),
                    adjusted_mutual_info_score(vardata1, vardata2), chi.chisq,
                    chi.p))
        else:
            self.scores.setText("\n\n")
        self._invalidate()

    def send_report(self):
        rows = None
        columns = None
        if self.data is not None:
            rows = self.rows
            if rows in self.data.domain:
                rows = self.data.domain[rows]
            columns = self.columns
            if columns in self.data.domain:
                columns = self.data.domain[columns]
        self.report_items((
            ("Rows", rows),
            ("Columns", columns),
        ))
Beispiel #47
0
class OWUnique(widget.OWWidget):
    name = '唯一(Unique)'
    icon = 'icons/Unique.svg'
    description = '根据所选特征删除重复的实例。'
    category = "数据(Data)"

    class Inputs:
        data = widget.Input("数据(Data)", Table, replaces=['Data'])

    class Outputs:
        data = widget.Output("数据(Data)", Table, replaces=['Data'])

    want_main_area = False

    TIEBREAKERS = {'最后的实例': itemgetter(-1),
                   '第一个实例': itemgetter(0),
                   '中间的实例': lambda seq: seq[len(seq) // 2],
                   '随机实例': np.random.choice,
                   '丢弃重复实例':
                   lambda seq: seq[0] if len(seq) == 1 else None}

    settingsHandler = settings.DomainContextHandler()
    selected_vars = settings.ContextSetting([])
    tiebreaker = settings.Setting(next(iter(TIEBREAKERS)))
    autocommit = settings.Setting(True)

    def __init__(self):
        # Commit is thunked because autocommit redefines it
        # pylint: disable=unnecessary-lambda
        super().__init__()
        self.data = None

        self.var_model = DomainModel(parent=self, order=DomainModel.MIXED)
        var_list = gui.listView(
            self.controlArea, self, "selected_vars", box="分组依据",
            model=self.var_model, callback=self.commit.deferred,
            viewType=ListViewSearch
        )
        var_list.setSelectionMode(var_list.ExtendedSelection)

        gui.comboBox(
            self.controlArea, self, 'tiebreaker', box=True,
            label='每组中选择的实例:',
            items=tuple(self.TIEBREAKERS),
            callback=self.commit.deferred, sendSelectedValue=True)
        gui.auto_commit(
            self.controlArea, self, 'autocommit', 'Commit',
            orientation=Qt.Horizontal)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.selected_vars = []
        if data:
            self.var_model.set_domain(data.domain)
            self.selected_vars = self.var_model[:]
            self.openContext(data.domain)
        else:
            self.var_model.set_domain(None)

        self.commit.now()

    @gui.deferred
    def commit(self):
        if self.data is None:
            self.Outputs.data.send(None)
        else:
            self.Outputs.data.send(self._compute_unique_data())

    def _compute_unique_data(self):
        uniques = {}
        keys = zip(*[self.data.get_column_view(attr)[0]
                     for attr in self.selected_vars or self.var_model])
        for i, key in enumerate(keys):
            uniques.setdefault(key, []).append(i)

        choose = self.TIEBREAKERS[self.tiebreaker]
        selection = sorted(
            x for x in (choose(inds) for inds in uniques.values())
            if x is not None)
        if selection:
            return self.data[selection]
        else:
            return None
Beispiel #48
0
class OWScatterPlot(OWDataProjectionWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs(OWDataProjectionWidget.Inputs):
        features = Input("Features", AttributeList)

    class Outputs(OWDataProjectionWidget.Outputs):
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 3
    auto_sample = Setting(True)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    tooltip_shows_all = Setting(True)

    GRAPH_CLASS = OWScatterPlotGraph
    graph = SettingProvider(OWScatterPlotGraph)
    embedding_variables_names = None

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")

    class Information(OWDataProjectionWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        super().__init__()

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            for ext in w.EXTENSIONS:
                self.graph_writers[ext] = w

    def _add_controls(self):
        self._add_controls_axis()
        self._add_controls_sampling()
        super()._add_controls()
        self.graph.gui.add_widget(self.graph.gui.JitterNumericValues,
                                  self._effects_box)
        self.graph.gui.add_widgets([
            self.graph.gui.ShowGridLines, self.graph.gui.ToolTipShowsAll,
            self.graph.gui.RegressionLine
        ], self._plot_box)

    def _add_controls_axis(self):
        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str,
                              contentsLength=14)
        box = gui.vBox(self.controlArea, True)
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self.attr_changed,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(box,
                                      self,
                                      "attr_y",
                                      label="Axis y:",
                                      callback=self.attr_changed,
                                      model=self.xy_model,
                                      **common_options)
        vizrank_box = gui.hBox(box)
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

    def _add_controls_sampling(self):
        self.sampling = gui.auto_commit(self.controlArea,
                                        self,
                                        "auto_sample",
                                        "Sample",
                                        box="Sampling",
                                        callback=self.switch_sampling,
                                        commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
            len(self.xy_model) > 2 and len(self.data[self.valid_data]) > 1 \
            and np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0)
        self.vizrank_button.setEnabled(
            is_enabled and self.attr_color is not None and not np.isnan(
                self.data.get_column_view(
                    self.attr_color)[0].astype(float)).all())
        text = "Color variable has to be selected." \
            if is_enabled and self.attr_color is None else ""
        self.vizrank_button.setToolTip(text)

    def set_data(self, data):
        if self.data and data and self.data.checksum() == data.checksum():
            return
        super().set_data(data)

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.attr_label, str):
            self.attr_label = findvar(self.attr_label,
                                      self.graph.gui.label_model)
        if isinstance(self.attr_color, str):
            self.attr_color = findvar(self.attr_color,
                                      self.graph.gui.color_model)
        if isinstance(self.attr_shape, str):
            self.attr_shape = findvar(self.attr_shape,
                                      self.graph.gui.shape_model)
        if isinstance(self.attr_size, str):
            self.attr_size = findvar(self.attr_size, self.graph.gui.size_model)

    def check_data(self):
        self.clear_messages()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(self.data, SqlTable):
            if self.data.approx_len() < 4000:
                self.data = Table(self.data)
            else:
                self.Information.sampled_sql()
                self.sql_data = self.data
                data_sample = self.data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                self.data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        x_data = self.get_column(self.attr_x, filter_valid=False)
        y_data = self.get_column(self.attr_y, filter_valid=False)
        if x_data is None or y_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_x.name, self.attr_y.name)
        return np.vstack((x_data, y_data)).T

    # Tooltip
    def _point_tooltip(self, point_id, skip_attrs=()):
        point_data = self.data[point_id]
        xy_attrs = (self.attr_x, self.attr_y)
        text = "<br/>".join(
            escape('{} = {}'.format(var.name, point_data[var]))
            for var in xy_attrs)
        if self.tooltip_shows_all:
            others = super()._point_tooltip(point_id, skip_attrs=xy_attrs)
            if others:
                text = "<b>{}</b><br/><br/>{}".format(text, others)
        return text

    def can_draw_regresssion_line(self):
        return self.data is not None and\
               self.data.domain is not None and \
               self.attr_x.is_continuous and \
               self.attr_y.is_continuous

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            self.__timer.stop()
            return
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def init_attr_values(self):
        super().init_attr_values()
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        super().set_subset_data(subset_data)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        if self.attribute_selection_list and self.data is not None and \
                self.data.domain is not None and \
                all(attr in self.data.domain for attr
                        in self.attribute_selection_list):
            self.attr_x = self.attribute_selection_list[0]
            self.attr_y = self.attribute_selection_list[1]
        self.attribute_selection_list = None
        super().handleNewSignals()
        self._vizrank_color_change()
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.attr_changed()

    def attr_changed(self):
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())
        self.setup_plot()
        self.commit()

    def setup_plot(self):
        super().setup_plot()
        for axis, var in (("bottom", self.attr_x), ("left", self.attr_y)):
            self.graph.set_axis_title(axis, var)
            if var and var.is_discrete:
                self.graph.set_axis_labels(axis,
                                           get_variable_values_sorted(var))
            else:
                self.graph.set_axis_labels(axis, None)

    def colors_changed(self):
        super().colors_changed()
        self._vizrank_color_change()

    def commit(self):
        super().commit()
        self.send_features()

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
        return None

    def _get_send_report_caption(self):
        return report.render_items_vert(
            (("Color", self._get_caption_var_name(self.attr_color)),
             ("Label", self._get_caption_var_name(self.attr_label)),
             ("Shape", self._get_caption_var_name(self.attr_shape)),
             ("Size", self._get_caption_var_name(self.attr_size)),
             ("Jittering", (self.attr_x.is_discrete or self.attr_y.is_discrete
                            or self.graph.jitter_continuous)
              and self.graph.jitter_size)))

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1)
                                           for a in settings["selection"]]
        if version < 3:
            if "auto_send_selection" in settings:
                settings["auto_commit"] = settings["auto_send_selection"]
            if "selection_group" in settings:
                settings["selection"] = settings["selection_group"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Beispiel #49
0
class OWMosaicDisplay(OWWidget):
    name = "Mosaic Display"
    description = "Display data in a mosaic plot."
    icon = "icons/MosaicDisplay.svg"
    priority = 220

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    PEARSON, CLASS_DISTRIBUTION = 0, 1

    settingsHandler = DomainContextHandler()
    use_boxes = Setting(True)
    interior_coloring = Setting(CLASS_DISTRIBUTION)
    variable1 = ContextSetting("")
    variable2 = ContextSetting("")
    variable3 = ContextSetting("")
    variable4 = ContextSetting("")
    variable_color = ContextSetting("")
    selection = ContextSetting(set())

    BAR_WIDTH = 5
    SPACING = 4
    ATTR_NAME_OFFSET = 20
    ATTR_VAL_OFFSET = 3
    BLUE_COLORS = [QColor(255, 255, 255), QColor(210, 210, 255),
                   QColor(110, 110, 255), QColor(0, 0, 255)]
    RED_COLORS = [QColor(255, 255, 255), QColor(255, 200, 200),
                  QColor(255, 100, 100), QColor(255, 0, 0)]

    vizrank = SettingProvider(MosaicVizRank)

    graph_name = "canvas"

    class Warning(OWWidget.Warning):
        incompatible_subset = Msg("Data subset is incompatible with Data")
        no_valid_data = Msg("No valid data")
        no_cont_selection_sql = \
            Msg("Selection of numeric features on SQL is not supported")

    def __init__(self):
        super().__init__()

        self.data = None
        self.discrete_data = None
        self.subset_data = None
        self.subset_indices = None

        self.color_data = None

        self.areas = []

        self.canvas = QGraphicsScene()
        self.canvas_view = ViewWithPress(self.canvas,
                                         handler=self.clear_selection)
        self.mainArea.layout().addWidget(self.canvas_view)
        self.canvas_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setRenderHint(QPainter.Antialiasing)

        box = gui.vBox(self.controlArea, box=True)
        self.attr_combos = [
            gui.comboBox(
                box, self, value="variable{}".format(i),
                orientation=Qt.Horizontal, contentsLength=12,
                callback=self.reset_graph,
                sendSelectedValue=True, valueType=str, emptyString="(None)")
            for i in range(1, 5)]
        self.vizrank, self.vizrank_button = MosaicVizRank.add_vizrank(
            box, self, "Find Informative Mosaics", self.set_attr)

        box2 = gui.vBox(self.controlArea, box="Interior Coloring")
        dmod = DomainModel
        self.color_model = DomainModel(order=dmod.MIXED,
                                       valid_types=dmod.PRIMITIVE,
                                       placeholder="(Pearson residuals)")
        self.cb_attr_color = gui.comboBox(
            box2, self, value="variable_color",
            orientation=Qt.Horizontal, contentsLength=12, labelWidth=50,
            callback=self.set_color_data,
            sendSelectedValue=True, model=self.color_model, valueType=str)
        self.bar_button = gui.checkBox(
            box2, self, 'use_boxes', label='Compare with total',
            callback=self._compare_with_total)
        gui.rubber(self.controlArea)

    def sizeHint(self):
        return QSize(720, 530)

    def _compare_with_total(self):
        if self.data is not None and \
                self.data.domain.class_var is not None and \
                self.interior_coloring != self.CLASS_DISTRIBUTION:
            self.interior_coloring = self.CLASS_DISTRIBUTION
            self.coloring_changed()  # This also calls self.update_graph
        else:
            self.update_graph()

    def _get_discrete_data(self, data):
        """
        Discretizes continuous attributes.
        Returns None when there is no data, no rows, or no discrete or continuous attributes.
        """
        if (data is None or
                not len(data) or
                not any(attr.is_discrete or attr.is_continuous
                        for attr in chain(data.domain.variables, data.domain.metas))):
            return None
        elif any(attr.is_continuous for attr in data.domain.variables):
            return Discretize(
                method=EqualFreq(n=4), remove_const=False, discretize_classes=True,
                discretize_metas=True)(data)
        else:
            return data

    def init_combos(self, data):
        for combo in self.attr_combos:
            combo.clear()
        if data is None:
            self.color_model.set_domain(None)
            return
        self.color_model.set_domain(self.data.domain)
        for combo in self.attr_combos[1:]:
            combo.addItem("(None)")

        icons = gui.attributeIconDict
        for attr in chain(data.domain.variables, data.domain.metas):
            if attr.is_primitive:
                for combo in self.attr_combos:
                    combo.addItem(icons[attr], attr.name)

        if self.attr_combos[0].count() > 0:
            self.variable1 = self.attr_combos[0].itemText(0)
            self.variable2 = self.attr_combos[1].itemText(
                2 * (self.attr_combos[1].count() > 2))
        self.variable3 = self.attr_combos[2].itemText(0)
        self.variable4 = self.attr_combos[3].itemText(0)
        if self.data.domain.class_var:
            self.variable_color = self.data.domain.class_var.name
            idx = self.cb_attr_color.findText(self.variable_color)
        else:
            idx = 0
        self.cb_attr_color.setCurrentIndex(idx)

    def get_attr_list(self):
        return [
            a for a in [self.variable1, self.variable2,
                        self.variable3, self.variable4]
            if a and a != "(None)"]

    def set_attr(self, *attrs):
        self.variable1, self.variable2, self.variable3, self.variable4 = \
            [a.name if a else "" for a in attrs]
        self.reset_graph()

    def resizeEvent(self, e):
        OWWidget.resizeEvent(self, e)
        self.update_graph()

    def showEvent(self, ev):
        OWWidget.showEvent(self, ev)
        self.update_graph()

    @Inputs.data
    def set_data(self, data):
        if type(data) == SqlTable and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data

        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(
            self.data is not None and len(self.data) > 1 \
            and len(self.data.domain.attributes) >= 1)

        if self.data is None:
            self.discrete_data = None
            self.init_combos(None)
            return

        self.init_combos(self.data)

        self.openContext(self.data)

    @Inputs.data_subset
    def set_subset_data(self, data):
        self.subset_data = data

    # this is called by widget after setData and setSubsetData are called.
    # this way the graph is updated only once
    def handleNewSignals(self):
        self.Warning.incompatible_subset.clear()
        self.subset_indices = indices = None
        if self.data is not None and self.subset_data:
            transformed = self.subset_data.transform(self.data.domain)
            if np.all(np.isnan(transformed.X)) and np.all(np.isnan(transformed.Y)):
                self.Warning.incompatible_subset()
            else:
                indices = {e.id for e in transformed}
                self.subset_indices = [ex.id in indices for ex in self.data]

        self.set_color_data()
        self.reset_graph()

    def clear_selection(self):
        self.selection = set()
        self.update_selection_rects()
        self.send_selection()

    def coloring_changed(self):
        self.vizrank.coloring_changed()
        self.update_graph()

    def reset_graph(self):
        self.clear_selection()
        self.update_graph()

    def set_color_data(self):
        if self.data is None or len(self.data) < 2 or len(self.data.domain.attributes) < 1:
            return
        if self.cb_attr_color.currentIndex() <= 0:
            color_var = None
            self.interior_coloring = self.PEARSON
            self.bar_button.setEnabled(False)
        else:
            color_var = self.data.domain[self.cb_attr_color.currentText()]
            self.interior_coloring = self.CLASS_DISTRIBUTION
            self.bar_button.setEnabled(True)
        attributes = [v for v in self.data.domain.attributes + self.data.domain.class_vars
                      + self.data.domain.metas if v != color_var and v.is_primitive()]
        domain = Domain(attributes, color_var, None)
        self.color_data = color_data = self.data.from_table(domain, self.data)
        self.discrete_data = self._get_discrete_data(color_data)
        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(True)
        self.coloring_changed()

    def update_selection_rects(self):
        for i, (_, _, area) in enumerate(self.areas):
            if i in self.selection:
                area.setPen(QPen(Qt.black, 3, Qt.DotLine))
            else:
                area.setPen(QPen())

    def select_area(self, index, ev):
        if ev.button() != Qt.LeftButton:
            return
        if ev.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection_rects()
        self.send_selection()

    def send_selection(self):
        if not self.selection or self.data is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(create_annotated_table(self.data, []))
            return
        filters = []
        self.Warning.no_cont_selection_sql.clear()
        if self.discrete_data is not self.data:
            if isinstance(self.data, SqlTable):
                self.Warning.no_cont_selection_sql()
        for i in self.selection:
            cols, vals, _ = self.areas[i]
            filters.append(
                filter.Values(
                    filter.FilterDiscrete(col, [val])
                    for col, val in zip(cols, vals)))
        if len(filters) > 1:
            filters = filter.Values(filters, conjunction=False)
        else:
            filters = filters[0]
        selection = filters(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]
        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(create_annotated_table(self.data, sel_idx))

    def send_report(self):
        self.report_plot(self.canvas)

    def update_graph(self):
        spacing = self.SPACING
        bar_width = self.BAR_WIDTH

        def get_counts(attr_vals, values):
            """This function calculates rectangles' widths.
            If all widths are zero then all widths are set to 1."""
            if attr_vals == "":
                counts = [conditionaldict[val] for val in values]
            else:
                counts = [conditionaldict[attr_vals + "-" + val]
                          for val in values]
            total = sum(counts)
            if total == 0:
                counts = [1] * len(values)
                total = sum(counts)
            return total, counts

        def draw_data(attr_list, x0_x1, y0_y1, side, condition,
                      total_attrs, used_attrs, used_vals, attr_vals=""):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if conditionaldict[attr_vals] == 0:
                add_rect(x0, x1, y0, y1, "",
                         used_attrs, used_vals, attr_vals=attr_vals)
                # store coordinates for later drawing of labels
                draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                          used_attrs, used_vals, attr_vals)
                return

            attr = attr_list[0]
            # how much smaller rectangles do we draw
            edge = len(attr_list) * spacing
            values = get_variable_values_sorted(data.domain[attr])
            if side % 2:
                values = values[::-1]  # reverse names if necessary

            if side % 2 == 0:  # we are drawing on the x axis
                # remove the space needed for separating different attr. values
                whole = max(0, (x1 - x0) - edge * (
                    len(values) - 1))
                if whole == 0:
                    edge = (x1 - x0) / float(len(values) - 1)
            else:  # we are drawing on the y axis
                whole = max(0, (y1 - y0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (y1 - y0) / float(len(values) - 1)

            total, counts = get_counts(attr_vals, values)

            # if we are visualizing the third attribute and the first attribute
            # has the last value, we have to reverse the order in which the
            # boxes will be drawn otherwise, if the last cell, nearest to the
            # labels of the fourth attribute, is empty, we wouldn't be able to
            # position the labels
            valrange = list(range(len(values)))
            if len(attr_list + used_attrs) == 4 and len(used_attrs) == 2:
                attr1values = get_variable_values_sorted(
                    data.domain[used_attrs[0]])
                if used_vals[0] == attr1values[-1]:
                    valrange = valrange[::-1]

            for i in valrange:
                start = i * edge + whole * float(sum(counts[:i]) / total)
                end = i * edge + whole * float(sum(counts[:i + 1]) / total)
                val = values[i]
                htmlval = to_html(val)
                if attr_vals != "":
                    newattrvals = attr_vals + "-" + val
                else:
                    newattrvals = val

                tooltip = condition + 4 * "&nbsp;" + attr + \
                    ": <b>" + htmlval + "</b><br>"
                attrs = used_attrs + [attr]
                vals = used_vals + [val]
                common_args = attrs, vals, newattrvals
                if side % 2 == 0:  # if we are moving horizontally
                    if len(attr_list) == 1:
                        add_rect(x0 + start, x0 + end, y0, y1,
                                 tooltip, *common_args)
                    else:
                        draw_data(attr_list[1:], (x0 + start, x0 + end),
                                  (y0, y1), side + 1,
                                  tooltip, total_attrs, *common_args)
                else:
                    if len(attr_list) == 1:
                        add_rect(x0, x1, y0 + start, y0 + end,
                                 tooltip, *common_args)
                    else:
                        draw_data(attr_list[1:], (x0, x1),
                                  (y0 + start, y0 + end), side + 1,
                                  tooltip, total_attrs, *common_args)

            draw_text(side, attr_list[0], (x0, x1), (y0, y1),
                      total_attrs, used_attrs, used_vals, attr_vals)

        def draw_text(side, attr, x0_x1, y0_y1,
                      total_attrs, used_attrs, used_vals, attr_vals):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if side in drawn_sides:
                return

            # the text on the right will be drawn when we are processing
            # visualization of the last value of the first attribute
            if side == 3:
                attr1values = \
                    get_variable_values_sorted(data.domain[used_attrs[0]])
                if used_vals[0] != attr1values[-1]:
                    return

            if not conditionaldict[attr_vals]:
                if side not in draw_positions:
                    draw_positions[side] = (x0, x1, y0, y1)
                return
            else:
                if side in draw_positions:
                    # restore the positions of attribute values and name
                    (x0, x1, y0, y1) = draw_positions[side]

            drawn_sides.add(side)

            values = get_variable_values_sorted(data.domain[attr])
            if side % 2:
                values = values[::-1]

            spaces = spacing * (total_attrs - side) * (len(values) - 1)
            width = x1 - x0 - spaces * (side % 2 == 0)
            height = y1 - y0 - spaces * (side % 2 == 1)

            # calculate position of first attribute
            currpos = 0

            total, counts = get_counts(attr_vals, values)

            aligns = [Qt.AlignTop | Qt.AlignHCenter,
                      Qt.AlignRight | Qt.AlignVCenter,
                      Qt.AlignBottom | Qt.AlignHCenter,
                      Qt.AlignLeft | Qt.AlignVCenter]
            align = aligns[side]
            for i, val in enumerate(values):
                perc = counts[i] / float(total)
                if distributiondict[val] != 0:
                    if side == 0:
                        CanvasText(self.canvas, str(val),
                                   x0 + currpos + width * 0.5 * perc,
                                   y1 + self.ATTR_VAL_OFFSET, align)
                    elif side == 1:
                        CanvasText(self.canvas, str(val),
                                   x0 - self.ATTR_VAL_OFFSET,
                                   y0 + currpos + height * 0.5 * perc, align)
                    elif side == 2:
                        CanvasText(self.canvas, str(val),
                                   x0 + currpos + width * perc * 0.5,
                                   y0 - self.ATTR_VAL_OFFSET, align)
                    else:
                        CanvasText(self.canvas, str(val),
                                   x1 + self.ATTR_VAL_OFFSET,
                                   y0 + currpos + height * 0.5 * perc, align)

                if side % 2 == 0:
                    currpos += perc * width + spacing * (total_attrs - side)
                else:
                    currpos += perc * height + spacing * (total_attrs - side)

            if side == 0:
                CanvasText(
                    self.canvas, attr,
                    x0 + (x1 - x0) / 2,
                    y1 + self.ATTR_VAL_OFFSET + self.ATTR_NAME_OFFSET,
                    align, bold=1)
            elif side == 1:
                CanvasText(
                    self.canvas, attr,
                    x0 - max_ylabel_w1 - self.ATTR_VAL_OFFSET,
                    y0 + (y1 - y0) / 2,
                    align, bold=1, vertical=True)
            elif side == 2:
                CanvasText(
                    self.canvas, attr,
                    x0 + (x1 - x0) / 2,
                    y0 - self.ATTR_VAL_OFFSET - self.ATTR_NAME_OFFSET,
                    align, bold=1)
            else:
                CanvasText(
                    self.canvas, attr,
                    x1 + max_ylabel_w2 + self.ATTR_VAL_OFFSET,
                    y0 + (y1 - y0) / 2,
                    align, bold=1, vertical=True)

        def add_rect(x0, x1, y0, y1, condition,
                     used_attrs, used_vals, attr_vals=""):
            area_index = len(self.areas)
            if x0 == x1:
                x1 += 1
            if y0 == y1:
                y1 += 1

            # rectangles of width and height 1 are not shown - increase
            if x1 - x0 + y1 - y0 == 2:
                y1 += 1

            if class_var:
                colors = [QColor(*col) for col in class_var.colors]
            else:
                colors = None

            def select_area(_, ev):
                self.select_area(area_index, ev)

            def rect(x, y, w, h, z, pen_color=None, brush_color=None, **args):
                if pen_color is None:
                    return CanvasRectangle(
                        self.canvas, x, y, w, h, z=z, onclick=select_area,
                        **args)
                if brush_color is None:
                    brush_color = pen_color
                return CanvasRectangle(
                    self.canvas, x, y, w, h, pen_color, brush_color, z=z,
                    onclick=select_area, **args)

            def line(x1, y1, x2, y2):
                r = QGraphicsLineItem(x1, y1, x2, y2, None)
                self.canvas.addItem(r)
                r.setPen(QPen(Qt.white, 2))
                r.setZValue(30)

            outer_rect = rect(x0, y0, x1 - x0, y1 - y0, 30)
            self.areas.append((used_attrs, used_vals, outer_rect))
            if not conditionaldict[attr_vals]:
                return

            if self.interior_coloring == self.PEARSON:
                s = sum(apriori_dists[0])
                expected = s * reduce(
                    mul,
                    (apriori_dists[i][used_vals[i]] / float(s)
                     for i in range(len(used_vals))))
                actual = conditionaldict[attr_vals]
                pearson = (actual - expected) / sqrt(expected)
                if pearson == 0:
                    ind = 0
                else:
                    ind = max(0, min(int(log(abs(pearson), 2)), 3))
                color = [self.RED_COLORS, self.BLUE_COLORS][pearson > 0][ind]
                rect(x0, y0, x1 - x0, y1 - y0, -20, color)
                outer_rect.setToolTip(
                    condition + "<hr/>" +
                    "Expected instances: %.1f<br>"
                    "Actual instances: %d<br>"
                    "Standardized (Pearson) residual: %.1f" %
                    (expected, conditionaldict[attr_vals], pearson))
            else:
                cls_values = get_variable_values_sorted(class_var)
                prior = get_distribution(data, class_var.name)
                total = 0
                for i, value in enumerate(cls_values):
                    val = conditionaldict[attr_vals + "-" + value]
                    if val == 0:
                        continue
                    if i == len(cls_values) - 1:
                        v = y1 - y0 - total
                    else:
                        v = ((y1 - y0) * val) / conditionaldict[attr_vals]
                    rect(x0, y0 + total, x1 - x0, v, -20, colors[i])
                    total += v

                if self.use_boxes and \
                        abs(x1 - x0) > bar_width and \
                        abs(y1 - y0) > bar_width:
                    total = 0
                    line(x0 + bar_width, y0, x0 + bar_width, y1)
                    n = sum(prior)
                    for i, (val, color) in enumerate(zip(prior, colors)):
                        if i == len(prior) - 1:
                            h = y1 - y0 - total
                        else:
                            h = (y1 - y0) * val / n
                        rect(x0, y0 + total, bar_width, h, 20, color)
                        total += h

                if conditionalsubsetdict:
                    if conditionalsubsetdict[attr_vals]:
                        if self.subset_indices is not None:
                            line(x1 - bar_width, y0, x1 - bar_width, y1)
                            total = 0
                            n = conditionalsubsetdict[attr_vals]
                            if n:
                                for i, (cls, color) in \
                                        enumerate(zip(cls_values, colors)):
                                    val = conditionalsubsetdict[
                                        attr_vals + "-" + cls]
                                    if val == 0:
                                        continue
                                    if i == len(prior) - 1:
                                        v = y1 - y0 - total
                                    else:
                                        v = ((y1 - y0) * val) / n
                                    rect(x1 - bar_width, y0 + total,
                                         bar_width, v, 15, color)
                                    total += v

                actual = [conditionaldict[attr_vals + "-" + cls_values[i]]
                          for i in range(len(prior))]
                n_actual = sum(actual)
                if n_actual > 0:
                    apriori = [prior[key] for key in cls_values]
                    n_apriori = sum(apriori)
                    text = "<br/>".join(
                        "<b>%s</b>: %d / %.1f%% (Expected %.1f / %.1f%%)" %
                        (cls, act, 100.0 * act / n_actual,
                         apr / n_apriori * n_actual, 100.0 * apr / n_apriori)
                        for cls, act, apr in zip(cls_values, actual, apriori))
                else:
                    text = ""
                outer_rect.setToolTip(
                    "{}<hr>Instances: {}<br><br>{}".format(
                        condition, n_actual, text[:-4]))

        def draw_legend(x0_x1, y0_y1):
            x0, x1 = x0_x1
            _, y1 = y0_y1
            if self.interior_coloring == self.PEARSON:
                names = ["<-8", "-8:-4", "-4:-2", "-2:2", "2:4", "4:8", ">8",
                         "Residuals:"]
                colors = self.RED_COLORS[::-1] + self.BLUE_COLORS[1:]
            else:
                names = get_variable_values_sorted(class_var) + \
                        [class_var.name + ":"]
                colors = [QColor(*col) for col in class_var.colors]

            names = [CanvasText(self.canvas, name, alignment=Qt.AlignVCenter)
                     for name in names]
            totalwidth = sum(text.boundingRect().width() for text in names)

            # compute the x position of the center of the legend
            y = y1 + self.ATTR_NAME_OFFSET + self.ATTR_VAL_OFFSET + 35
            distance = 30
            startx = (x0 + x1) / 2 - (totalwidth + (len(names)) * distance) / 2

            names[-1].setPos(startx + 15, y)
            names[-1].show()
            xoffset = names[-1].boundingRect().width() + distance

            size = 8

            for i in range(len(names) - 1):
                if self.interior_coloring == self.PEARSON:
                    edgecolor = Qt.black
                else:
                    edgecolor = colors[i]

                CanvasRectangle(self.canvas, startx + xoffset, y - size / 2,
                                size, size, edgecolor, colors[i])
                names[i].setPos(startx + xoffset + 10, y)
                xoffset += distance + names[i].boundingRect().width()

        self.canvas.clear()
        self.areas = []

        data = self.discrete_data
        if data is None:
            return
        attr_list = self.get_attr_list()
        class_var = data.domain.class_var
        if class_var:
            sql = type(data) == SqlTable
            name = not sql and data.name
            # save class_var because it is removed in the next line
            data = data[:, attr_list + [class_var]]
            data.domain.class_var = class_var
            if not sql:
                data.name = name
        else:
            data = data[:, attr_list]
        # TODO: check this
        # data = Preprocessor_dropMissing(data)
        if len(data) == 0:
            self.Warning.no_valid_data()
            return
        else:
            self.Warning.no_valid_data.clear()

        attrs = [attr for attr in attr_list if not data.domain[attr].values]
        if attrs:
            CanvasText(self.canvas,
                       "Feature {} has no values".format(attrs[0]),
                       (self.canvas_view.width() - 120) / 2,
                       self.canvas_view.height() / 2)
            return
        if self.interior_coloring == self.PEARSON:
            apriori_dists = [get_distribution(data, attr) for attr in attr_list]
        else:
            apriori_dists = []

        def get_max_label_width(attr):
            values = get_variable_values_sorted(data.domain[attr])
            maxw = 0
            for val in values:
                t = CanvasText(self.canvas, val, 0, 0, bold=0, show=False)
                maxw = max(int(t.boundingRect().width()), maxw)
            return maxw

        # get the maximum width of rectangle
        xoff = 20
        width = 20
        if len(attr_list) > 1:
            text = CanvasText(self.canvas, attr_list[1], bold=1, show=0)
            max_ylabel_w1 = min(get_max_label_width(attr_list[1]), 150)
            width = 5 + text.boundingRect().height() + \
                self.ATTR_VAL_OFFSET + max_ylabel_w1
            xoff = width
            if len(attr_list) == 4:
                text = CanvasText(self.canvas, attr_list[3], bold=1, show=0)
                max_ylabel_w2 = min(get_max_label_width(attr_list[3]), 150)
                width += text.boundingRect().height() + \
                    self.ATTR_VAL_OFFSET + max_ylabel_w2 - 10

        # get the maximum height of rectangle
        height = 100
        yoff = 45
        square_size = min(self.canvas_view.width() - width - 20,
                          self.canvas_view.height() - height - 20)

        if square_size < 0:
            return  # canvas is too small to draw rectangles
        self.canvas_view.setSceneRect(
            0, 0, self.canvas_view.width(), self.canvas_view.height())

        drawn_sides = set()
        draw_positions = {}

        conditionaldict, distributiondict = \
            get_conditional_distribution(data, attr_list)
        conditionalsubsetdict = None
        if self.subset_indices:
            conditionalsubsetdict, _ = \
                get_conditional_distribution(self.discrete_data[self.subset_indices], attr_list)

        # draw rectangles
        draw_data(
            attr_list, (xoff, xoff + square_size), (yoff, yoff + square_size),
            0, "", len(attr_list), [], [])
        draw_legend((xoff, xoff + square_size), (yoff, yoff + square_size))
        self.update_selection_rects()
Beispiel #50
0
class LineScanPlot(QWidget, OWComponent, SelectionGroupMixin,
                   ImageColorSettingMixin, ImageZoomMixin):

    attr_x = ContextSetting(None)
    gamma = Setting(0)

    selection_changed = Signal()

    def __init__(self, parent):
        QWidget.__init__(self)
        OWComponent.__init__(self, parent)
        SelectionGroupMixin.__init__(self)
        ImageColorSettingMixin.__init__(self)

        self.parent = parent

        self.selection_type = SELECTMANY
        self.saving_enabled = True
        self.selection_enabled = True
        self.viewtype = INDIVIDUAL  # required bt InteractiveViewBox
        self.highlighted = None
        self.data_points = None
        self.data_imagepixels = None

        self.plotview = pg.GraphicsLayoutWidget()
        self.plotview.show()

        self.plot = pg.PlotItem(background="w",
                                viewBox=InteractiveViewBox(self))
        self.plotview.addItem(self.plot)

        self.legend = ImageColorLegend()
        self.plotview.addItem(self.legend)

        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        layout = QVBoxLayout()
        self.setLayout(layout)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().addWidget(self.plotview)

        self.img = ImageItemNan()
        self.img.setOpts(axisOrder='row-major')
        self.plot.addItem(self.img)
        self.plot.scene().sigMouseMoved.connect(self.plot.vb.mouseMovedEvent)

        layout = QGridLayout()
        self.plotview.setLayout(layout)
        self.button = QPushButton("Menu", self.plotview)
        self.button.setAutoDefault(False)

        layout.setRowStretch(1, 1)
        layout.setColumnStretch(1, 1)
        layout.addWidget(self.button, 0, 0)
        view_menu = MenuFocus(self)
        self.button.setMenu(view_menu)

        # prepare interface according to the new context
        self.parent.contextAboutToBeOpened.connect(
            lambda x: self.init_interface_data(x[0]))

        self.add_zoom_actions(view_menu)

        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str)

        choose_xy = QWidgetAction(self)
        box = gui.vBox(self)
        box.setFocusPolicy(Qt.TabFocus)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=DomainModel.PRIMITIVE,
                                    placeholder="Position (index)")
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self.update_attr,
                                      model=self.xy_model,
                                      **common_options)

        box.setFocusProxy(self.cb_attr_x)

        box.layout().addWidget(self.color_settings_box())

        choose_xy.setDefaultWidget(box)
        view_menu.addAction(choose_xy)

        self.lsx = None  # info about the X axis
        self.lsy = None  # info about the Y axis

        self.data = None
        self.data_ids = {}

    def init_interface_data(self, data):
        same_domain = (self.data and data and data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)

    def help_event(self, ev):
        pos = self.plot.vb.mapSceneToView(ev.scenePos())
        sel, wavenumber_ind = self._points_at_pos(pos)
        prepared = []
        if sel is not None:
            prepared.append(str(self.wavenumbers[wavenumber_ind]))
            for d in self.data[sel]:
                variables = [
                    v for v in self.data.domain.metas +
                    self.data.domain.class_vars if v not in [self.attr_x]
                ]
                features = [
                    '{} = {}'.format(attr.name, d[attr]) for attr in variables
                ]
                features.append('value = {}'.format(d[wavenumber_ind]))
                prepared.append("\n".join(features))
        text = "\n\n".join(prepared)
        if text:
            text = ('<span style="white-space:pre">{}</span>'.format(
                escape(text)))
            QToolTip.showText(ev.screenPos(), text, widget=self.plotview)
            return True
        else:
            return False

    def update_attr(self):
        self.update_view()

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None

    def set_data(self, data):
        if data:
            self.data = data
            self.data_ids = {e: i for i, e in enumerate(data.ids)}
            self.restore_selection_settings()
        else:
            self.data = None
            self.data_ids = {}

    def update_view(self):
        self.img.clear()
        self.img.setSelection(None)
        self.legend.set_colors(None)
        self.lsx = None
        self.lsy = None
        self.wavenumbers = None
        self.data_xs = None
        self.data_imagepixels = None
        if self.data and len(self.data.domain.attributes):
            if self.attr_x is not None:
                xat = self.data.domain[self.attr_x]
                ndom = Domain([xat])
                datam = Table(ndom, self.data)
                coorx = datam.X[:, 0]
            else:
                coorx = np.arange(len(self.data))
            self.lsx = lsx = values_to_linspace(coorx)
            self.data_xs = coorx

            self.wavenumbers = wavenumbers = getx(self.data)
            self.lsy = lsy = values_to_linspace(wavenumbers)

            # set data
            imdata = np.ones((lsy[2], lsx[2])) * float("nan")
            xindex = index_values(coorx, lsx)
            yindex = index_values(wavenumbers, lsy)
            for xind, d in zip(xindex, self.data.X):
                imdata[yindex, xind] = d

            self.data_imagepixels = xindex

            self.img.setImage(imdata, autoLevels=False)
            self.update_levels()
            self.update_color_schema()

            # shift centres of the pixels so that the axes are useful
            shiftx = _shift(lsx)
            shifty = _shift(lsy)
            left = lsx[0] - shiftx
            bottom = lsy[0] - shifty
            width = (lsx[1] - lsx[0]) + 2 * shiftx
            height = (lsy[1] - lsy[0]) + 2 * shifty
            self.img.setRect(QRectF(left, bottom, width, height))

            self.refresh_img_selection()

    def refresh_img_selection(self):
        selected_px = np.zeros((self.lsy[2], self.lsx[2]), dtype=np.uint8)
        selected_px[:, self.data_imagepixels] = self.selection_group
        self.img.setSelection(selected_px)

    def make_selection(self, selected):
        """Add selected indices to the selection."""
        add_to_group, add_group, remove = selection_modifiers()
        if self.data and self.lsx and self.lsy:
            if add_to_group:  # both keys - need to test it before add_group
                selnum = np.max(self.selection_group)
            elif add_group:
                selnum = np.max(self.selection_group) + 1
            elif remove:
                selnum = 0
            else:
                self.selection_group *= 0
                selnum = 1
            if selected is not None:
                self.selection_group[selected] = selnum
            self.refresh_img_selection()
        self.prepare_settings_for_saving()
        self.selection_changed.emit()

    def _points_at_pos(self, pos):
        if self.data and self.lsx and self.lsy:
            x, y = pos.x(), pos.y()
            x_distance = np.abs(self.data_xs - x)
            sel = (x_distance < _shift(self.lsx))
            wavenumber_distance = np.abs(self.wavenumbers - y)
            wavenumber_ind = np.argmin(wavenumber_distance)
            return sel, wavenumber_ind
        return None, None

    def select_by_click(self, pos):
        sel, _ = self._points_at_pos(pos)
        self.make_selection(sel)
class LineScanPlot(QWidget, OWComponent, SelectionGroupMixin,
                   ImageColorSettingMixin, ImageZoomMixin):

    attr_x = ContextSetting(None)
    gamma = Setting(0)

    selection_changed = Signal()

    def __init__(self, parent):
        QWidget.__init__(self)
        OWComponent.__init__(self, parent)
        SelectionGroupMixin.__init__(self)
        ImageColorSettingMixin.__init__(self)

        self.parent = parent

        self.selection_type = SELECTMANY
        self.saving_enabled = True
        self.selection_enabled = True
        self.viewtype = INDIVIDUAL  # required bt InteractiveViewBox
        self.highlighted = None
        self.data_points = None
        self.data_imagepixels = None

        self.plotview = pg.PlotWidget(background="w", viewBox=InteractiveViewBox(self))
        self.plot = self.plotview.getPlotItem()

        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        layout = QVBoxLayout()
        self.setLayout(layout)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().addWidget(self.plotview)

        self.img = ImageItemNan()
        self.img.setOpts(axisOrder='row-major')
        self.plot.addItem(self.img)
        self.plot.scene().sigMouseMoved.connect(self.plot.vb.mouseMovedEvent)

        layout = QGridLayout()
        self.plotview.setLayout(layout)
        self.button = QPushButton("Menu", self.plotview)
        self.button.setAutoDefault(False)

        layout.setRowStretch(1, 1)
        layout.setColumnStretch(1, 1)
        layout.addWidget(self.button, 0, 0)
        view_menu = MenuFocus(self)
        self.button.setMenu(view_menu)

        # prepare interface according to the new context
        self.parent.contextAboutToBeOpened.connect(lambda x: self.init_interface_data(x[0]))

        self.add_zoom_actions(view_menu)

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str)

        choose_xy = QWidgetAction(self)
        box = gui.vBox(self)
        box.setFocusPolicy(Qt.TabFocus)
        self.xy_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                    valid_types=DomainModel.PRIMITIVE,
                                    placeholder="Position (index)")
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self.update_attr,
            model=self.xy_model, **common_options)

        box.setFocusProxy(self.cb_attr_x)

        box.layout().addWidget(self.color_settings_box())

        choose_xy.setDefaultWidget(box)
        view_menu.addAction(choose_xy)

        self.lsx = None  # info about the X axis
        self.lsy = None  # info about the Y axis

        self.data = None
        self.data_ids = {}

    def init_interface_data(self, data):
        same_domain = (self.data and data and
                       data.domain == self.data.domain)
        if not same_domain:
            self.init_attr_values(data)

    def help_event(self, ev):
        pos = self.plot.vb.mapSceneToView(ev.scenePos())
        sel, wavenumber_ind = self._points_at_pos(pos)
        prepared = []
        if sel is not None:
            prepared.append(str(self.wavenumbers[wavenumber_ind]))
            for d in self.data[sel]:
                variables = [v for v in self.data.domain.metas + self.data.domain.class_vars
                             if v not in [self.attr_x]]
                features = ['{} = {}'.format(attr.name, d[attr]) for attr in variables]
                features.append('value = {}'.format(d[wavenumber_ind]))
                prepared.append("\n".join(features))
        text = "\n\n".join(prepared)
        if text:
            text = ('<span style="white-space:pre">{}</span>'
                    .format(escape(text)))
            QToolTip.showText(ev.screenPos(), text, widget=self.plotview)
            return True
        else:
            return False

    def update_attr(self):
        self.update_view()

    def init_attr_values(self, data):
        domain = data.domain if data is not None else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None

    def set_data(self, data):
        if data:
            self.data = data
            self.data_ids = {e: i for i, e in enumerate(data.ids)}
            self.restore_selection_settings()
        else:
            self.data = None
            self.data_ids = {}

    def update_view(self):
        self.img.clear()
        self.img.setSelection(None)
        self.lsx = None
        self.lsy = None
        self.wavenumbers = None
        self.data_xs = None
        self.data_imagepixels = None
        if self.data and len(self.data.domain.attributes):
            if self.attr_x is not None:
                xat = self.data.domain[self.attr_x]
                ndom = Domain([xat])
                datam = Table(ndom, self.data)
                coorx = datam.X[:, 0]
            else:
                coorx = np.arange(len(self.data))
            self.lsx = lsx = values_to_linspace(coorx)
            self.data_xs = coorx

            self.wavenumbers = wavenumbers = getx(self.data)
            self.lsy = lsy = values_to_linspace(wavenumbers)

            # set data
            imdata = np.ones((lsy[2], lsx[2])) * float("nan")
            xindex = index_values(coorx, lsx)
            yindex = index_values(wavenumbers, lsy)
            for xind, d in zip(xindex, self.data.X):
                imdata[yindex, xind] = d

            self.data_imagepixels = xindex

            self.img.setImage(imdata, autoLevels=False)
            self.img.setLevels([0, 1])
            self.update_levels()
            self.update_color_schema()

            # shift centres of the pixels so that the axes are useful
            shiftx = _shift(lsx)
            shifty = _shift(lsy)
            left = lsx[0] - shiftx
            bottom = lsy[0] - shifty
            width = (lsx[1]-lsx[0]) + 2*shiftx
            height = (lsy[1]-lsy[0]) + 2*shifty
            self.img.setRect(QRectF(left, bottom, width, height))

            self.refresh_img_selection()

    def refresh_img_selection(self):
        selected_px = np.zeros((self.lsy[2], self.lsx[2]), dtype=np.uint8)
        selected_px[:, self.data_imagepixels] = self.selection_group
        self.img.setSelection(selected_px)

    def make_selection(self, selected, add):
        """Add selected indices to the selection."""
        add_to_group, add_group, remove = selection_modifiers()
        if self.data and self.lsx and self.lsy:
            if add_to_group:  # both keys - need to test it before add_group
                selnum = np.max(self.selection_group)
            elif add_group:
                selnum = np.max(self.selection_group) + 1
            elif remove:
                selnum = 0
            else:
                self.selection_group *= 0
                selnum = 1
            if selected is not None:
                self.selection_group[selected] = selnum
            self.refresh_img_selection()
        self.prepare_settings_for_saving()
        self.selection_changed.emit()

    def _points_at_pos(self, pos):
        if self.data and self.lsx and self.lsy:
            x, y = pos.x(), pos.y()
            x_distance = np.abs(self.data_xs - x)
            sel = (x_distance < _shift(self.lsx))
            wavenumber_distance = np.abs(self.wavenumbers - y)
            wavenumber_ind = np.argmin(wavenumber_distance)
            return sel, wavenumber_ind
        return None, None

    def select_by_click(self, pos, add):
        sel, _ = self._points_at_pos(pos)
        self.make_selection(sel, add)
class OWTestLearners(OWWidget):
    name = "Test & Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message(
            "Click on the table header to select shown columns",
            "click_header")]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    BUILTIN_ORDER = {
        DiscreteVariable: ("AUC", "CA", "F1", "Precision", "Recall"),
        ContinuousVariable: ("MSE", "RMSE", "MAE", "R2")}

    shown_scores = \
        settings.Setting(set(chain(*BUILTIN_ORDER.values())))

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train dataset is empty.")
        test_data_empty = Msg("Test dataset is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg("Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        no_class_values = Msg("Target variable has no values.")
        only_one_class_var_value = Msg("Target variable has only one value.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(
            sbox, self, "resampling", callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_folds", label="Number of folds: ",
            items=[str(x) for x in self.NFolds], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.kfold_changed)
        gui.checkBox(
            ibox, self, "cv_stratified", "Stratified",
            callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(
            ibox, self, "fold_feature", model=self.feature_model,
            orientation=Qt.Horizontal, callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_repeats", label="Repeat train/test: ",
            items=[str(x) for x in self.NRepeats], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.shuffle_split_changed)
        gui.comboBox(
            ibox, self, "sample_size", label="Training set size: ",
            items=["{} %".format(x) for x in self.SampleSizes],
            maximumContentsLength=5, orientation=Qt.Horizontal,
            callback=self.shuffle_split_changed)
        gui.checkBox(
            ibox, self, "shuffle_stratified", "Stratified",
            callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox, self, "class_selection", items=[],
            sendSelectedValue=True, valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)

        self.view = gui.TableView(
            wordWrap=True,
        )
        header = self.view.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setStretchLastSection(False)
        header.setContextMenuPolicy(Qt.CustomContextMenu)
        header.customContextMenuRequested.connect(self.show_column_chooser)

        self.result_model = QStandardItemModel(self)
        self.result_model.setHorizontalHeaderLabels(["Method"])
        self.view.setModel(self.result_model)
        self.view.setItemDelegate(ItemDelegate())

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.view)

    def sizeHint(self):
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        else:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.no_class_values.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not len(data):
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [not data.domain.class_vars,
                     len(data.domain.class_vars) > 1,
                     np.isnan(data.Y).all(),
                     data.domain.has_discrete_class and len(data.domain.class_var.values) == 1]
            errors = [self.Error.class_required,
                      self.Error.too_many_classes,
                      self.Error.no_class_values,
                      self.Error.only_one_class_var_value]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not len(data):
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {(True, True): " ",  # both, don't specify
                (True, False): " train ",
                (False, True): " test "}[(self.train_data_missing_vals,
                                          self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data is None or self.data.domain.class_var is None:
            self.scorers = []
            return
        class_var = self.data and self.data.domain.class_var
        order = {name: i
                 for i, name in enumerate(self.BUILTIN_ORDER[type(class_var)])}
        # 'abstract' is retrieved from __dict__ to avoid inheriting
        usable = (cls for cls in scoring.Score.registry.values()
                  if cls.is_scalar and not cls.__dict__.get("abstract")
                  and isinstance(class_var, cls.class_types))
        self.scorers = sorted(usable, key=lambda cls: order.get(cls.name, 99))

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self._update_header()
        self._update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._invalidate()
        self.__update()

    def _update_header(self):
        # Set the correct horizontal header labels on the results_model.
        model = self.result_model
        model.setColumnCount(1 + len(self.scorers))
        for col, score in enumerate(self.scorers):
            item = QStandardItem(score.name)
            item.setToolTip(score.long_name)
            model.setHorizontalHeaderItem(col + 1, item)
        self._update_shown_columns()

    def _update_shown_columns(self):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            header.setSectionHidden(section, col_name not in self.shown_scores)

    def _update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.view.model()
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            if isinstance(slot.results, Try.Fail):
                head.setToolTip(str(slot.results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                errors.append("{name} failed with error:\n"
                              "{exc.__class__.__name__}: {exc!s}"
                              .format(name=name, exc=slot.results.exception))

            row = [head]

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(
                        slot.results.value, target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [Try(scorer_caller(scorer, ovr_results))
                             for scorer in self.scorers]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat in stats:
                    item = QStandardItem()
                    if stat.success:
                        item.setText("{:.3f}".format(stat.value[0]))
                    else:
                        item.setToolTip(str(stat.exception))
                        has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self._update_stats_model()

    def _invalidate(self, which=None):
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.view.model()
        statmodelkeys = [model.item(row, 0).data(Qt.UserRole)
                         for row in range(model.rowCount())]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.__needupdate = True

    def show_column_chooser(self, pos):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        def update(col_name, checked):
            if checked:
                self.shown_scores.add(col_name)
            else:
                self.shown_scores.remove(col_name)
            self._update_shown_columns()

        menu = QMenu()
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            action = menu.addAction(col_name)
            action.setCheckable(True)
            action.setChecked(col_name in self.shown_scores)
            action.triggered.connect(partial(update, col_name))
        menu.exec(header.mapToGlobal(pos))

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [slot for slot in self.learners.values()
                 if slot.results is not None and slot.results.success]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [learner_name(slot.learner)
                                      for slot in valid]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".
                      format(stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [("Sampling type",
                      "{}Shuffle split, {} random samples with {}% data "
                      .format(stratified, self.NRepeats[self.n_repeats],
                              self.SampleSizes[self.sample_size]))]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value, processEvents=False)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Warning.test_data_missing.clear()
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        common_args = dict(
            store_data=True,
            preprocessor=self.preprocessor,
        )
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.KFold:
            folds = self.NFolds[self.n_folds]
            test_f = partial(
                Orange.evaluation.CrossValidation,
                self.data, learners_c, k=folds,
                random_state=rstate, **common_args)
        elif self.resampling == OWTestLearners.FeatureFold:
            test_f = partial(
                Orange.evaluation.CrossValidationFeature,
                self.data, learners_c, self.fold_feature,
                **common_args
            )
        elif self.resampling == OWTestLearners.LeaveOneOut:
            test_f = partial(
                Orange.evaluation.LeaveOneOut,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.ShuffleSplit:
            train_size = self.SampleSizes[self.sample_size] / 100
            test_f = partial(
                Orange.evaluation.ShuffleSplit,
                self.data, learners_c,
                n_resamples=self.NRepeats[self.n_repeats],
                train_size=train_size, test_size=None,
                stratified=self.shuffle_stratified,
                random_state=rstate, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTrain:
            test_f = partial(
                Orange.evaluation.TestOnTrainingData,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData,
                self.data, self.test_data, learners_c, **common_args
            )
        else:
            assert False, "self.resampling %s" % self.resampling

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[float]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = Task()

        def progress_callback(finished):
            if task.cancelled:
                raise UserInterrupt()
            QMetaObject.invokeMethod(
                self, "setProgressValue", Qt.QueuedConnection,
                Q_ARG(float, 100 * finished)
            )

        def ondone(_):
            QMetaObject.invokeMethod(
                self, "__task_complete", Qt.QueuedConnection,
                Q_ARG(object, task))

        testfunc = partial(testfunc, callback=progress_callback)
        task.future = self.__executor.submit(testfunc)
        task.future.add_done_callback(ondone)

        self.progressBarInit(processEvents=None)
        self.setBlocking(True)
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, task):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        if self.__task is not task:
            assert task.cancelled
            log.debug("Reaping cancelled task: %r", "<>")
            return

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)
        self.setStatusMessage("")
        result = task.future
        assert result.done()
        self.__task = None
        try:
            results = result.result()    # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:
            log.exception("testing error (in __task_complete):",
                          exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er), er)))
            self.__state = State.Done
            return

        self.__state = State.Done

        learner_key = {slot.learner: key for key, slot in
                       self.learners.items()}
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [Try(scorer_caller(scorer, result))
                             for scorer in self.scorers]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self._update_header()
        self._update_stats_model()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            assert task.future.done()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Beispiel #53
0
class OWTranspose(OWWidget):
    name = "Transpose"
    description = "Transpose data table."
    icon = "icons/Transpose.svg"
    priority = 2000

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        data = Output("Data", Table, dynamic=False)

    GENERIC, FROM_META_ATTR = range(2)

    resizing_enabled = False
    want_main_area = False

    DEFAULT_PREFIX = "Feature"

    settingsHandler = DomainContextHandler()
    feature_type = ContextSetting(GENERIC)
    feature_name = ContextSetting("")
    feature_names_column = ContextSetting(None)
    auto_apply = Setting(True)

    class Error(OWWidget.Error):
        value_error = Msg("{}")

    def __init__(self):
        super().__init__()
        self.data = None

        box = gui.radioButtons(
            self.controlArea, self, "feature_type", box="Feature names",
            callback=lambda: self.apply())

        button = gui.appendRadioButton(box, "Generic")
        edit = gui.lineEdit(
            gui.indentedBox(box, gui.checkButtonOffsetHint(button)), self,
            "feature_name",
            placeholderText="Type a prefix ...", toolTip="Custom feature name")
        edit.editingFinished.connect(self._apply_editing)

        self.meta_button = gui.appendRadioButton(box, "From meta attribute:")
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=StringVariable,
            alphabetical=True)
        self.feature_combo = gui.comboBox(
            gui.indentedBox(box, gui.checkButtonOffsetHint(button)), self,
            "feature_names_column", callback=self._feature_combo_changed,
            model=self.feature_model)

        self.apply_button = gui.auto_commit(
            self.controlArea, self, "auto_apply", "&Apply",
            box=False, commit=self.apply)
        self.apply_button.button.setAutoDefault(False)

        self.set_controls()

    def _apply_editing(self):
        self.feature_type = self.GENERIC
        self.feature_name = self.feature_name.strip()
        self.apply()

    def _feature_combo_changed(self):
        self.feature_type = self.FROM_META_ATTR
        self.apply()

    @Inputs.data
    def set_data(self, data):
        # Skip the context if the combo is empty: a context with
        # feature_model == None would then match all domains
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.set_controls()
        if self.feature_model:
            self.openContext(data)
        self.apply()

    def set_controls(self):
        self.feature_model.set_domain(self.data and self.data.domain)
        self.meta_button.setEnabled(bool(self.feature_model))
        if self.feature_model:
            self.feature_names_column = self.feature_model[0]
            self.feature_type = self.FROM_META_ATTR
        else:
            self.feature_names_column = None

    def apply(self):
        self.clear_messages()
        transposed = None
        if self.data:
            try:
                transposed = Table.transpose(
                    self.data,
                    self.feature_type == self.FROM_META_ATTR and self.feature_names_column,
                    feature_name=self.feature_name or self.DEFAULT_PREFIX)
            except ValueError as e:
                self.Error.value_error(e)
        self.Outputs.data.send(transposed)

    def send_report(self):
        if self.feature_type == self.GENERIC:
            names = self.feature_name or self.DEFAULT_PREFIX
        else:
            names = "from meta attribute"
            if self.feature_names_column:
                names += "  '{}'".format(self.feature_names_column.name)
        self.report_items("", [("Feature names", names)])
        if self.data:
            self.report_data("Data", self.data)
class OWContingencyTable(widget.OWWidget):
    name = "Contingency Table"
    description = "Construct a contingency table from given data."
    icon = "icons/Contingency.svg"
    priority = 2010

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        contingency = Output("Contingency Table", Table, default=True)
        selected_data = Output("Selected Data", Table)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler(metas_in_res=True)
    rows = ContextSetting(None)
    columns = ContextSetting(None)
    selection = ContextSetting(set())
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.data = None
        self.feature_model = DomainModel(valid_types=DiscreteVariable)
        self.table = None

        box = gui.vBox(self.controlArea, "Rows")
        gui.comboBox(box, self, 'rows', sendSelectedValue=True,
                     model=self.feature_model, callback=self._attribute_changed)

        box = gui.vBox(self.controlArea, "Columns")
        gui.comboBox(box, self, 'columns', sendSelectedValue=True,
                     model=self.feature_model, callback=self._attribute_changed)

        gui.rubber(self.controlArea)

        box = gui.vBox(self.controlArea, "Scores")
        self.scores = gui.widgetLabel(box, "\n\n")

        self.apply_button = gui.auto_commit(
            self.controlArea, self, "auto_apply", "&Apply", box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.feature_model.set_domain(None)
        self.rows = None
        self.columns = None
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.rows = self.feature_model[0]
                self.columns = self.feature_model[0]
                self.openContext(data)
                self.tableview.set_variables(self.rows, self.columns)
                self.table = contingency_table(self.data, self.columns, self.rows)
                self.tableview.update_table(self.table.X, formatstr="{:.0f}")
        else:
            self.tableview.clear()

    def handleNewSignals(self):
        self._attribute_changed()

    def commit(self):
        if len(self.selection):
            cells = []
            for ir, r in enumerate(self.rows.values):
                for ic, c in enumerate(self.columns.values):
                    if (ir, ic) in self.selection:
                        cells.append(Values([FilterDiscrete(self.rows, [r]), FilterDiscrete(self.columns, [c])]))
            selected_data = Values(cells, conjunction=False)(self.data)
            annotated_data = create_annotated_table(self.data,
                                                    np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])
        self.Outputs.contingency.send(self.table)
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)

    def _invalidate(self):
        self.selection = self.tableview.get_selection()
        self.commit()

    def _attribute_changed(self):
        self.tableview.set_selection(self.selection)
        self.table = None
        if self.data and self.rows and self.columns:
            self.tableview.set_variables(self.rows, self.columns)
            self.table = contingency_table(self.data, self.columns, self.rows)
            self.tableview.update_table(self.table.X, formatstr="{:.0f}")

            chi = ChiSqStats(self.data, self.rows, self.columns)
            vardata1 = self.data.get_column_view(self.rows.name)[0]
            vardata2 = self.data.get_column_view(self.columns.name)[0]
            self.scores.setText("ARI: {:.3f}\nAMI: {:.3f}\nχ²={:.2f}, p={:.3f}".format(
                adjusted_rand_score(vardata1, vardata2),
                adjusted_mutual_info_score(vardata1, vardata2),
                chi.chisq,
                chi.p))
        else:
            self.scores.setText("\n\n")
        self._invalidate()

    def send_report(self):
        rows = None
        columns = None
        if self.data is not None:
            rows = self.rows
            if rows in self.data.domain:
                rows = self.data.domain[rows]
            columns = self.columns
            if columns in self.data.domain:
                columns = self.data.domain[columns]
        self.report_items((
            ("Rows", rows),
            ("Columns", columns),
        ))